thoroughly commented the UTF-8 byte reading code
This commit is contained in:
@@ -358,8 +358,8 @@ int compare_tokens(const void *a, const void *b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||||
// find the perfect match for str in vocab, return its index or -1 if not found
|
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||||
TokenIndex tok = {str=str};
|
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||||
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||||
return res != NULL ? res->id : -1;
|
return res != NULL ? res->id : -1;
|
||||||
}
|
}
|
||||||
@@ -374,7 +374,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
}
|
}
|
||||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||||
|
|
||||||
// a temporary buffer to merge two consecutive tokens
|
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
|
|
||||||
@@ -382,24 +382,47 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
|||||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||||
*n_tokens = 1; // the number of tokens
|
*n_tokens = 1; // the number of tokens
|
||||||
|
|
||||||
// first encode every individual byte in the input string
|
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||||
for (char *c = text; *c != '\0'; c++) {
|
// Code point ↔ UTF-8 conversion
|
||||||
// reset buffer if the current byte is ASCII or leading byte
|
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
|
||||||
if ((*c & 0xC0) != 0x80)
|
// U+0000 U+007F 0xxxxxxx
|
||||||
str_len = 0;
|
// U+0080 U+07FF 110xxxxx 10xxxxxx
|
||||||
|
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
|
||||||
|
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||||
|
|
||||||
str_buffer[str_len++] = *c; // append byte to the buffer
|
// process the raw (UTF-8) byte sequence of the input string
|
||||||
|
for (char *c = text; *c != '\0'; c++) {
|
||||||
|
|
||||||
|
// reset buffer if the current byte is ASCII or a leading byte
|
||||||
|
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
|
||||||
|
// 0x80 is 10000000
|
||||||
|
// in UTF-8, all continuation bytes start with "10" in first two bits
|
||||||
|
// so in English this is: "if this byte is not a continuation byte"
|
||||||
|
if ((*c & 0xC0) != 0x80) {
|
||||||
|
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
|
||||||
|
// => reset our location, as we're starting a new UTF-8 codepoint
|
||||||
|
str_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// append the current byte to the buffer
|
||||||
|
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
|
||||||
str_buffer[str_len] = '\0';
|
str_buffer[str_len] = '\0';
|
||||||
|
|
||||||
if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding
|
// while the next character is a continuation byte, continue appending
|
||||||
|
if ((*(c+1) & 0xC0) == 0x80) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||||
|
|
||||||
if (id != -1) {
|
if (id != -1) {
|
||||||
|
// we found this codepoint in vocab, add it as a token
|
||||||
tokens[(*n_tokens)++] = id;
|
tokens[(*n_tokens)++] = id;
|
||||||
} else {
|
} else {
|
||||||
// byte_fallback encoding
|
// byte_fallback encoding: just encode each byte as a token
|
||||||
|
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
|
||||||
|
// so the individual bytes only start at index 3
|
||||||
for (int i=0; i < str_len; i++) {
|
for (int i=0; i < str_len; i++) {
|
||||||
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user