diff --git a/run.c b/run.c index bf4ce34..da1fbc4 100644 --- a/run.c +++ b/run.c @@ -380,7 +380,7 @@ typedef struct { TokenIndex *sorted_vocab; int vocab_size; unsigned int max_token_length; - char byte_piece[2]; + unsigned char byte_pieces[512]; // stores all single-byte strings } Tokenizer; int compare_tokens(const void *a, const void *b) { @@ -393,8 +393,11 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { // malloc space to hold the scores and the strings t->vocab = (char**)malloc(vocab_size * sizeof(char*)); t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); - t->byte_piece[1] = '\0'; // null terminate the byte_piece string t->sorted_vocab = NULL; // initialized lazily + for (int i = 0; i < 256; i++) { + t->byte_pieces[i * 2] = (unsigned char)i; + t->byte_pieces[i * 2 + 1] = '\0'; + } // read in the file FILE *file = fopen(tokenizer_path, "rb"); if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } @@ -422,18 +425,28 @@ char* decode(Tokenizer* t, int prev_token, int token) { // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) if (prev_token == 1 && piece[0] == ' ') { piece++; } // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + // parse this and convert and return the actual byte unsigned char byte_val; if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { - // ok this token is a raw byte token, careful to only print printable chars or whitespace - // some of the other bytes can be various control codes, backspace, etc. => skip - if (isprint(byte_val) || isspace(byte_val)) { - t->byte_piece[0] = byte_val; - piece = &t->byte_piece[0]; - } + piece = (char*)t->byte_pieces + byte_val * 2; } return piece; } +void safe_printf(char *piece) { + // piece might be a raw byte token, and we only want to print printable chars or whitespace + // because some of the other bytes can be various control codes, backspace, etc. + if (piece == NULL) { return; } + if (piece[0] == '\0') { return; } + if (piece[1] == '\0') { + unsigned char byte_val = piece[0]; + if (!(isprint(byte_val) || isspace(byte_val))) { + return; // bad byte, don't print it + } + } + printf("%s", piece); +} + int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { // efficiently find the perfect match for str in vocab, return its index or -1 if not found TokenIndex tok = { .str = str }; // acts as the key to search for @@ -754,7 +767,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, // print the token as string, decode it with the Tokenizer object char* piece = decode(tokenizer, token, next); - printf("%s", piece); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); token = next;