Merge pull request #297 from karpathy/feature/utf8
Add UTF-8 support to prompts
This commit is contained in:
@@ -10,6 +10,7 @@ $ ./run
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
@@ -347,29 +348,85 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
// ----------------------------------------------------------------------------
|
||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||
|
||||
int str_lookup(char *str, char **vocab, int vocab_size) {
|
||||
// find the first perfect match for str in vocab, return its index or -1 if not found
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (strcmp(str, vocab[i]) == 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
typedef struct {
|
||||
char *str;
|
||||
int id;
|
||||
} TokenIndex;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||
}
|
||||
|
||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||
|
||||
// a temporary buffer to merge two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||
// sort vocabulary
|
||||
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
sorted_vocab[i].str = vocab[i];
|
||||
sorted_vocab[i].id = i;
|
||||
}
|
||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
|
||||
// first encode every individual byte in the input string
|
||||
*n_tokens = 0; // the number of tokens
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
|
||||
// U+0000 U+007F 0xxxxxxx
|
||||
// U+0080 U+07FF 110xxxxx 10xxxxxx
|
||||
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
|
||||
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
|
||||
// process the raw (UTF-8) byte sequence of the input string
|
||||
for (char *c = text; *c != '\0'; c++) {
|
||||
sprintf(str_buffer, "%c", *c);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); }
|
||||
tokens[*n_tokens] = id;
|
||||
(*n_tokens)++;
|
||||
|
||||
// reset buffer if the current byte is ASCII or a leading byte
|
||||
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
|
||||
// 0x80 is 10000000
|
||||
// in UTF-8, all continuation bytes start with "10" in first two bits
|
||||
// so in English this is: "if this byte is not a continuation byte"
|
||||
if ((*c & 0xC0) != 0x80) {
|
||||
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
|
||||
// => reset our location, as we're starting a new UTF-8 codepoint
|
||||
str_len = 0;
|
||||
}
|
||||
|
||||
// append the current byte to the buffer
|
||||
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
|
||||
str_buffer[str_len] = '\0';
|
||||
|
||||
// while the next character is a continuation byte, continue appending
|
||||
if ((*(c+1) & 0xC0) == 0x80) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
|
||||
if (id != -1) {
|
||||
// we found this codepoint in vocab, add it as a token
|
||||
tokens[(*n_tokens)++] = id;
|
||||
} else {
|
||||
// byte_fallback encoding: just encode each byte as a token
|
||||
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
|
||||
// so the individual bytes only start at index 3
|
||||
for (int i=0; i < str_len; i++) {
|
||||
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||
@@ -381,7 +438,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
if (id != -1 && vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = vocab_scores[id];
|
||||
@@ -404,6 +461,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
}
|
||||
|
||||
free(str_buffer);
|
||||
free(sorted_vocab);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -609,7 +667,7 @@ int main(int argc, char *argv[]) {
|
||||
int *prompt_tokens = NULL;
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int));
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
@@ -654,7 +712,20 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
printf("%s", token_str);
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
unsigned char byte_val;
|
||||
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
char byte_piece[2];
|
||||
byte_piece[0] = byte_val;
|
||||
byte_piece[1] = '\0';
|
||||
printf("%s", byte_piece);
|
||||
}
|
||||
} else {
|
||||
printf("%s", token_str);
|
||||
}
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
|
||||
@@ -51,11 +51,16 @@ if compile:
|
||||
print("Compiling the model...")
|
||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||
|
||||
# load the tokenizer, either provided, or attempt to find it
|
||||
# load the tokenizer
|
||||
vocab_source = checkpoint_dict.get("vocab_source", "llama2")
|
||||
vocab_size = gptconf.vocab_size
|
||||
if tokenizer:
|
||||
# a specific tokenizer is provided, use it
|
||||
tokenizer_model = tokenizer
|
||||
else:
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
||||
# let's try to find the tokenizer model automatically. bit gross here...
|
||||
query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
|
||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||
|
||||
# encode the beginning of the prompt
|
||||
|
||||
Binary file not shown.
@@ -51,8 +51,6 @@ class Tokenizer:
|
||||
t = '\n<s>\n'
|
||||
elif i == self.eos_id:
|
||||
t = '\n</s>\n'
|
||||
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
||||
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
|
||||
|
||||
|
||||
Reference in New Issue
Block a user