From 379f083b85740ccabe80b61259e7346ad27ede3d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 22 Aug 2023 01:56:51 +0000 Subject: [PATCH] make sorted vocab a buffer of Tokenizer --- run.c | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/run.c b/run.c index 1ac3c56..705719c 100644 --- a/run.c +++ b/run.c @@ -369,14 +369,24 @@ float* forward(Transformer* transformer, int token, int pos) { // ---------------------------------------------------------------------------- // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens +typedef struct { + char *str; + int id; +} TokenIndex; + typedef struct { char** vocab; float* vocab_scores; + TokenIndex *sorted_vocab; int vocab_size; unsigned int max_token_length; char byte_piece[2]; } Tokenizer; +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; @@ -384,6 +394,7 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { t->vocab = (char**)malloc(vocab_size * sizeof(char*)); t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); t->byte_piece[1] = '\0'; // null terminate the byte_piece string + t->sorted_vocab = NULL; // initialized lazily // read in the file FILE *file = fopen(tokenizer_path, "rb"); if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } @@ -403,6 +414,7 @@ void free_tokenizer(Tokenizer* t) { for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } free(t->vocab); free(t->vocab_scores); + free(t->sorted_vocab); } char* decode(Tokenizer* t, int prev_token, int token) { @@ -422,15 +434,6 @@ char* decode(Tokenizer* t, int prev_token, int token) { return piece; } -typedef struct { - char *str; - int id; -} TokenIndex; - -int compare_tokens(const void *a, const void *b) { - return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); -} - int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { // efficiently find the perfect match for str in vocab, return its index or -1 if not found TokenIndex tok = { .str = str }; // acts as the key to search for @@ -441,20 +444,22 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { // encode the string text (input) into an upper-bound preallocated tokens[] array - // sort vocabulary - TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); - for (int i = 0; i < t->vocab_size; i++) { - sorted_vocab[i].str = t->vocab[i]; - sorted_vocab[i].id = i; + if (t->sorted_vocab == NULL) { + // lazily malloc and sort the vocabulary + t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < t->vocab_size; i++) { + t->sorted_vocab[i].str = t->vocab[i]; + t->sorted_vocab[i].id = i; + } + qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); } - qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); // create a temporary buffer that will store merge candidates of always two consecutive tokens char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size); + tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size); *n_tokens = 1; // the number of tokens // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -490,7 +495,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // ok c+1 is not a continuation byte, so we've read in a full codepoint - int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); if (id != -1) { // we found this codepoint in vocab, add it as a token @@ -515,7 +520,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); if (id != -1 && t->vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = t->vocab_scores[id]; @@ -538,7 +543,6 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } free(str_buffer); - free(sorted_vocab); } // ----------------------------------------------------------------------------