diff --git a/run.c b/run.c index 37d3018..8f565cd 100644 --- a/run.c +++ b/run.c @@ -358,10 +358,10 @@ int compare_tokens(const void *a, const void *b) { } int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { - // find the perfect match for str in vocab, return its index or -1 if not found - TokenIndex tok = {str=str}; + // efficiently find the perfect match for str in vocab, return its index or -1 if not found + TokenIndex tok = { .str = str }; // acts as the key to search for TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); - return res!=NULL ? res->id : -1; + return res != NULL ? res->id : -1; } void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { @@ -374,7 +374,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); - // a temporary buffer to merge two consecutive tokens + // create a temporary buffer that will store merge candidates of always two consecutive tokens char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator size_t str_len = 0; @@ -382,25 +382,48 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); *n_tokens = 1; // the number of tokens - // first encode every individual byte in the input string - for (char *c = text; *c != '\0'; c++) { - // reset buffer if the current byte is ASCII or leading byte - if ((*c & 0xC0) != 0x80) - str_len = 0; + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - str_buffer[str_len++] = *c; // append byte to the buffer + // process the raw (UTF-8) byte sequence of the input string + for (char *c = text; *c != '\0'; c++) { + + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest + // 0x80 is 10000000 + // in UTF-8, all continuation bytes start with "10" in first two bits + // so in English this is: "if this byte is not a continuation byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line str_buffer[str_len] = '\0'; - if ((*(c+1) & 0xC0) == 0x80) // skip if in middle of multi-byte utf8 encoding + // while the next character is a continuation byte, continue appending + if ((*(c+1) & 0xC0) == 0x80) { continue; + } + // ok c+1 is not a continuation byte, so we've read in a full codepoint int id = str_lookup(str_buffer, sorted_vocab, vocab_size); if (id != -1) { + // we found this codepoint in vocab, add it as a token tokens[(*n_tokens)++] = id; } else { - // byte_fallback encoding - for (int i=0; i, , + // so the individual bytes only start at index 3 + for (int i=0; i < str_len; i++) { tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; } }