From daa9fd9b8a288996b5a3c7913881a46d15cc3932 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Sat, 12 Aug 2023 23:12:35 +0300 Subject: [PATCH] sort vocabulary for faster lookup with bsearch() --- run.c | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/run.c b/run.c index f69c21a..46d7a41 100644 --- a/run.c +++ b/run.c @@ -342,24 +342,38 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // ---------------------------------------------------------------------------- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt -int str_lookup(char *str, char **vocab, int vocab_size) { - // find the first perfect match for str in vocab, return its index or -1 if not found - for (int i = 0; i < vocab_size; i++) { - if (strcmp(str, vocab[i]) == 0) { - return i; - } - } - return -1; +typedef struct { + char *str; + int id; +} TokenIndex; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = {str=str}; + TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res!=NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { + // sort vocabulary + TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < vocab_size; i++) { + sorted_vocab[i].str = vocab[i]; + sorted_vocab[i].id = i; + } + qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", vocab, vocab_size); + tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string @@ -374,7 +388,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding continue; - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1) { tokens[(*n_tokens)++] = id; @@ -395,7 +409,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1 && vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = vocab_scores[id]; @@ -418,6 +432,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } free(str_buffer); + free(sorted_vocab); } // convert token to printable string