add BOS and EOS function to the Tokenizer as we start to converge closer to the Llama 2 code from Meta, and as we're about to add the Chat capability
This commit is contained in:
@@ -441,8 +441,10 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
|
||||
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
|
||||
|
||||
if (t->sorted_vocab == NULL) {
|
||||
// lazily malloc and sort the vocabulary
|
||||
@@ -455,13 +457,24 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
||||
size_t str_len = 0;
|
||||
|
||||
// start at 0 tokens
|
||||
*n_tokens = 0;
|
||||
|
||||
// add optional BOS (=1) token, if desired
|
||||
if (bos) tokens[(*n_tokens)++] = 1;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
// so prepend a dummy prefix token to the input string, but only if text != ""
|
||||
// TODO: pretty sure this isn't correct in the general case but I don't have the
|
||||
// energy to read more of the sentencepiece code to figure out what it's doing
|
||||
if (text[0] != '\0') {
|
||||
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
tokens[(*n_tokens)++] = dummy_prefix;
|
||||
}
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
@@ -543,6 +556,9 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
(*n_tokens)--; // token length decreased
|
||||
}
|
||||
|
||||
// add optional EOS (=2) token, if desired
|
||||
if (eos) tokens[(*n_tokens)++] = 2;
|
||||
|
||||
free(str_buffer);
|
||||
}
|
||||
|
||||
@@ -704,18 +720,19 @@ long time_in_ms() {
|
||||
|
||||
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||
// encode the (string) prompt into tokens sequence
|
||||
int num_prompt_tokens = 0;
|
||||
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
|
||||
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
if (num_prompt_tokens < 1) {
|
||||
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int token = prompt_tokens[0]; // kick off with the first token in the prompt
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
@@ -723,16 +740,16 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
float* logits = forward(transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if (pos < num_prompt_tokens) {
|
||||
if (pos < num_prompt_tokens - 1) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
next = prompt_tokens[pos + 1];
|
||||
} else {
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
// data-dependent terminating condition: the BOS (=1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
@@ -756,7 +773,8 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
// CLI, include only if not testing
|
||||
#ifndef TESTING
|
||||
|
||||
void error_usage() {
|
||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||
@@ -779,7 +797,7 @@ int main(int argc, char *argv[]) {
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
char *prompt = ""; // prompt string
|
||||
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
@@ -827,3 +845,4 @@ int main(int argc, char *argv[]) {
|
||||
free_transformer(&transformer);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user