refactor step 1. the tokenizer, and all the other abstractions, are a total mess, refactoring things a bit
This commit is contained in:
@@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
|
||||
|
||||
typedef struct {
|
||||
char** vocab;
|
||||
float* vocab_scores;
|
||||
int vocab_size;
|
||||
unsigned int max_token_length;
|
||||
char byte_piece[2];
|
||||
} Tokenizer;
|
||||
|
||||
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
|
||||
// i should have written the vocab_size into the tokenizer file... sigh
|
||||
t->vocab_size = vocab_size;
|
||||
// malloc space to hold the scores and the strings
|
||||
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
||||
// read in the file
|
||||
FILE *file = fopen(tokenizer, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
|
||||
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
int len;
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
t->vocab[i] = (char *)malloc(len + 1);
|
||||
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||
t->vocab[i][len] = '\0'; // add the string terminating token
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
|
||||
void free_tokenizer(Tokenizer* t) {
|
||||
for (int i = 0; i < t->vocab_size; i++) {
|
||||
free(t->vocab[i]);
|
||||
}
|
||||
free(t->vocab);
|
||||
free(t->vocab_scores);
|
||||
}
|
||||
|
||||
char* get_piece(Tokenizer* t, int prev_token, int token) {
|
||||
char *piece = t->vocab[token];
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
unsigned char byte_val;
|
||||
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, careful to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
t->byte_piece[0] = byte_val;
|
||||
piece = &t->byte_piece[0];
|
||||
}
|
||||
}
|
||||
return piece;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
char *str;
|
||||
@@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||
void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
|
||||
// sort vocabulary
|
||||
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
sorted_vocab[i].str = vocab[i];
|
||||
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < t->vocab_size; i++) {
|
||||
sorted_vocab[i].str = t->vocab[i];
|
||||
sorted_vocab[i].id = i;
|
||||
}
|
||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
@@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
}
|
||||
|
||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||
|
||||
if (id != -1) {
|
||||
// we found this codepoint in vocab, add it as a token
|
||||
@@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
if (id != -1 && vocab_scores[id] > best_score) {
|
||||
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||
if (id != -1 && t->vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = vocab_scores[id];
|
||||
best_score = t->vocab_scores[id];
|
||||
best_id = id;
|
||||
best_idx = i;
|
||||
}
|
||||
@@ -587,8 +643,8 @@ void error_usage() {
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
// default inits
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
char *tokenizer = "tokenizer.bin";
|
||||
char *checkpoint_path = NULL; // e.g. out/model.bin
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
rng_seed = 0; // seed rng with time by default
|
||||
@@ -596,7 +652,7 @@ int main(int argc, char *argv[]) {
|
||||
char *prompt = NULL; // prompt string
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
|
||||
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||
for (int i = 2; i < argc; i+=2) {
|
||||
// do some basic validation
|
||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||
@@ -608,7 +664,7 @@ int main(int argc, char *argv[]) {
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||
@@ -619,29 +675,14 @@ int main(int argc, char *argv[]) {
|
||||
int fd = 0; // file descriptor for memory mapping
|
||||
float* data = NULL; // memory mapped data pointer
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size);
|
||||
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
|
||||
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
// read in the tokenizer .bin file
|
||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||
unsigned int max_token_length;
|
||||
{
|
||||
FILE *file = fopen(tokenizer, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
|
||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
int len;
|
||||
for (int i = 0; i < config.vocab_size; i++) {
|
||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i] = (char *)malloc(len + 1);
|
||||
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i][len] = '\0'; // add the string terminating token
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
Tokenizer tokenizer;
|
||||
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
|
||||
|
||||
// create and init the application RunState
|
||||
RunState state;
|
||||
@@ -653,7 +694,7 @@ int main(int argc, char *argv[]) {
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||
bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
@@ -695,22 +736,9 @@ int main(int argc, char *argv[]) {
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
unsigned char byte_val;
|
||||
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
char byte_piece[2];
|
||||
byte_piece[0] = byte_val;
|
||||
byte_piece[1] = '\0';
|
||||
printf("%s", byte_piece);
|
||||
}
|
||||
} else {
|
||||
printf("%s", token_str);
|
||||
}
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = get_piece(&tokenizer, token, next);
|
||||
printf("%s", piece);
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
@@ -728,9 +756,7 @@ int main(int argc, char *argv[]) {
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
free(probindex);
|
||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||
free(vocab);
|
||||
free(vocab_scores);
|
||||
free_tokenizer(&tokenizer);
|
||||
if (prompt_tokens != NULL) free(prompt_tokens);
|
||||
if (data != MAP_FAILED) munmap(data, file_size);
|
||||
if (fd != -1) close(fd);
|
||||
|
||||
Reference in New Issue
Block a user