diff --git a/run.c b/run.c index 56ceff5..8f565cd 100644 --- a/run.c +++ b/run.c @@ -10,6 +10,7 @@ $ ./run #include #include +#include #include #include #include @@ -347,29 +348,85 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // ---------------------------------------------------------------------------- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt -int str_lookup(char *str, char **vocab, int vocab_size) { - // find the first perfect match for str in vocab, return its index or -1 if not found - for (int i = 0; i < vocab_size; i++) { - if (strcmp(str, vocab[i]) == 0) { - return i; - } - } - return -1; +typedef struct { + char *str; + int id; +} TokenIndex; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // efficiently find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = { .str = str }; // acts as the key to search for + TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res != NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { - // a temporary buffer to merge two consecutive tokens - char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + // sort vocabulary + TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < vocab_size; i++) { + sorted_vocab[i].str = vocab[i]; + sorted_vocab[i].id = i; + } + qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); - // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens + // create a temporary buffer that will store merge candidates of always two consecutive tokens + char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); + *n_tokens = 1; // the number of tokens + + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + + // process the raw (UTF-8) byte sequence of the input string for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); - int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest + // 0x80 is 10000000 + // in UTF-8, all continuation bytes start with "10" in first two bits + // so in English this is: "if this byte is not a continuation byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line + str_buffer[str_len] = '\0'; + + // while the next character is a continuation byte, continue appending + if ((*(c+1) & 0xC0) == 0x80) { + continue; + } + + // ok c+1 is not a continuation byte, so we've read in a full codepoint + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); + + if (id != -1) { + // we found this codepoint in vocab, add it as a token + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding: just encode each byte as a token + // +3 is here because the first 3 vocab elements are , , + // so the individual bytes only start at index 3 + for (int i=0; i < str_len; i++) { + tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; + } + } } // merge the best consecutive pair each iteration, according the scores in vocab_scores @@ -381,7 +438,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1 && vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = vocab_scores[id]; @@ -404,6 +461,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } free(str_buffer); + free(sorted_vocab); } // ---------------------------------------------------------------------------- @@ -609,7 +667,7 @@ int main(int argc, char *argv[]) { int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { - prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int)); + prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); } @@ -654,7 +712,20 @@ int main(int argc, char *argv[]) { // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + unsigned char byte_val; + if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) { + // ok this token is a raw byte token, carefuly to only print printable chars or whitespace + // some of the other bytes can be various control codes, backspace, etc. => skip + if (isprint(byte_val) || isspace(byte_val)) { + char byte_piece[2]; + byte_piece[0] = byte_val; + byte_piece[1] = '\0'; + printf("%s", byte_piece); + } + } else { + printf("%s", token_str); + } fflush(stdout); token = next; diff --git a/sample.py b/sample.py index b26e277..d2f56ea 100644 --- a/sample.py +++ b/sample.py @@ -51,11 +51,16 @@ if compile: print("Compiling the model...") model = torch.compile(model) # requires PyTorch 2.0 (optional) -# load the tokenizer, either provided, or attempt to find it +# load the tokenizer +vocab_source = checkpoint_dict.get("vocab_source", "llama2") +vocab_size = gptconf.vocab_size if tokenizer: + # a specific tokenizer is provided, use it tokenizer_model = tokenizer else: - tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) + # let's try to find the tokenizer model automatically. bit gross here... + query_vocab_size = 0 if vocab_source == "llama2" else vocab_size + tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7b..e6c1b23 100644 Binary files a/tokenizer.bin and b/tokenizer.bin differ diff --git a/tokenizer.py b/tokenizer.py index bc2a35a..f3c0cc3 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -51,8 +51,6 @@ class Tokenizer: t = '\n\n' elif i == self.eos_id: t = '\n\n' - elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded