prompt tokenizer improvements: utf8 support, add_dummy_prefix and byte_fallback options to match sentencepiece
This commit is contained in:
@@ -356,15 +356,34 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
|
||||
// a temporary buffer to merge two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", vocab, vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
|
||||
// first encode every individual byte in the input string
|
||||
*n_tokens = 0; // the number of tokens
|
||||
for (char *c = text; *c != '\0'; c++) {
|
||||
sprintf(str_buffer, "%c", *c);
|
||||
// reset buffer if the current byte is ASCII or leading byte
|
||||
if ((*c & 0xC0) != 0x80)
|
||||
str_len = 0;
|
||||
|
||||
str_buffer[str_len++] = *c; // append byte to the buffer
|
||||
str_buffer[str_len] = '\0';
|
||||
|
||||
if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding
|
||||
continue;
|
||||
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); }
|
||||
tokens[*n_tokens] = id;
|
||||
(*n_tokens)++;
|
||||
|
||||
if (id != -1) {
|
||||
tokens[(*n_tokens)++] = id;
|
||||
} else {
|
||||
// byte_fallback encoding
|
||||
for (int i=0; i<str_len; i++) {
|
||||
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||
@@ -401,6 +420,19 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
free(str_buffer);
|
||||
}
|
||||
|
||||
// convert token to printable string
|
||||
char *token_to_str(char **vocab, int token, int prev_token) {
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (prev_token == 1 && vocab[token][0] == ' ') ? vocab[token]+1 : vocab[token];
|
||||
// make '<0x01>' into '\x01'
|
||||
static char byte_piece[4];
|
||||
if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) {
|
||||
byte_piece[1] = '\0';
|
||||
token_str = byte_piece;
|
||||
}
|
||||
return token_str;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// utilities: time / rng
|
||||
|
||||
@@ -637,9 +669,7 @@ int main(int argc, char *argv[]) {
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
printf("%s", token_str);
|
||||
printf("%s", token_to_str(vocab, next, token));
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
|
||||
Binary file not shown.
@@ -52,8 +52,6 @@ class Tokenizer:
|
||||
t = '\n<s>\n'
|
||||
elif i == self.eos_id:
|
||||
t = '\n</s>\n'
|
||||
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
||||
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
|
||||
|
||||
|
||||
Reference in New Issue
Block a user