make decode safer and fix issue with skipping bad byte tokens

This commit is contained in:
Andrej Karpathy
2023-08-23 01:08:31 +00:00
parent 4b3e66021a
commit 7ac65cb2c2
+22 -9
View File
@@ -380,7 +380,7 @@ typedef struct {
TokenIndex *sorted_vocab; TokenIndex *sorted_vocab;
int vocab_size; int vocab_size;
unsigned int max_token_length; unsigned int max_token_length;
char byte_piece[2]; unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer; } Tokenizer;
int compare_tokens(const void *a, const void *b) { int compare_tokens(const void *a, const void *b) {
@@ -393,8 +393,11 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// malloc space to hold the scores and the strings // malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*)); t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
t->sorted_vocab = NULL; // initialized lazily t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file // read in the file
FILE *file = fopen(tokenizer_path, "rb"); FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
@@ -422,18 +425,28 @@ char* decode(Tokenizer* t, int prev_token, int token) {
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; } if (prev_token == 1 && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>' // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val; unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
// ok this token is a raw byte token, careful to only print printable chars or whitespace piece = (char*)t->byte_pieces + byte_val * 2;
// some of the other bytes can be various control codes, backspace, etc. => skip
if (isprint(byte_val) || isspace(byte_val)) {
t->byte_piece[0] = byte_val;
piece = &t->byte_piece[0];
}
} }
return piece; return piece;
} }
void safe_printf(char *piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
if (piece == NULL) { return; }
if (piece[0] == '\0') { return; }
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
printf("%s", piece);
}
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found // efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for TokenIndex tok = { .str = str }; // acts as the key to search for
@@ -754,7 +767,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// print the token as string, decode it with the Tokenizer object // print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next); char* piece = decode(tokenizer, token, next);
printf("%s", piece); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout); fflush(stdout);
token = next; token = next;