make decode safer and fix issue with skipping bad byte tokens
This commit is contained in:
@@ -380,7 +380,7 @@ typedef struct {
|
||||
TokenIndex *sorted_vocab;
|
||||
int vocab_size;
|
||||
unsigned int max_token_length;
|
||||
char byte_piece[2];
|
||||
unsigned char byte_pieces[512]; // stores all single-byte strings
|
||||
} Tokenizer;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
@@ -393,8 +393,11 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||
// malloc space to hold the scores and the strings
|
||||
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
||||
t->sorted_vocab = NULL; // initialized lazily
|
||||
for (int i = 0; i < 256; i++) {
|
||||
t->byte_pieces[i * 2] = (unsigned char)i;
|
||||
t->byte_pieces[i * 2 + 1] = '\0';
|
||||
}
|
||||
// read in the file
|
||||
FILE *file = fopen(tokenizer_path, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
||||
@@ -422,18 +425,28 @@ char* decode(Tokenizer* t, int prev_token, int token) {
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
// parse this and convert and return the actual byte
|
||||
unsigned char byte_val;
|
||||
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, careful to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
t->byte_piece[0] = byte_val;
|
||||
piece = &t->byte_piece[0];
|
||||
}
|
||||
piece = (char*)t->byte_pieces + byte_val * 2;
|
||||
}
|
||||
return piece;
|
||||
}
|
||||
|
||||
void safe_printf(char *piece) {
|
||||
// piece might be a raw byte token, and we only want to print printable chars or whitespace
|
||||
// because some of the other bytes can be various control codes, backspace, etc.
|
||||
if (piece == NULL) { return; }
|
||||
if (piece[0] == '\0') { return; }
|
||||
if (piece[1] == '\0') {
|
||||
unsigned char byte_val = piece[0];
|
||||
if (!(isprint(byte_val) || isspace(byte_val))) {
|
||||
return; // bad byte, don't print it
|
||||
}
|
||||
}
|
||||
printf("%s", piece);
|
||||
}
|
||||
|
||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||
@@ -754,7 +767,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
printf("%s", piece);
|
||||
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user