add BOS and EOS function to the Tokenizer as we start to converge closer to the Llama 2 code from Meta, and as we're about to add the Chat capability

This commit is contained in:
Andrej Karpathy
2023-08-23 00:08:22 +00:00
parent d26a499207
commit d1eb18b8ec
4 changed files with 133 additions and 16 deletions
+35 -16
View File
@@ -441,8 +441,10 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
return res != NULL ? res->id : -1;
}
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
if (t->sorted_vocab == NULL) {
// lazily malloc and sort the vocabulary
@@ -455,13 +457,24 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}
// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
size_t str_len = 0;
// start at 0 tokens
*n_tokens = 0;
// add optional BOS (=1) token, if desired
if (bos) tokens[(*n_tokens)++] = 1;
// add_dummy_prefix is true by default
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
*n_tokens = 1; // the number of tokens
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;
}
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
@@ -543,6 +556,9 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
(*n_tokens)--; // token length decreased
}
// add optional EOS (=2) token, if desired
if (eos) tokens[(*n_tokens)++] = 2;
free(str_buffer);
}
@@ -704,18 +720,19 @@ long time_in_ms() {
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
int num_prompt_tokens = 0; // the total number of prompt tokens
if (prompt != NULL) {
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
if (num_prompt_tokens < 1) {
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
exit(EXIT_FAILURE);
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int token = prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
@@ -723,16 +740,16 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
float* logits = forward(transformer, token, pos);
// advance the state state machine
if (pos < num_prompt_tokens) {
if (pos < num_prompt_tokens - 1) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (1) token delimits sequences
// data-dependent terminating condition: the BOS (=1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
@@ -756,7 +773,8 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
}
// ----------------------------------------------------------------------------
// int main
// CLI, include only if not testing
#ifndef TESTING
void error_usage() {
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
@@ -779,7 +797,7 @@ int main(int argc, char *argv[]) {
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
char *prompt = ""; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
// poor man's C argparse so we can override the defaults above from the command line
@@ -827,3 +845,4 @@ int main(int argc, char *argv[]) {
free_transformer(&transformer);
return 0;
}
#endif