sort vocabulary for faster lookup with bsearch()
This commit is contained in:
@@ -342,24 +342,38 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||||
|
|
||||||
int str_lookup(char *str, char **vocab, int vocab_size) {
|
typedef struct {
|
||||||
// find the first perfect match for str in vocab, return its index or -1 if not found
|
char *str;
|
||||||
for (int i = 0; i < vocab_size; i++) {
|
int id;
|
||||||
if (strcmp(str, vocab[i]) == 0) {
|
} TokenIndex;
|
||||||
return i;
|
|
||||||
|
int compare_tokens(const void *a, const void *b) {
|
||||||
|
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return -1;
|
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||||
|
// find the perfect match for str in vocab, return its index or -1 if not found
|
||||||
|
TokenIndex tok = {str=str};
|
||||||
|
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||||
|
return res!=NULL ? res->id : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||||
|
|
||||||
|
// sort vocabulary
|
||||||
|
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
||||||
|
for (int i = 0; i < vocab_size; i++) {
|
||||||
|
sorted_vocab[i].str = vocab[i];
|
||||||
|
sorted_vocab[i].id = i;
|
||||||
|
}
|
||||||
|
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||||
|
|
||||||
// a temporary buffer to merge two consecutive tokens
|
// a temporary buffer to merge two consecutive tokens
|
||||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
|
|
||||||
// add_dummy_prefix is true by default
|
// add_dummy_prefix is true by default
|
||||||
tokens[0] = str_lookup(" ", vocab, vocab_size);
|
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||||
*n_tokens = 1; // the number of tokens
|
*n_tokens = 1; // the number of tokens
|
||||||
|
|
||||||
// first encode every individual byte in the input string
|
// first encode every individual byte in the input string
|
||||||
@@ -374,7 +388,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding
|
if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||||
|
|
||||||
if (id != -1) {
|
if (id != -1) {
|
||||||
tokens[(*n_tokens)++] = id;
|
tokens[(*n_tokens)++] = id;
|
||||||
@@ -395,7 +409,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
for (int i=0; i < (*n_tokens-1); i++) {
|
for (int i=0; i < (*n_tokens-1); i++) {
|
||||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||||
if (id != -1 && vocab_scores[id] > best_score) {
|
if (id != -1 && vocab_scores[id] > best_score) {
|
||||||
// this merge pair exists in vocab! record its score and position
|
// this merge pair exists in vocab! record its score and position
|
||||||
best_score = vocab_scores[id];
|
best_score = vocab_scores[id];
|
||||||
@@ -418,6 +432,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
free(str_buffer);
|
free(str_buffer);
|
||||||
|
free(sorted_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert token to printable string
|
// convert token to printable string
|
||||||
|
|||||||
Reference in New Issue
Block a user