add BOS and EOS function to the Tokenizer as we start to converge closer to the Llama 2 code from Meta, and as we're about to add the Chat capability

This commit is contained in:
Andrej Karpathy
2023-08-23 00:08:22 +00:00
parent d26a499207
commit d1eb18b8ec
4 changed files with 133 additions and 16 deletions
+8
View File
@@ -55,6 +55,14 @@ test:
testc:
pytest -k runc
# run the C tests, without touching pytest / python
# to increase verbosity level run e.g. as `make testcc VERBOSITY=1`
VERBOSITY ?= 0
.PHONY: testcc
testcc:
$(CC) -DVERBOSITY=$(VERBOSITY) -O3 -o testc test.c -lm
./testc
.PHONY: clean
clean:
rm -f run
+9
View File
@@ -243,6 +243,14 @@ $ pytest
This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download).
There are also some tests in C, in the file [test.c](test.c). You can run these with `make testcc`, or to see more stuff printed:
```
make testcc VERBOSITY=1
```
Call for help: help add more tests.
## ack
I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you.
@@ -319,6 +327,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp
- ability to calculate perplexity in run.c, exactly as done in llama.cpp
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
- add more tests inside [test.c](test.c) (call for help!)
- runq.c (int8 quantization) add
- run.cu (CUDA) investigate and merge
- make it easier to add a new dataset with not too much pain
+35 -16
View File
@@ -441,8 +441,10 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
return res != NULL ? res->id : -1;
}
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
if (t->sorted_vocab == NULL) {
// lazily malloc and sort the vocabulary
@@ -455,13 +457,24 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}
// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
size_t str_len = 0;
// start at 0 tokens
*n_tokens = 0;
// add optional BOS (=1) token, if desired
if (bos) tokens[(*n_tokens)++] = 1;
// add_dummy_prefix is true by default
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
*n_tokens = 1; // the number of tokens
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;
}
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
@@ -543,6 +556,9 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
(*n_tokens)--; // token length decreased
}
// add optional EOS (=2) token, if desired
if (eos) tokens[(*n_tokens)++] = 2;
free(str_buffer);
}
@@ -704,18 +720,19 @@ long time_in_ms() {
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
int num_prompt_tokens = 0; // the total number of prompt tokens
if (prompt != NULL) {
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
if (num_prompt_tokens < 1) {
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
exit(EXIT_FAILURE);
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int token = prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
@@ -723,16 +740,16 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
float* logits = forward(transformer, token, pos);
// advance the state state machine
if (pos < num_prompt_tokens) {
if (pos < num_prompt_tokens - 1) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (1) token delimits sequences
// data-dependent terminating condition: the BOS (=1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
@@ -756,7 +773,8 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
}
// ----------------------------------------------------------------------------
// int main
// CLI, include only if not testing
#ifndef TESTING
void error_usage() {
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
@@ -779,7 +797,7 @@ int main(int argc, char *argv[]) {
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
char *prompt = ""; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
// poor man's C argparse so we can override the defaults above from the command line
@@ -827,3 +845,4 @@ int main(int argc, char *argv[]) {
free_transformer(&transformer);
return 0;
}
#endif
+81
View File
@@ -0,0 +1,81 @@
#define TESTING
#include "run.c"
void assert_eq(int a, int b) {
if (a != b) {
printf("Assertion failed: %d != %d\n", a, b);
exit(EXIT_FAILURE);
}
}
void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
// encode
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
int num_prompt_tokens = 0; // the total number of prompt tokens
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
#if VERBOSITY == 1
// print maybe
printf("expected tokens:\n");
for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
printf("\n");
printf("actual tokens:\n");
for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
printf("\n");
#endif
// verify
assert_eq(num_prompt_tokens, num_expected_tokens);
for (int i = 0; i < num_prompt_tokens; i++) {
assert_eq(prompt_tokens[i], expected_tokens[i]);
}
#if VERBOSITY == 1
printf("OK\n");
printf("---\n");
#endif
free(prompt_tokens);
}
void text_prompt_encodings() {
// let's verify that the Tokenizer works as expected
char *tokenizer_path = "tokenizer.bin";
int vocab_size = 32000;
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
// test 0 (test the empty string) (I added this as a simple case)
char *prompt0 = "";
int expected_tokens0[] = {1};
test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
// the tests below are taken from the Meta Llama 2 repo example code
// https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
// and the expected tokens come from me breaking in the debugger in Python
// test 1
char *prompt = "I believe the meaning of life is";
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
// test 2
char* prompt2 = "Simply put, the theory of relativity states that ";
int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
// test 3
char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
// test 4
char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
}
int main(int argc, char *argv[]) {
text_prompt_encodings();
printf("ALL OK\n");
}