make decode safer and fix issue with skipping bad byte tokens
This commit is contained in:
@@ -380,7 +380,7 @@ typedef struct {
|
|||||||
TokenIndex *sorted_vocab;
|
TokenIndex *sorted_vocab;
|
||||||
int vocab_size;
|
int vocab_size;
|
||||||
unsigned int max_token_length;
|
unsigned int max_token_length;
|
||||||
char byte_piece[2];
|
unsigned char byte_pieces[512]; // stores all single-byte strings
|
||||||
} Tokenizer;
|
} Tokenizer;
|
||||||
|
|
||||||
int compare_tokens(const void *a, const void *b) {
|
int compare_tokens(const void *a, const void *b) {
|
||||||
@@ -393,8 +393,11 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
|||||||
// malloc space to hold the scores and the strings
|
// malloc space to hold the scores and the strings
|
||||||
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||||
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
|
||||||
t->sorted_vocab = NULL; // initialized lazily
|
t->sorted_vocab = NULL; // initialized lazily
|
||||||
|
for (int i = 0; i < 256; i++) {
|
||||||
|
t->byte_pieces[i * 2] = (unsigned char)i;
|
||||||
|
t->byte_pieces[i * 2 + 1] = '\0';
|
||||||
|
}
|
||||||
// read in the file
|
// read in the file
|
||||||
FILE *file = fopen(tokenizer_path, "rb");
|
FILE *file = fopen(tokenizer_path, "rb");
|
||||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
||||||
@@ -422,18 +425,28 @@ char* decode(Tokenizer* t, int prev_token, int token) {
|
|||||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||||
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||||
|
// parse this and convert and return the actual byte
|
||||||
unsigned char byte_val;
|
unsigned char byte_val;
|
||||||
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
||||||
// ok this token is a raw byte token, careful to only print printable chars or whitespace
|
piece = (char*)t->byte_pieces + byte_val * 2;
|
||||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
|
||||||
if (isprint(byte_val) || isspace(byte_val)) {
|
|
||||||
t->byte_piece[0] = byte_val;
|
|
||||||
piece = &t->byte_piece[0];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return piece;
|
return piece;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void safe_printf(char *piece) {
|
||||||
|
// piece might be a raw byte token, and we only want to print printable chars or whitespace
|
||||||
|
// because some of the other bytes can be various control codes, backspace, etc.
|
||||||
|
if (piece == NULL) { return; }
|
||||||
|
if (piece[0] == '\0') { return; }
|
||||||
|
if (piece[1] == '\0') {
|
||||||
|
unsigned char byte_val = piece[0];
|
||||||
|
if (!(isprint(byte_val) || isspace(byte_val))) {
|
||||||
|
return; // bad byte, don't print it
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("%s", piece);
|
||||||
|
}
|
||||||
|
|
||||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||||
@@ -754,7 +767,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
|||||||
|
|
||||||
// print the token as string, decode it with the Tokenizer object
|
// print the token as string, decode it with the Tokenizer object
|
||||||
char* piece = decode(tokenizer, token, next);
|
char* piece = decode(tokenizer, token, next);
|
||||||
printf("%s", piece);
|
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
token = next;
|
token = next;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user