diff --git a/run.c b/run.c index 56ceff5..8da8823 100644 --- a/run.c +++ b/run.c @@ -347,29 +347,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // ---------------------------------------------------------------------------- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt -int str_lookup(char *str, char **vocab, int vocab_size) { - // find the first perfect match for str in vocab, return its index or -1 if not found - for (int i = 0; i < vocab_size; i++) { - if (strcmp(str, vocab[i]) == 0) { - return i; - } - } - return -1; +typedef struct { + char *str; + int id; +} TokenIndex; + +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { + // find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = {str=str}; + TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res!=NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { + // sort vocabulary + TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < vocab_size; i++) { + sorted_vocab[i].str = vocab[i]; + sorted_vocab[i].id = i; + } + qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); + *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); - int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + // reset buffer if the current byte is ASCII or leading byte + if ((*c & 0xC0) != 0x80) + str_len = 0; + + str_buffer[str_len++] = *c; // append byte to the buffer + str_buffer[str_len] = '\0'; + + if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + continue; + + int id = str_lookup(str_buffer, sorted_vocab, vocab_size); + + if (id != -1) { + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding + for (int i=0; i best_score) { // this merge pair exists in vocab! record its score and position best_score = vocab_scores[id]; @@ -404,6 +437,20 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } free(str_buffer); + free(sorted_vocab); +} + +// convert token to printable string +char *token_to_str(char **vocab, int token, int prev_token) { + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) + char *token_str = (prev_token == 1 && vocab[token][0] == ' ') ? vocab[token]+1 : vocab[token]; + // make '<0x01>' into '\x01' + static char byte_piece[4]; + if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { + byte_piece[1] = '\0'; + token_str = byte_piece; + } + return token_str; } // ---------------------------------------------------------------------------- @@ -609,7 +656,7 @@ int main(int argc, char *argv[]) { int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { - prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int)); + prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); } @@ -652,9 +699,7 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + printf("%s", token_to_str(vocab, next, token)); fflush(stdout); token = next; diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7b..e6c1b23 100644 Binary files a/tokenizer.bin and b/tokenizer.bin differ diff --git a/tokenizer.py b/tokenizer.py index bc2a35a..f3c0cc3 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -51,8 +51,6 @@ class Tokenizer: t = '\n\n' elif i == self.eos_id: t = '\n\n' - elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded