diff --git a/Makefile b/Makefile index a4c6588..11c35c8 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,14 @@ test: testc: pytest -k runc +# run the C tests, without touching pytest / python +# to increase verbosity level run e.g. as `make testcc VERBOSITY=1` +VERBOSITY ?= 0 +.PHONY: testcc +testcc: + $(CC) -DVERBOSITY=$(VERBOSITY) -O3 -o testc test.c -lm + ./testc + .PHONY: clean clean: rm -f run diff --git a/README.md b/README.md index 440e5e2..e9df1f6 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,14 @@ $ pytest This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download). +There are also some tests in C, in the file [test.c](test.c). You can run these with `make testcc`, or to see more stuff printed: + +``` +make testcc VERBOSITY=1 +``` + +Call for help: help add more tests. + ## ack I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you. @@ -319,6 +327,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp - ability to calculate perplexity in run.c, exactly as done in llama.cpp - add support in run.c of reading version 1+ files from export, later deprecate "version 0" +- add more tests inside [test.c](test.c) (call for help!) - runq.c (int8 quantization) add - run.cu (CUDA) investigate and merge - make it easier to add a new dataset with not too much pain diff --git a/run.c b/run.c index 1f50d59..bf4ce34 100644 --- a/run.c +++ b/run.c @@ -441,8 +441,10 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { return res != NULL ? res->id : -1; } -void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { +void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { // encode the string text (input) into an upper-bound preallocated tokens[] array + // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) + if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } if (t->sorted_vocab == NULL) { // lazily malloc and sort the vocabulary @@ -455,13 +457,24 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // create a temporary buffer that will store merge candidates of always two consecutive tokens - // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) + // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); size_t str_len = 0; + // start at 0 tokens + *n_tokens = 0; + + // add optional BOS (=1) token, if desired + if (bos) tokens[(*n_tokens)++] = 1; + // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size); - *n_tokens = 1; // the number of tokens + // so prepend a dummy prefix token to the input string, but only if text != "" + // TODO: pretty sure this isn't correct in the general case but I don't have the + // energy to read more of the sentencepiece code to figure out what it's doing + if (text[0] != '\0') { + int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); + tokens[(*n_tokens)++] = dummy_prefix; + } // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: // Code point ↔ UTF-8 conversion @@ -543,6 +556,9 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { (*n_tokens)--; // token length decreased } + // add optional EOS (=2) token, if desired + if (eos) tokens[(*n_tokens)++] = 2; + free(str_buffer); } @@ -704,18 +720,19 @@ long time_in_ms() { void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { - // encode the (string) prompt into tokens sequence, if any is given - int *prompt_tokens = NULL; // the sequence of prompt tokens - int num_prompt_tokens = 0; // the total number of prompt tokens - if (prompt != NULL) { - prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); - encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens); + // encode the (string) prompt into tokens sequence + int num_prompt_tokens = 0; + int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS + encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + if (num_prompt_tokens < 1) { + fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); + exit(EXIT_FAILURE); } // start the main loop long start = 0; // used to time our code, only initialized after first iteration int next; // will store the next token in the sequence - int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer + int token = prompt_tokens[0]; // kick off with the first token in the prompt int pos = 0; // position in the sequence while (pos < steps) { @@ -723,16 +740,16 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, float* logits = forward(transformer, token, pos); // advance the state state machine - if (pos < num_prompt_tokens) { + if (pos < num_prompt_tokens - 1) { // if we are still processing the input prompt, force the next prompt token - next = prompt_tokens[pos]; + next = prompt_tokens[pos + 1]; } else { // otherwise sample the next token from the logits next = sample(sampler, logits); } pos++; - // data-dependent terminating condition: the BOS (1) token delimits sequences + // data-dependent terminating condition: the BOS (=1) token delimits sequences if (next == 1) { break; } // print the token as string, decode it with the Tokenizer object @@ -756,7 +773,8 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, } // ---------------------------------------------------------------------------- -// int main +// CLI, include only if not testing +#ifndef TESTING void error_usage() { fprintf(stderr, "Usage: run [options]\n"); @@ -779,7 +797,7 @@ int main(int argc, char *argv[]) { float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower int steps = 256; // number of steps to run for - char *prompt = NULL; // prompt string + char *prompt = ""; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default // poor man's C argparse so we can override the defaults above from the command line @@ -827,3 +845,4 @@ int main(int argc, char *argv[]) { free_transformer(&transformer); return 0; } +#endif diff --git a/test.c b/test.c new file mode 100644 index 0000000..887e2bf --- /dev/null +++ b/test.c @@ -0,0 +1,81 @@ +#define TESTING +#include "run.c" + +void assert_eq(int a, int b) { + if (a != b) { + printf("Assertion failed: %d != %d\n", a, b); + exit(EXIT_FAILURE); + } +} + +void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) { + // encode + int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); + int num_prompt_tokens = 0; // the total number of prompt tokens + encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + + #if VERBOSITY == 1 + // print maybe + printf("expected tokens:\n"); + for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]); + printf("\n"); + printf("actual tokens:\n"); + for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]); + printf("\n"); + #endif + + // verify + assert_eq(num_prompt_tokens, num_expected_tokens); + for (int i = 0; i < num_prompt_tokens; i++) { + assert_eq(prompt_tokens[i], expected_tokens[i]); + } + + #if VERBOSITY == 1 + printf("OK\n"); + printf("---\n"); + #endif + free(prompt_tokens); +} + +void text_prompt_encodings() { + // let's verify that the Tokenizer works as expected + + char *tokenizer_path = "tokenizer.bin"; + int vocab_size = 32000; + Tokenizer tokenizer; + build_tokenizer(&tokenizer, tokenizer_path, vocab_size); + + // test 0 (test the empty string) (I added this as a simple case) + char *prompt0 = ""; + int expected_tokens0[] = {1}; + test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int)); + + // the tests below are taken from the Meta Llama 2 repo example code + // https://github.com/facebookresearch/llama/blob/main/example_text_completion.py + // and the expected tokens come from me breaking in the debugger in Python + + // test 1 + char *prompt = "I believe the meaning of life is"; + int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338}; + test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int)); + + // test 2 + char* prompt2 = "Simply put, the theory of relativity states that "; + int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871}; + test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int)); + + // test 3 + char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just "; + int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871}; + test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int)); + + // test 4 + char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>"; + int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149}; + test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int)); +} + +int main(int argc, char *argv[]) { + text_prompt_encodings(); + printf("ALL OK\n"); +}