From c02865df300f3bd9e567ce061000dc23bf785a17 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Fri, 4 Aug 2023 04:18:20 +0300 Subject: [PATCH 1/3] prompt tokenizer improvements: utf8 support, add_dummy_prefix and byte_fallback options to match sentencepiece --- run.c | 46 ++++++++++++++++++++++++++++++++++++++-------- tokenizer.bin | Bin 432717 -> 433869 bytes tokenizer.py | 2 -- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 02cc877..f69c21a 100644 --- a/run.c +++ b/run.c @@ -356,15 +356,34 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", vocab, vocab_size); + *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); + // reset buffer if the current byte is ASCII or leading byte + if ((*c & 0xC0) != 0x80) + str_len = 0; + + str_buffer[str_len++] = *c; // append byte to the buffer + str_buffer[str_len] = '\0'; + + if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + continue; + int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + + if (id != -1) { + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding + for (int i=0; i' into '\x01' + static char byte_piece[4]; + if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { + byte_piece[1] = '\0'; + token_str = byte_piece; + } + return token_str; +} + // ---------------------------------------------------------------------------- // utilities: time / rng @@ -637,9 +669,7 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + printf("%s", token_to_str(vocab, next, token)); fflush(stdout); token = next; diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7bc47fda2ecaab1e2cc2255d140925ed704..e6c1b23ec0c18ebfd7162ef838a77d559b7e3316 100644 GIT binary patch delta 3646 zcmY+_JC55h6op~uHGp?8m{pi>5(o_Nxh8ikQlyPlwZ=WA;@0b2^I-P!AZkNk-Je;;9kR*~q(nto$A~_`gNc~wK@G>GLq>NOMDpH%1 z08nhe(LeM8LII(GP(Uak6cCE1@BtJ6iXGV3dI$xC0zv_yfKWmxH{lAD0LlZ{hh9P` zA(RkG2qlCPLb(eUpaf7(z&>6=C?S*(DhL&X3PN=VN4)}2&A@(B5Gn{2gbG3hp@L9N z!ch;P0M4WWimL#SurO#`R_)Ez<%p@vXHs3Fu4Y6vxe20$bF{Er$44TJ_l z1EGP?KxhCo0GfDR?w{?}GY}dG4TJ_l1EGb`VuThzJAh07^U}2tS_mzK7D5Z5h0tPz z7C^fINB>q2p@q;w=pb|uItU#`=m2ypuy6GaLI&;jTkz-|9I2t9-z zLJy&b&_n1kLJy#S0+;!Z8_GlIA@mS>2t9-zLXQ!80K*0x{i8lW7$6J~1_%R$0m6V0 z1^~kjT=!pqFhCd}3=jqg1B7uCuD}RjJb=soix5T#BZLvc2w{XU?!x8!MgZdk9Q!Xq z7$J-hMhFvx3Bq&;M|}b?&A@(35GDu{gbBg~VS+GC!ch-k0aLzrjb zEd!VV%pJlEVTLe6m?6v%W(YHY8NhPE{woj`2n&P-!UAD|umD&9ECKtkKv*Cw5Ecjv zgayKi5mo@}08aa_LRcZJ5LO5)gcZVy5mo@}0zCF#g|I?cA*>KK2pfbABWwV+6}a!e z24RD+LD(Q{5H<)KM%Vys58&%_`!@(%yZ*YLf8Wnv#`!wVw{gCY^W)$B{PO4hKiNUJ AVE_OC delta 2485 zcmYM!Rdd^56h%=7%Pv#OOfEAsGcz+YgVLsqDKn?cZOV|Dp6C3hN^N{jOfy961DvQLEgNA#5$4I^ZMY>)$TK{-$!Q~(u0B~Teui4la~Nk&ycHBcSY z05w4^P#e?%bwNE)ALM}sX@vV5GHL`GgC?LUXa<^t7N8|)1zLkPpe<-uh9J?NQ3sF$ z9YH718FT?%K{wDH^Z-3UFVGtVeNbP}5A+8Ez(6nv3g5d0;+R02YEpU@=$%mV#x8 zXn(YRIinR|C0GSkgEe3+SO?aF4PYbK1U7>$A%b}QRz};vcCZ8N1iQd)um|h~`@nv1 z02~B|BE;{XID+Vv^EW@5#X4*qu?noC)-mh2b;3Gnow80_XX5)GuRH7EIqST2!78*a zT9>TL))nijb4eFzn#pY}%&KUu}rXX}gg)%s?Aw|-bZtzXt}>(Br6N*F5Na)fe( za)fe(a)fe(a)fe(a)ctnAfd3rbn|kAa)fe(a)fe(a)fe(a)fe(a)fe(iX#l7`sE1a z2;~Un2;~Un2;~Un2=%~!$q~vC$`Pt4e*eao;0Wai'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded From daa9fd9b8a288996b5a3c7913881a46d15cc3932 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Sat, 12 Aug 2023 23:12:35 +0300 Subject: [PATCH 2/3] sort vocabulary for faster lookup with bsearch() --- run.c | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/run.c b/run.c index f69c21a..46d7a41 100644 --- a/run.c +++ b/run.c @@ -342,24 +342,38 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // ---------------------------------------------------------------------------- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt -int str_lookup(char *str, char **vocab, int vocab_size) { - // find the first perfect match for str in vocab, return its index or -1 if not found - for (int i = 0; i < vocab_size; i++) { - if (strcmp(str, vocab[i]) == 0) { - return i; - } - } - return -1; +typedef struct { + char *str; + int id; +} TokenIndex; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = {str=str}; + TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res!=NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { + // sort vocabulary + TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < vocab_size; i++) { + sorted_vocab[i].str = vocab[i]; + sorted_vocab[i].id = i; + } + qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", vocab, vocab_size); + tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string @@ -374,7 +388,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding continue; - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1) { tokens[(*n_tokens)++] = id; @@ -395,7 +409,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1 && vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = vocab_scores[id]; @@ -418,6 +432,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } free(str_buffer); + free(sorted_vocab); } // convert token to printable string From 36b54321e519cdabac2ecb3a1247db82f2aea4bb Mon Sep 17 00:00:00 2001 From: atamyrat Date: Sun, 13 Aug 2023 23:23:32 +0300 Subject: [PATCH 3/3] bugfix: allocate +1 in tokens buffer for dummy whitespace --- run.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run.c b/run.c index 46d7a41..4680dc5 100644 --- a/run.c +++ b/run.c @@ -641,7 +641,7 @@ int main(int argc, char *argv[]) { int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { - prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int)); + prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); }