refactor step 1. the tokenizer, and all the other abstractions, are a total mess, refactoring things a bit
This commit is contained in:
@@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char** vocab;
|
||||||
|
float* vocab_scores;
|
||||||
|
int vocab_size;
|
||||||
|
unsigned int max_token_length;
|
||||||
|
char byte_piece[2];
|
||||||
|
} Tokenizer;
|
||||||
|
|
||||||
|
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
|
||||||
|
// i should have written the vocab_size into the tokenizer file... sigh
|
||||||
|
t->vocab_size = vocab_size;
|
||||||
|
// malloc space to hold the scores and the strings
|
||||||
|
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||||
|
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||||
|
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
||||||
|
// read in the file
|
||||||
|
FILE *file = fopen(tokenizer, "rb");
|
||||||
|
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
|
||||||
|
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||||
|
int len;
|
||||||
|
for (int i = 0; i < vocab_size; i++) {
|
||||||
|
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
|
||||||
|
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||||
|
t->vocab[i] = (char *)malloc(len + 1);
|
||||||
|
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||||
|
t->vocab[i][len] = '\0'; // add the string terminating token
|
||||||
|
}
|
||||||
|
fclose(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_tokenizer(Tokenizer* t) {
|
||||||
|
for (int i = 0; i < t->vocab_size; i++) {
|
||||||
|
free(t->vocab[i]);
|
||||||
|
}
|
||||||
|
free(t->vocab);
|
||||||
|
free(t->vocab_scores);
|
||||||
|
}
|
||||||
|
|
||||||
|
char* get_piece(Tokenizer* t, int prev_token, int token) {
|
||||||
|
char *piece = t->vocab[token];
|
||||||
|
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||||
|
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||||
|
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||||
|
unsigned char byte_val;
|
||||||
|
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
||||||
|
// ok this token is a raw byte token, careful to only print printable chars or whitespace
|
||||||
|
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||||
|
if (isprint(byte_val) || isspace(byte_val)) {
|
||||||
|
t->byte_piece[0] = byte_val;
|
||||||
|
piece = &t->byte_piece[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return piece;
|
||||||
|
}
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
char *str;
|
char *str;
|
||||||
@@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
|||||||
return res != NULL ? res->id : -1;
|
return res != NULL ? res->id : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||||
|
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||||
|
|
||||||
// sort vocabulary
|
// sort vocabulary
|
||||||
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
|
||||||
for (int i = 0; i < vocab_size; i++) {
|
for (int i = 0; i < t->vocab_size; i++) {
|
||||||
sorted_vocab[i].str = vocab[i];
|
sorted_vocab[i].str = t->vocab[i];
|
||||||
sorted_vocab[i].id = i;
|
sorted_vocab[i].id = i;
|
||||||
}
|
}
|
||||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||||
|
|
||||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||||
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
|
|
||||||
// add_dummy_prefix is true by default
|
// add_dummy_prefix is true by default
|
||||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
|
||||||
*n_tokens = 1; // the number of tokens
|
*n_tokens = 1; // the number of tokens
|
||||||
|
|
||||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||||
@@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||||
|
|
||||||
if (id != -1) {
|
if (id != -1) {
|
||||||
// we found this codepoint in vocab, add it as a token
|
// we found this codepoint in vocab, add it as a token
|
||||||
@@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
|
|
||||||
for (int i=0; i < (*n_tokens-1); i++) {
|
for (int i=0; i < (*n_tokens-1); i++) {
|
||||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
||||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||||
if (id != -1 && vocab_scores[id] > best_score) {
|
if (id != -1 && t->vocab_scores[id] > best_score) {
|
||||||
// this merge pair exists in vocab! record its score and position
|
// this merge pair exists in vocab! record its score and position
|
||||||
best_score = vocab_scores[id];
|
best_score = t->vocab_scores[id];
|
||||||
best_id = id;
|
best_id = id;
|
||||||
best_idx = i;
|
best_idx = i;
|
||||||
}
|
}
|
||||||
@@ -587,8 +643,8 @@ void error_usage() {
|
|||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
// default inits
|
// default inits
|
||||||
char *checkpoint = NULL; // e.g. out/model.bin
|
char *checkpoint_path = NULL; // e.g. out/model.bin
|
||||||
char *tokenizer = "tokenizer.bin";
|
char *tokenizer_path = "tokenizer.bin";
|
||||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||||
rng_seed = 0; // seed rng with time by default
|
rng_seed = 0; // seed rng with time by default
|
||||||
@@ -596,7 +652,7 @@ int main(int argc, char *argv[]) {
|
|||||||
char *prompt = NULL; // prompt string
|
char *prompt = NULL; // prompt string
|
||||||
|
|
||||||
// poor man's C argparse so we can override the defaults above from the command line
|
// poor man's C argparse so we can override the defaults above from the command line
|
||||||
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
|
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||||
for (int i = 2; i < argc; i+=2) {
|
for (int i = 2; i < argc; i+=2) {
|
||||||
// do some basic validation
|
// do some basic validation
|
||||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||||
@@ -608,7 +664,7 @@ int main(int argc, char *argv[]) {
|
|||||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||||
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
|
||||||
else { error_usage(); }
|
else { error_usage(); }
|
||||||
}
|
}
|
||||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||||
@@ -619,29 +675,14 @@ int main(int argc, char *argv[]) {
|
|||||||
int fd = 0; // file descriptor for memory mapping
|
int fd = 0; // file descriptor for memory mapping
|
||||||
float* data = NULL; // memory mapped data pointer
|
float* data = NULL; // memory mapped data pointer
|
||||||
ssize_t file_size; // size of the checkpoint file in bytes
|
ssize_t file_size; // size of the checkpoint file in bytes
|
||||||
read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size);
|
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
|
||||||
|
|
||||||
// right now we cannot run for more than config.seq_len steps
|
// right now we cannot run for more than config.seq_len steps
|
||||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||||
|
|
||||||
// read in the tokenizer .bin file
|
// read in the tokenizer .bin file
|
||||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
Tokenizer tokenizer;
|
||||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
|
||||||
unsigned int max_token_length;
|
|
||||||
{
|
|
||||||
FILE *file = fopen(tokenizer, "rb");
|
|
||||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
|
|
||||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
|
||||||
int len;
|
|
||||||
for (int i = 0; i < config.vocab_size; i++) {
|
|
||||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
|
|
||||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
|
||||||
vocab[i] = (char *)malloc(len + 1);
|
|
||||||
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
|
||||||
vocab[i][len] = '\0'; // add the string terminating token
|
|
||||||
}
|
|
||||||
fclose(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and init the application RunState
|
// create and init the application RunState
|
||||||
RunState state;
|
RunState state;
|
||||||
@@ -653,7 +694,7 @@ int main(int argc, char *argv[]) {
|
|||||||
int num_prompt_tokens = 0;
|
int num_prompt_tokens = 0;
|
||||||
if (prompt != NULL) {
|
if (prompt != NULL) {
|
||||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the main loop
|
// start the main loop
|
||||||
@@ -695,22 +736,9 @@ int main(int argc, char *argv[]) {
|
|||||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||||
if (next == 1) { break; }
|
if (next == 1) { break; }
|
||||||
|
|
||||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
// print the token as string, decode it with the Tokenizer object
|
||||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
char* piece = get_piece(&tokenizer, token, next);
|
||||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
printf("%s", piece);
|
||||||
unsigned char byte_val;
|
|
||||||
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
|
|
||||||
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
|
|
||||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
|
||||||
if (isprint(byte_val) || isspace(byte_val)) {
|
|
||||||
char byte_piece[2];
|
|
||||||
byte_piece[0] = byte_val;
|
|
||||||
byte_piece[1] = '\0';
|
|
||||||
printf("%s", byte_piece);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
printf("%s", token_str);
|
|
||||||
}
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
token = next;
|
token = next;
|
||||||
|
|
||||||
@@ -728,9 +756,7 @@ int main(int argc, char *argv[]) {
|
|||||||
// memory and file handles cleanup
|
// memory and file handles cleanup
|
||||||
free_run_state(&state);
|
free_run_state(&state);
|
||||||
free(probindex);
|
free(probindex);
|
||||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
free_tokenizer(&tokenizer);
|
||||||
free(vocab);
|
|
||||||
free(vocab_scores);
|
|
||||||
if (prompt_tokens != NULL) free(prompt_tokens);
|
if (prompt_tokens != NULL) free(prompt_tokens);
|
||||||
if (data != MAP_FAILED) munmap(data, file_size);
|
if (data != MAP_FAILED) munmap(data, file_size);
|
||||||
if (fd != -1) close(fd);
|
if (fd != -1) close(fd);
|
||||||
|
|||||||
Reference in New Issue
Block a user