From c02865df300f3bd9e567ce061000dc23bf785a17 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Fri, 4 Aug 2023 04:18:20 +0300 Subject: [PATCH 1/7] prompt tokenizer improvements: utf8 support, add_dummy_prefix and byte_fallback options to match sentencepiece --- run.c | 46 ++++++++++++++++++++++++++++++++++++++-------- tokenizer.bin | Bin 432717 -> 433869 bytes tokenizer.py | 2 -- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 02cc877..f69c21a 100644 --- a/run.c +++ b/run.c @@ -356,15 +356,34 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", vocab, vocab_size); + *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); + // reset buffer if the current byte is ASCII or leading byte + if ((*c & 0xC0) != 0x80) + str_len = 0; + + str_buffer[str_len++] = *c; // append byte to the buffer + str_buffer[str_len] = '\0'; + + if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + continue; + int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + + if (id != -1) { + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding + for (int i=0; i' into '\x01' + static char byte_piece[4]; + if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { + byte_piece[1] = '\0'; + token_str = byte_piece; + } + return token_str; +} + // ---------------------------------------------------------------------------- // utilities: time / rng @@ -637,9 +669,7 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + printf("%s", token_to_str(vocab, next, token)); fflush(stdout); token = next; diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7bc47fda2ecaab1e2cc2255d140925ed704..e6c1b23ec0c18ebfd7162ef838a77d559b7e3316 100644 GIT binary patch delta 3646 zcmY+_JC55h6op~uHGp?8m{pi>5(o_Nxh8ikQlyPlwZ=WA;@0b2^I-P!AZkNk-Je;;9kR*~q(nto$A~_`gNc~wK@G>GLq>NOMDpH%1 z08nhe(LeM8LII(GP(Uak6cCE1@BtJ6iXGV3dI$xC0zv_yfKWmxH{lAD0LlZ{hh9P` zA(RkG2qlCPLb(eUpaf7(z&>6=C?S*(DhL&X3PN=VN4)}2&A@(B5Gn{2gbG3hp@L9N z!ch;P0M4WWimL#SurO#`R_)Ez<%p@vXHs3Fu4Y6vxe20$bF{Er$44TJ_l z1EGP?KxhCo0GfDR?w{?}GY}dG4TJ_l1EGb`VuThzJAh07^U}2tS_mzK7D5Z5h0tPz z7C^fINB>q2p@q;w=pb|uItU#`=m2ypuy6GaLI&;jTkz-|9I2t9-z zLJy&b&_n1kLJy#S0+;!Z8_GlIA@mS>2t9-zLXQ!80K*0x{i8lW7$6J~1_%R$0m6V0 z1^~kjT=!pqFhCd}3=jqg1B7uCuD}RjJb=soix5T#BZLvc2w{XU?!x8!MgZdk9Q!Xq z7$J-hMhFvx3Bq&;M|}b?&A@(35GDu{gbBg~VS+GC!ch-k0aLzrjb zEd!VV%pJlEVTLe6m?6v%W(YHY8NhPE{woj`2n&P-!UAD|umD&9ECKtkKv*Cw5Ecjv zgayKi5mo@}08aa_LRcZJ5LO5)gcZVy5mo@}0zCF#g|I?cA*>KK2pfbABWwV+6}a!e z24RD+LD(Q{5H<)KM%Vys58&%_`!@(%yZ*YLf8Wnv#`!wVw{gCY^W)$B{PO4hKiNUJ AVE_OC delta 2485 zcmYM!Rdd^56h%=7%Pv#OOfEAsGcz+YgVLsqDKn?cZOV|Dp6C3hN^N{jOfy961DvQLEgNA#5$4I^ZMY>)$TK{-$!Q~(u0B~Teui4la~Nk&ycHBcSY z05w4^P#e?%bwNE)ALM}sX@vV5GHL`GgC?LUXa<^t7N8|)1zLkPpe<-uh9J?NQ3sF$ z9YH718FT?%K{wDH^Z-3UFVGtVeNbP}5A+8Ez(6nv3g5d0;+R02YEpU@=$%mV#x8 zXn(YRIinR|C0GSkgEe3+SO?aF4PYbK1U7>$A%b}QRz};vcCZ8N1iQd)um|h~`@nv1 z02~B|BE;{XID+Vv^EW@5#X4*qu?noC)-mh2b;3Gnow80_XX5)GuRH7EIqST2!78*a zT9>TL))nijb4eFzn#pY}%&KUu}rXX}gg)%s?Aw|-bZtzXt}>(Br6N*F5Na)fe( za)fe(a)fe(a)fe(a)ctnAfd3rbn|kAa)fe(a)fe(a)fe(a)fe(a)fe(iX#l7`sE1a z2;~Un2;~Un2;~Un2=%~!$q~vC$`Pt4e*eao;0Wai'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded From daa9fd9b8a288996b5a3c7913881a46d15cc3932 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Sat, 12 Aug 2023 23:12:35 +0300 Subject: [PATCH 2/7] sort vocabulary for faster lookup with bsearch() --- run.c | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/run.c b/run.c index f69c21a..46d7a41 100644 --- a/run.c +++ b/run.c @@ -342,24 +342,38 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // ---------------------------------------------------------------------------- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt -int str_lookup(char *str, char **vocab, int vocab_size) { - // find the first perfect match for str in vocab, return its index or -1 if not found - for (int i = 0; i < vocab_size; i++) { - if (strcmp(str, vocab[i]) == 0) { - return i; - } - } - return -1; +typedef struct { + char *str; + int id; +} TokenIndex; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = {str=str}; + TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res!=NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { + // sort vocabulary + TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < vocab_size; i++) { + sorted_vocab[i].str = vocab[i]; + sorted_vocab[i].id = i; + } + qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", vocab, vocab_size); + tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string @@ -374,7 +388,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding continue; - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1) { tokens[(*n_tokens)++] = id; @@ -395,7 +409,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1 && vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = vocab_scores[id]; @@ -418,6 +432,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } free(str_buffer); + free(sorted_vocab); } // convert token to printable string From 36b54321e519cdabac2ecb3a1247db82f2aea4bb Mon Sep 17 00:00:00 2001 From: atamyrat Date: Sun, 13 Aug 2023 23:23:32 +0300 Subject: [PATCH 3/7] bugfix: allocate +1 in tokens buffer for dummy whitespace --- run.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run.c b/run.c index 46d7a41..4680dc5 100644 --- a/run.c +++ b/run.c @@ -641,7 +641,7 @@ int main(int argc, char *argv[]) { int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { - prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int)); + prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); } From 4bf36ecc1792ce2ed579d6c5718fc38b5a035677 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 15 Aug 2023 01:04:10 +0000 Subject: [PATCH 4/7] get rid of the special byte decoding logic --- run.c | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/run.c b/run.c index 8da8823..33560fe 100644 --- a/run.c +++ b/run.c @@ -358,7 +358,7 @@ int compare_tokens(const void *a, const void *b) { int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { // find the perfect match for str in vocab, return its index or -1 if not found - TokenIndex tok = {str=str}; + TokenIndex tok = {str=str}; TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); return res!=NULL ? res->id : -1; } @@ -440,19 +440,6 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u free(sorted_vocab); } -// convert token to printable string -char *token_to_str(char **vocab, int token, int prev_token) { - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (prev_token == 1 && vocab[token][0] == ' ') ? vocab[token]+1 : vocab[token]; - // make '<0x01>' into '\x01' - static char byte_piece[4]; - if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { - byte_piece[1] = '\0'; - token_str = byte_piece; - } - return token_str; -} - // ---------------------------------------------------------------------------- // utilities: time / rng @@ -699,7 +686,9 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - printf("%s", token_to_str(vocab, next, token)); + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) + char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; + printf("%s", token_str); fflush(stdout); token = next; From d459fd4243cddf5893231cbaa70da26e598cfa53 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 15 Aug 2023 01:42:33 +0000 Subject: [PATCH 5/7] add back careful processing of the byte tokens --- run.c | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/run.c b/run.c index 33560fe..37d3018 100644 --- a/run.c +++ b/run.c @@ -10,6 +10,7 @@ $ ./run #include #include +#include #include #include #include @@ -688,7 +689,20 @@ int main(int argc, char *argv[]) { // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + unsigned char byte_val; + if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) { + // ok this token is a raw byte token, carefuly to only print printable chars or whitespace + // some of the other bytes can be various control codes, backspace, etc. => skip + if (isprint(byte_val) || isspace(byte_val)) { + char byte_piece[2]; + byte_piece[0] = byte_val; + byte_piece[1] = '\0'; + printf("%s", byte_piece); + } + } else { + printf("%s", token_str); + } fflush(stdout); token = next; From a9a0628c9254c0efcc0249cdf3d5dc0b692201a6 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 15 Aug 2023 02:18:49 +0000 Subject: [PATCH 6/7] thoroughly commented the UTF-8 byte reading code --- run.c | 49 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/run.c b/run.c index 37d3018..8f565cd 100644 --- a/run.c +++ b/run.c @@ -358,10 +358,10 @@ int compare_tokens(const void *a, const void *b) { } int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { - // find the perfect match for str in vocab, return its index or -1 if not found - TokenIndex tok = {str=str}; + // efficiently find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = { .str = str }; // acts as the key to search for TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); - return res!=NULL ? res->id : -1; + return res != NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { @@ -374,7 +374,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); - // a temporary buffer to merge two consecutive tokens + // create a temporary buffer that will store merge candidates of always two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator size_t str_len = 0; @@ -382,25 +382,48 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); *n_tokens = 1; // the number of tokens - // first encode every individual byte in the input string - for (char *c = text; *c != '\0'; c++) { - // reset buffer if the current byte is ASCII or leading byte - if ((*c & 0xC0) != 0x80) - str_len = 0; + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - str_buffer[str_len++] = *c; // append byte to the buffer + // process the raw (UTF-8) byte sequence of the input string + for (char *c = text; *c != '\0'; c++) { + + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest + // 0x80 is 10000000 + // in UTF-8, all continuation bytes start with "10" in first two bits + // so in English this is: "if this byte is not a continuation byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line str_buffer[str_len] = '\0'; - if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + // while the next character is a continuation byte, continue appending + if ((*(c+1) & 0xC0) == 0x80) { continue; + } + // ok c+1 is not a continuation byte, so we've read in a full codepoint int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1) { + // we found this codepoint in vocab, add it as a token tokens[(*n_tokens)++] = id; } else { - // byte_fallback encoding - for (int i=0; i, , + // so the individual bytes only start at index 3 + for (int i=0; i < str_len; i++) { tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; } } From fe2de68688ec35502b566fcef227a94935a3f3b7 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 15 Aug 2023 02:33:01 +0000 Subject: [PATCH 7/7] fix sample.py from tokenizer changes before --- sample.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sample.py b/sample.py index b26e277..d2f56ea 100644 --- a/sample.py +++ b/sample.py @@ -51,11 +51,16 @@ if compile: print("Compiling the model...") model = torch.compile(model) # requires PyTorch 2.0 (optional) -# load the tokenizer, either provided, or attempt to find it +# load the tokenizer +vocab_source = checkpoint_dict.get("vocab_source", "llama2") +vocab_size = gptconf.vocab_size if tokenizer: + # a specific tokenizer is provided, use it tokenizer_model = tokenizer else: - tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) + # let's try to find the tokenizer model automatically. bit gross here... + query_vocab_size = 0 if vocab_source == "llama2" else vocab_size + tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt