add BOS and EOS function to the Tokenizer as we start to converge closer to the Llama 2 code from Meta, and as we're about to add the Chat capability
This commit is contained in:
@@ -55,6 +55,14 @@ test:
|
||||
testc:
|
||||
pytest -k runc
|
||||
|
||||
# run the C tests, without touching pytest / python
|
||||
# to increase verbosity level run e.g. as `make testcc VERBOSITY=1`
|
||||
VERBOSITY ?= 0
|
||||
.PHONY: testcc
|
||||
testcc:
|
||||
$(CC) -DVERBOSITY=$(VERBOSITY) -O3 -o testc test.c -lm
|
||||
./testc
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f run
|
||||
|
||||
@@ -243,6 +243,14 @@ $ pytest
|
||||
|
||||
This will currently invoke two tests inside `test_all.py`, which forward the model in both C and Python for 200 steps and check the output against a known good expected output. The tests currently run in only a few seconds, but will have to download and cache the stories260K models in a temporary `test` directory (only ~2MB download).
|
||||
|
||||
There are also some tests in C, in the file [test.c](test.c). You can run these with `make testcc`, or to see more stuff printed:
|
||||
|
||||
```
|
||||
make testcc VERBOSITY=1
|
||||
```
|
||||
|
||||
Call for help: help add more tests.
|
||||
|
||||
## ack
|
||||
|
||||
I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you.
|
||||
@@ -319,6 +327,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp
|
||||
- ability to calculate perplexity in run.c, exactly as done in llama.cpp
|
||||
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
|
||||
- add more tests inside [test.c](test.c) (call for help!)
|
||||
- runq.c (int8 quantization) add
|
||||
- run.cu (CUDA) investigate and merge
|
||||
- make it easier to add a new dataset with not too much pain
|
||||
|
||||
@@ -441,8 +441,10 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
|
||||
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
|
||||
|
||||
if (t->sorted_vocab == NULL) {
|
||||
// lazily malloc and sort the vocabulary
|
||||
@@ -455,13 +457,24 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
||||
size_t str_len = 0;
|
||||
|
||||
// start at 0 tokens
|
||||
*n_tokens = 0;
|
||||
|
||||
// add optional BOS (=1) token, if desired
|
||||
if (bos) tokens[(*n_tokens)++] = 1;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
// so prepend a dummy prefix token to the input string, but only if text != ""
|
||||
// TODO: pretty sure this isn't correct in the general case but I don't have the
|
||||
// energy to read more of the sentencepiece code to figure out what it's doing
|
||||
if (text[0] != '\0') {
|
||||
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
tokens[(*n_tokens)++] = dummy_prefix;
|
||||
}
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
@@ -543,6 +556,9 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
(*n_tokens)--; // token length decreased
|
||||
}
|
||||
|
||||
// add optional EOS (=2) token, if desired
|
||||
if (eos) tokens[(*n_tokens)++] = 2;
|
||||
|
||||
free(str_buffer);
|
||||
}
|
||||
|
||||
@@ -704,18 +720,19 @@ long time_in_ms() {
|
||||
|
||||
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||
// encode the (string) prompt into tokens sequence
|
||||
int num_prompt_tokens = 0;
|
||||
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
|
||||
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
if (num_prompt_tokens < 1) {
|
||||
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int token = prompt_tokens[0]; // kick off with the first token in the prompt
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
@@ -723,16 +740,16 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
float* logits = forward(transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if (pos < num_prompt_tokens) {
|
||||
if (pos < num_prompt_tokens - 1) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
next = prompt_tokens[pos + 1];
|
||||
} else {
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
// data-dependent terminating condition: the BOS (=1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
@@ -756,7 +773,8 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
// CLI, include only if not testing
|
||||
#ifndef TESTING
|
||||
|
||||
void error_usage() {
|
||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||
@@ -779,7 +797,7 @@ int main(int argc, char *argv[]) {
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
char *prompt = ""; // prompt string
|
||||
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
@@ -827,3 +845,4 @@ int main(int argc, char *argv[]) {
|
||||
free_transformer(&transformer);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
#define TESTING
|
||||
#include "run.c"
|
||||
|
||||
void assert_eq(int a, int b) {
|
||||
if (a != b) {
|
||||
printf("Assertion failed: %d != %d\n", a, b);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
|
||||
// encode
|
||||
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
|
||||
#if VERBOSITY == 1
|
||||
// print maybe
|
||||
printf("expected tokens:\n");
|
||||
for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
|
||||
printf("\n");
|
||||
printf("actual tokens:\n");
|
||||
for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
// verify
|
||||
assert_eq(num_prompt_tokens, num_expected_tokens);
|
||||
for (int i = 0; i < num_prompt_tokens; i++) {
|
||||
assert_eq(prompt_tokens[i], expected_tokens[i]);
|
||||
}
|
||||
|
||||
#if VERBOSITY == 1
|
||||
printf("OK\n");
|
||||
printf("---\n");
|
||||
#endif
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
void text_prompt_encodings() {
|
||||
// let's verify that the Tokenizer works as expected
|
||||
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
int vocab_size = 32000;
|
||||
Tokenizer tokenizer;
|
||||
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
|
||||
|
||||
// test 0 (test the empty string) (I added this as a simple case)
|
||||
char *prompt0 = "";
|
||||
int expected_tokens0[] = {1};
|
||||
test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
|
||||
|
||||
// the tests below are taken from the Meta Llama 2 repo example code
|
||||
// https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
|
||||
// and the expected tokens come from me breaking in the debugger in Python
|
||||
|
||||
// test 1
|
||||
char *prompt = "I believe the meaning of life is";
|
||||
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
|
||||
test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
|
||||
|
||||
// test 2
|
||||
char* prompt2 = "Simply put, the theory of relativity states that ";
|
||||
int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
|
||||
test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
|
||||
|
||||
// test 3
|
||||
char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
|
||||
int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
|
||||
test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
|
||||
|
||||
// test 4
|
||||
char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
|
||||
int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
|
||||
test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
text_prompt_encodings();
|
||||
printf("ALL OK\n");
|
||||
}
|
||||
Reference in New Issue
Block a user