make sorted vocab a buffer of Tokenizer
This commit is contained in:
@@ -369,14 +369,24 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
// ----------------------------------------------------------------------------
|
||||
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
|
||||
|
||||
typedef struct {
|
||||
char *str;
|
||||
int id;
|
||||
} TokenIndex;
|
||||
|
||||
typedef struct {
|
||||
char** vocab;
|
||||
float* vocab_scores;
|
||||
TokenIndex *sorted_vocab;
|
||||
int vocab_size;
|
||||
unsigned int max_token_length;
|
||||
char byte_piece[2];
|
||||
} Tokenizer;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||
}
|
||||
|
||||
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||
// i should have written the vocab_size into the tokenizer file... sigh
|
||||
t->vocab_size = vocab_size;
|
||||
@@ -384,6 +394,7 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
|
||||
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
|
||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
||||
t->sorted_vocab = NULL; // initialized lazily
|
||||
// read in the file
|
||||
FILE *file = fopen(tokenizer_path, "rb");
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
||||
@@ -403,6 +414,7 @@ void free_tokenizer(Tokenizer* t) {
|
||||
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
|
||||
free(t->vocab);
|
||||
free(t->vocab_scores);
|
||||
free(t->sorted_vocab);
|
||||
}
|
||||
|
||||
char* decode(Tokenizer* t, int prev_token, int token) {
|
||||
@@ -422,15 +434,6 @@ char* decode(Tokenizer* t, int prev_token, int token) {
|
||||
return piece;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
char *str;
|
||||
int id;
|
||||
} TokenIndex;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||
}
|
||||
|
||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||
@@ -441,20 +444,22 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
|
||||
// sort vocabulary
|
||||
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < t->vocab_size; i++) {
|
||||
sorted_vocab[i].str = t->vocab[i];
|
||||
sorted_vocab[i].id = i;
|
||||
if (t->sorted_vocab == NULL) {
|
||||
// lazily malloc and sort the vocabulary
|
||||
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < t->vocab_size; i++) {
|
||||
t->sorted_vocab[i].str = t->vocab[i];
|
||||
t->sorted_vocab[i].id = i;
|
||||
}
|
||||
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
}
|
||||
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
|
||||
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
@@ -490,7 +495,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
|
||||
if (id != -1) {
|
||||
// we found this codepoint in vocab, add it as a token
|
||||
@@ -515,7 +520,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
|
||||
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
if (id != -1 && t->vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = t->vocab_scores[id];
|
||||
@@ -538,7 +543,6 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
free(str_buffer);
|
||||
free(sorted_vocab);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user