Merge branch 'utf8' of https://github.com/atamurad/llama2.c into feature/utf8

This commit is contained in:
Andrej Karpathy
2023-08-15 00:18:53 +00:00
3 changed files with 64 additions and 21 deletions
+64 -19
View File
@@ -347,29 +347,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// ----------------------------------------------------------------------------
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
int str_lookup(char *str, char **vocab, int vocab_size) {
// find the first perfect match for str in vocab, return its index or -1 if not found
for (int i = 0; i < vocab_size; i++) {
if (strcmp(str, vocab[i]) == 0) {
return i;
}
}
return -1;
typedef struct {
char *str;
int id;
} TokenIndex;
int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = {str=str};
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
return res!=NULL ? res->id : -1;
}
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
// sort vocabulary
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
for (int i = 0; i < vocab_size; i++) {
sorted_vocab[i].str = vocab[i];
sorted_vocab[i].id = i;
}
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
// a temporary buffer to merge two consecutive tokens
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
size_t str_len = 0;
// add_dummy_prefix is true by default
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
*n_tokens = 1; // the number of tokens
// first encode every individual byte in the input string
*n_tokens = 0; // the number of tokens
for (char *c = text; *c != '\0'; c++) {
sprintf(str_buffer, "%c", *c);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); }
tokens[*n_tokens] = id;
(*n_tokens)++;
// reset buffer if the current byte is ASCII or leading byte
if ((*c & 0xC0) != 0x80)
str_len = 0;
str_buffer[str_len++] = *c; // append byte to the buffer
str_buffer[str_len] = '\0';
if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding
continue;
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
if (id != -1) {
tokens[(*n_tokens)++] = id;
} else {
// byte_fallback encoding
for (int i=0; i<str_len; i++) {
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
}
}
}
// merge the best consecutive pair each iteration, according the scores in vocab_scores
@@ -381,7 +414,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, vocab, vocab_size);
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
if (id != -1 && vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocab_scores[id];
@@ -404,6 +437,20 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
}
free(str_buffer);
free(sorted_vocab);
}
// convert token to printable string
char *token_to_str(char **vocab, int token, int prev_token) {
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (prev_token == 1 && vocab[token][0] == ' ') ? vocab[token]+1 : vocab[token];
// make '<0x01>' into '\x01'
static char byte_piece[4];
if (sscanf(token_str, "<0x%02X>", (int*)(&byte_piece)) == 1) {
byte_piece[1] = '\0';
token_str = byte_piece;
}
return token_str;
}
// ----------------------------------------------------------------------------
@@ -609,7 +656,7 @@ int main(int argc, char *argv[]) {
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
if (prompt != NULL) {
prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int));
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}
@@ -652,9 +699,7 @@ int main(int argc, char *argv[]) {
// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
printf("%s", token_to_str(vocab, next, token));
fflush(stdout);
token = next;
BIN
View File
Binary file not shown.
-2
View File
@@ -51,8 +51,6 @@ class Tokenizer:
t = '\n<s>\n'
elif i == self.eos_id:
t = '\n</s>\n'
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
t = t.replace('', ' ') # sentencepiece uses this character as whitespace
b = t.encode('utf-8') # bytes of this token, utf-8 encoded