diff --git a/run.c b/run.c index b8a1839..1c14563 100644 --- a/run.c +++ b/run.c @@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* } // ---------------------------------------------------------------------------- -// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt +// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens + +typedef struct { + char** vocab; + float* vocab_scores; + int vocab_size; + unsigned int max_token_length; + char byte_piece[2]; +} Tokenizer; + +void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { + // i should have written the vocab_size into the tokenizer file... sigh + t->vocab_size = vocab_size; + // malloc space to hold the scores and the strings + t->vocab = (char**)malloc(vocab_size * sizeof(char*)); + t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); + t->byte_piece[1] = '\0'; // null terminate the byte_piece string + // read in the file + FILE *file = fopen(tokenizer, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); } + if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + int len; + for (int i = 0; i < vocab_size; i++) { + if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} + if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i] = (char *)malloc(len + 1); + if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i][len] = '\0'; // add the string terminating token + } + fclose(file); +} + +void free_tokenizer(Tokenizer* t) { + for (int i = 0; i < t->vocab_size; i++) { + free(t->vocab[i]); + } + free(t->vocab); + free(t->vocab_scores); +} + +char* get_piece(Tokenizer* t, int prev_token, int token) { + char *piece = t->vocab[token]; + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) + if (prev_token == 1 && piece[0] == ' ') { piece++; } + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + unsigned char byte_val; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { + // ok this token is a raw byte token, careful to only print printable chars or whitespace + // some of the other bytes can be various control codes, backspace, etc. => skip + if (isprint(byte_val) || isspace(byte_val)) { + t->byte_piece[0] = byte_val; + piece = &t->byte_piece[0]; + } + } + return piece; +} typedef struct { char *str; @@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { return res != NULL ? res->id : -1; } -void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { +void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { + // encode the string text (input) into an upper-bound preallocated tokens[] array // sort vocabulary - TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); - for (int i = 0; i < vocab_size; i++) { - sorted_vocab[i].str = vocab[i]; + TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < t->vocab_size; i++) { + sorted_vocab[i].str = t->vocab[i]; sorted_vocab[i].id = i; } - qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); // create a temporary buffer that will store merge candidates of always two consecutive tokens - char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) + char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); + tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size); *n_tokens = 1; // the number of tokens // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } // ok c+1 is not a continuation byte, so we've read in a full codepoint - int id = str_lookup(str_buffer, sorted_vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); if (id != -1) { // we found this codepoint in vocab, add it as a token @@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) - sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, sorted_vocab, vocab_size); - if (id != -1 && vocab_scores[id] > best_score) { + sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); + int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + if (id != -1 && t->vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position - best_score = vocab_scores[id]; + best_score = t->vocab_scores[id]; best_id = id; best_idx = i; } @@ -587,8 +643,8 @@ void error_usage() { int main(int argc, char *argv[]) { // default inits - char *checkpoint = NULL; // e.g. out/model.bin - char *tokenizer = "tokenizer.bin"; + char *checkpoint_path = NULL; // e.g. out/model.bin + char *tokenizer_path = "tokenizer.bin"; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower rng_seed = 0; // seed rng with time by default @@ -596,7 +652,7 @@ int main(int argc, char *argv[]) { char *prompt = NULL; // prompt string // poor man's C argparse so we can override the defaults above from the command line - if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); } + if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } for (int i = 2; i < argc; i+=2) { // do some basic validation if (i + 1 >= argc) { error_usage(); } // must have arg after flag @@ -608,7 +664,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } - else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; } + else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } else { error_usage(); } } if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} @@ -619,29 +675,14 @@ int main(int argc, char *argv[]) { int fd = 0; // file descriptor for memory mapping float* data = NULL; // memory mapped data pointer ssize_t file_size; // size of the checkpoint file in bytes - read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size); + read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size); // right now we cannot run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } // read in the tokenizer .bin file - char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); - float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); - unsigned int max_token_length; - { - FILE *file = fopen(tokenizer, "rb"); - if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; } - if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - int len; - for (int i = 0; i < config.vocab_size; i++) { - if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;} - if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - vocab[i] = (char *)malloc(len + 1); - if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - vocab[i][len] = '\0'; // add the string terminating token - } - fclose(file); - } + Tokenizer tokenizer; + build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size); // create and init the application RunState RunState state; @@ -653,7 +694,7 @@ int main(int argc, char *argv[]) { int num_prompt_tokens = 0; if (prompt != NULL) { prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); - bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); + bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); } // start the main loop @@ -695,22 +736,9 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' - unsigned char byte_val; - if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) { - // ok this token is a raw byte token, carefuly to only print printable chars or whitespace - // some of the other bytes can be various control codes, backspace, etc. => skip - if (isprint(byte_val) || isspace(byte_val)) { - char byte_piece[2]; - byte_piece[0] = byte_val; - byte_piece[1] = '\0'; - printf("%s", byte_piece); - } - } else { - printf("%s", token_str); - } + // print the token as string, decode it with the Tokenizer object + char* piece = get_piece(&tokenizer, token, next); + printf("%s", piece); fflush(stdout); token = next; @@ -728,9 +756,7 @@ int main(int argc, char *argv[]) { // memory and file handles cleanup free_run_state(&state); free(probindex); - for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } - free(vocab); - free(vocab_scores); + free_tokenizer(&tokenizer); if (prompt_tokens != NULL) free(prompt_tokens); if (data != MAP_FAILED) munmap(data, file_size); if (fd != -1) close(fd);