make decode safer and fix issue with skipping bad byte tokens

This commit is contained in:
Andrej Karpathy
2023-08-23 01:08:31 +00:00
parent 4b3e66021a
commit 7ac65cb2c2
+22 -9
View File
@@ -380,7 +380,7 @@ typedef struct {
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
char byte_piece[2];
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;
int compare_tokens(const void *a, const void *b) {
@@ -393,8 +393,11 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
@@ -422,18 +425,28 @@ char* decode(Tokenizer* t, int prev_token, int token) {
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
// ok this token is a raw byte token, careful to only print printable chars or whitespace
// some of the other bytes can be various control codes, backspace, etc. => skip
if (isprint(byte_val) || isspace(byte_val)) {
t->byte_piece[0] = byte_val;
piece = &t->byte_piece[0];
}
piece = (char*)t->byte_pieces + byte_val * 2;
}
return piece;
}
void safe_printf(char *piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
if (piece == NULL) { return; }
if (piece[0] == '\0') { return; }
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
printf("%s", piece);
}
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
@@ -754,7 +767,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
printf("%s", piece);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
token = next;