refactor step 1. the tokenizer, and all the other abstractions, are a total mess, refactoring things a bit

This commit is contained in:
Andrej Karpathy
2023-08-20 18:18:23 +00:00
parent 1e335a41cf
commit c74456f3f0
+81 -55
View File
@@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
}
// ----------------------------------------------------------------------------
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
typedef struct {
char** vocab;
float* vocab_scores;
int vocab_size;
unsigned int max_token_length;
char byte_piece[2];
} Tokenizer;
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
// read in the file
FILE *file = fopen(tokenizer, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i] = (char *)malloc(len + 1);
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}
void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) {
free(t->vocab[i]);
}
free(t->vocab);
free(t->vocab_scores);
}
char* get_piece(Tokenizer* t, int prev_token, int token) {
char *piece = t->vocab[token];
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
// ok this token is a raw byte token, careful to only print printable chars or whitespace
// some of the other bytes can be various control codes, backspace, etc. => skip
if (isprint(byte_val) || isspace(byte_val)) {
t->byte_piece[0] = byte_val;
piece = &t->byte_piece[0];
}
}
return piece;
}
typedef struct {
char *str;
@@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
return res != NULL ? res->id : -1;
}
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// sort vocabulary
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
for (int i = 0; i < vocab_size; i++) {
sorted_vocab[i].str = vocab[i];
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
sorted_vocab[i].str = t->vocab[i];
sorted_vocab[i].id = i;
}
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
// create a temporary buffer that will store merge candidates of always two consecutive tokens
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
size_t str_len = 0;
// add_dummy_prefix is true by default
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
*n_tokens = 1; // the number of tokens
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
@@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
}
// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
@@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
if (id != -1 && vocab_scores[id] > best_score) {
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocab_scores[id];
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
}
@@ -587,8 +643,8 @@ void error_usage() {
int main(int argc, char *argv[]) {
// default inits
char *checkpoint = NULL; // e.g. out/model.bin
char *tokenizer = "tokenizer.bin";
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
rng_seed = 0; // seed rng with time by default
@@ -596,7 +652,7 @@ int main(int argc, char *argv[]) {
char *prompt = NULL; // prompt string
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
for (int i = 2; i < argc; i+=2) {
// do some basic validation
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
@@ -608,7 +664,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
else { error_usage(); }
}
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
@@ -619,29 +675,14 @@ int main(int argc, char *argv[]) {
int fd = 0; // file descriptor for memory mapping
float* data = NULL; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size);
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// read in the tokenizer .bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
unsigned int max_token_length;
{
FILE *file = fopen(tokenizer, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
int len;
for (int i = 0; i < config.vocab_size; i++) {
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
vocab[i] = (char *)malloc(len + 1);
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}
Tokenizer tokenizer;
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
// create and init the application RunState
RunState state;
@@ -653,7 +694,7 @@ int main(int argc, char *argv[]) {
int num_prompt_tokens = 0;
if (prompt != NULL) {
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
}
// start the main loop
@@ -695,22 +736,9 @@ int main(int argc, char *argv[]) {
// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
unsigned char byte_val;
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
// some of the other bytes can be various control codes, backspace, etc. => skip
if (isprint(byte_val) || isspace(byte_val)) {
char byte_piece[2];
byte_piece[0] = byte_val;
byte_piece[1] = '\0';
printf("%s", byte_piece);
}
} else {
printf("%s", token_str);
}
// print the token as string, decode it with the Tokenizer object
char* piece = get_piece(&tokenizer, token, next);
printf("%s", piece);
fflush(stdout);
token = next;
@@ -728,9 +756,7 @@ int main(int argc, char *argv[]) {
// memory and file handles cleanup
free_run_state(&state);
free(probindex);
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
free(vocab_scores);
free_tokenizer(&tokenizer);
if (prompt_tokens != NULL) free(prompt_tokens);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);