diff --git a/run.c b/run.c index 02cc877..f69c21a 100644 --- a/run.c +++ b/run.c @@ -356,15 +356,34 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u // a temporary buffer to merge two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + size_t str_len = 0; + + // add_dummy_prefix is true by default + tokens[0] = str_lookup(" ", vocab, vocab_size); + *n_tokens = 1; // the number of tokens // first encode every individual byte in the input string - *n_tokens = 0; // the number of tokens for (char *c = text; *c != '\0'; c++) { - sprintf(str_buffer, "%c", *c); + // reset buffer if the current byte is ASCII or leading byte + if ((*c & 0xC0) != 0x80) + str_len = 0; + + str_buffer[str_len++] = *c; // append byte to the buffer + str_buffer[str_len] = '\0'; + + if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + continue; + int id = str_lookup(str_buffer, vocab, vocab_size); - if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); } - tokens[*n_tokens] = id; - (*n_tokens)++; + + if (id != -1) { + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding + for (int i=0; i' into '\x01' + static char byte_piece[4]; + if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) { + byte_piece[1] = '\0'; + token_str = byte_piece; + } + return token_str; +} + // ---------------------------------------------------------------------------- // utilities: time / rng @@ -637,9 +669,7 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - printf("%s", token_str); + printf("%s", token_to_str(vocab, next, token)); fflush(stdout); token = next; diff --git a/tokenizer.bin b/tokenizer.bin index e0a8a7b..e6c1b23 100644 Binary files a/tokenizer.bin and b/tokenizer.bin differ diff --git a/tokenizer.py b/tokenizer.py index 35eee20..bbe1bc9 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -52,8 +52,6 @@ class Tokenizer: t = '\n\n' elif i == self.eos_id: t = '\n\n' - elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'): - t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace b = t.encode('utf-8') # bytes of this token, utf-8 encoded