From c02865df300f3bd9e567ce061000dc23bf785a17 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Fri, 4 Aug 2023 04:18:20 +0300 Subject: [PATCH] prompt tokenizer improvements: utf8 support, add_dummy_prefix and byte_fallback options to match sentencepiece --- run.c | 46 ++++++++++++++++++++++++++++++++++++++-------- tokenizer.bin | Bin 432717 -> 433869 bytes tokenizer.py | 2 -- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 02cc877..f69c21a 100644 --- a/run.c +++ b/run.c @@ -356,15 +356,34 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", vocab, vocab_size); + *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); + // reset buffer if the current byte is ASCII or leading byte + if ((*c & 0xC0) != 0x80) + str_len = 0; + + str_buffer[str_len++] = *c; // append byte to the buffer + str_buffer[str_len] = '\0'; + + if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + continue; + int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + + if (id != -1) { + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding + for (int i=0; i' into '\x01' + static char byte_piece[4]; + if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { + byte_piece[1] = '\0'; + token_str = byte_piece; + } + return token_str; +} + // ---------------------------------------------------------------------------- // utilities: time / rng @@ -637,9 +669,7 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + printf("%s", token_to_str(vocab, next, token)); fflush(stdout); token = next; diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7bc47fda2ecaab1e2cc2255d140925ed704..e6c1b23ec0c18ebfd7162ef838a77d559b7e3316 100644 GIT binary patch delta 3646 zcmY+_JC55h6op~uHGp?8m{pi>5(o_Nxh8ikQlyPlwZ=WA;@0b2^I-P!AZkNk-Je;;9kR*~q(nto$A~_`gNc~wK@G>GLq>NOMDpH%1 z08nhe(LeM8LII(GP(Uak6cCE1@BtJ6iXGV3dI$xC0zv_yfKWmxH{lAD0LlZ{hh9P` zA(RkG2qlCPLb(eUpaf7(z&>6=C?S*(DhL&X3PN=VN4)}2&A@(B5Gn{2gbG3hp@L9N z!ch;P0M4WWimL#SurO#`R_)Ez<%p@vXHs3Fu4Y6vxe20$bF{Er$44TJ_l z1EGP?KxhCo0GfDR?w{?}GY}dG4TJ_l1EGb`VuThzJAh07^U}2tS_mzK7D5Z5h0tPz z7C^fINB>q2p@q;w=pb|uItU#`=m2ypuy6GaLI&;jTkz-|9I2t9-z zLJy&b&_n1kLJy#S0+;!Z8_GlIA@mS>2t9-zLXQ!80K*0x{i8lW7$6J~1_%R$0m6V0 z1^~kjT=!pqFhCd}3=jqg1B7uCuD}RjJb=soix5T#BZLvc2w{XU?!x8!MgZdk9Q!Xq z7$J-hMhFvx3Bq&;M|}b?&A@(35GDu{gbBg~VS+GC!ch-k0aLzrjb zEd!VV%pJlEVTLe6m?6v%W(YHY8NhPE{woj`2n&P-!UAD|umD&9ECKtkKv*Cw5Ecjv zgayKi5mo@}08aa_LRcZJ5LO5)gcZVy5mo@}0zCF#g|I?cA*>KK2pfbABWwV+6}a!e z24RD+LD(Q{5H<)KM%Vys58&%_`!@(%yZ*YLf8Wnv#`!wVw{gCY^W)$B{PO4hKiNUJ AVE_OC delta 2485 zcmYM!Rdd^56h%=7%Pv#OOfEAsGcz+YgVLsqDKn?cZOV|Dp6C3hN^N{jOfy961DvQLEgNA#5$4I^ZMY>)$TK{-$!Q~(u0B~Teui4la~Nk&ycHBcSY z05w4^P#e?%bwNE)ALM}sX@vV5GHL`GgC?LUXa<^t7N8|)1zLkPpe<-uh9J?NQ3sF$ z9YH718FT?%K{wDH^Z-3UFVGtVeNbP}5A+8Ez(6nv3g5d0;+R02YEpU@=$%mV#x8 zXn(YRIinR|C0GSkgEe3+SO?aF4PYbK1U7>$A%b}QRz};vcCZ8N1iQd)um|h~`@nv1 z02~B|BE;{XID+Vv^EW@5#X4*qu?noC)-mh2b;3Gnow80_XX5)GuRH7EIqST2!78*a zT9>TL))nijb4eFzn#pY}%&KUu}rXX}gg)%s?Aw|-bZtzXt}>(Br6N*F5Na)fe( za)fe(a)fe(a)fe(a)ctnAfd3rbn|kAa)fe(a)fe(a)fe(a)fe(a)fe(iX#l7`sE1a z2;~Un2;~Un2;~Un2=%~!$q~vC$`Pt4e*eao;0Wai'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded