From c5e0e7fce4f90f810450a278316dc9fc96298d25 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 23 Aug 2023 16:27:48 +0000 Subject: [PATCH 1/4] attempt at chat function, but it was 8AM and I didn't have coffee yet. Seems to work but it's probably subtly broken or too complex. version 1 only, lots of hard-coded non-sensical buffer sizes. Have to go to work now --- run.c | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 112 insertions(+), 5 deletions(-) diff --git a/run.c b/run.c index da1fbc4..6f281f8 100644 --- a/run.c +++ b/run.c @@ -732,6 +732,8 @@ long time_in_ms() { // generation loop void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { + char *empty_prompt = ""; + if (prompt == NULL) { prompt = empty_prompt; } // encode the (string) prompt into tokens sequence int num_prompt_tokens = 0; @@ -785,6 +787,98 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, free(prompt_tokens); } +void read_stdin(const char* guide, char* buffer, size_t bufsize) { + // read a line from stdin, up to but not including \n + printf("%s", guide); + if (fgets(buffer, bufsize, stdin) != NULL) { + size_t len = strlen(buffer); + if (len > 0 && buffer[len - 1] == '\n') { + buffer[len - 1] = '\0'; // strip newline + } + } +} + +void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, + char *cli_user_prompt, char *cli_system_prompt, int steps) { + + // buffers for reading the system prompt and user prompt from stdin + char system_prompt[512]; + char user_prompt[512]; + char rendered_prompt[512]; + int num_prompt_tokens = 0; + int* prompt_tokens = (int*)malloc(512 * sizeof(int)); + int user_idx; + + // start the main loop + int8_t user_turn = 1; // user starts + int next; // will store the next token in the sequence + int token; // stores the current token to feed into the transformer + int prev_token; + int pos = 0; // position in the sequence + while (pos < steps) { + + // when it is the user's turn to contribute tokens to the dialog... + if (user_turn) { + // get the (optional) system prompt at position 0 + if (pos == 0) { + // at position 0, the user can also contribute a system prompt + if (cli_system_prompt == NULL) { + // system prompt was not passed in, attempt to get it from stdin + read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt)); + } else { + // system prompt was passed in, use it + strcpy(system_prompt, cli_system_prompt); + } + } + // get the user prompt + if (pos == 0 && cli_user_prompt != NULL) { + // user prompt for position 0 was passed in, use it + strcpy(user_prompt, cli_user_prompt); + } else { + // otherwise get user prompt from stdin + read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); + } + // render user/system prompts into the Llama 2 Chat schema + if (pos == 0 && system_prompt[0] != '\0') { + char system_template[] = "[INST] <>\n%s\n<>\n\n%s\n[/INST]"; + sprintf(rendered_prompt, system_template, system_prompt, user_prompt); + } else { + char user_template[] = "[INST] %s [/INST]"; + sprintf(rendered_prompt, user_template, user_prompt); + } + // encode the rendered prompt into tokens + encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + user_idx = 0; // reset the user index + user_turn = 0; + } + + // determine the token to pass into the transformer next + prev_token = token; + if (user_idx < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + token = prompt_tokens[user_idx++]; + } else { + // otherwise use the next token sampled from previous turn + token = next; + } + // EOS (=2) token ends the Assistant turn + if (token == 2) { user_turn = 1; } + + // forward the transformer to get logits for the next token + float* logits = forward(transformer, token, pos); + next = sample(sampler, logits); + pos++; + + // print the token as string, decode it with the Tokenizer object + char* piece = decode(tokenizer, prev_token, token); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + printf("\n"); + free(prompt_tokens); +} + + // ---------------------------------------------------------------------------- // CLI, include only if not testing #ifndef TESTING @@ -799,6 +893,8 @@ void error_usage() { fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); fprintf(stderr, " -z optional path to custom tokenizer\n"); + fprintf(stderr, " -m mode: generate|chat, default: generate\n"); + fprintf(stderr, " -y (optional) system prompt in chat mode\n"); exit(EXIT_FAILURE); } @@ -807,11 +903,13 @@ int main(int argc, char *argv[]) { // default parameters char *checkpoint_path = NULL; // e.g. out/model.bin char *tokenizer_path = "tokenizer.bin"; - float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher - float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower - int steps = 256; // number of steps to run for - char *prompt = ""; // prompt string + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower + int steps = 256; // number of steps to run for + char *prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default + char *mode = "generate"; // generate|chat + char *system_prompt = NULL; // the (optional) system prompt to use in chat mode // poor man's C argparse so we can override the defaults above from the command line if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } @@ -827,6 +925,8 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } + else if (argv[i][1] == 'm') { mode = argv[i + 1]; } + else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; } else { error_usage(); } } @@ -850,7 +950,14 @@ int main(int argc, char *argv[]) { build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); // run! - generate(&transformer, &tokenizer, &sampler, prompt, steps); + if (strcmp(mode, "generate") == 0) { + generate(&transformer, &tokenizer, &sampler, prompt, steps); + } else if (strcmp(mode, "chat") == 0) { + chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps); + } else { + fprintf(stderr, "unknown mode: %s\n", mode); + error_usage(); + } // memory and file handles cleanup free_sampler(&sampler); From 40fb902cf07084d43a78b45b31d977c5e3659dea Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 24 Aug 2023 03:33:44 +0000 Subject: [PATCH 2/4] fix chat format bug i think --- run.c | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/run.c b/run.c index 6f281f8..15b2683 100644 --- a/run.c +++ b/run.c @@ -798,6 +798,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) { } } +// ---------------------------------------------------------------------------- +// chat loop + void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *cli_user_prompt, char *cli_system_prompt, int steps) { @@ -840,7 +843,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, } // render user/system prompts into the Llama 2 Chat schema if (pos == 0 && system_prompt[0] != '\0') { - char system_template[] = "[INST] <>\n%s\n<>\n\n%s\n[/INST]"; + char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; sprintf(rendered_prompt, system_template, system_prompt, user_prompt); } else { char user_template[] = "[INST] %s [/INST]"; From 3d787b24635a7031f933ed42afc58e8117ee4504 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 24 Aug 2023 04:31:06 +0000 Subject: [PATCH 3/4] ok getting closer, and manually verified correctness of the schema matching python. still some weirdness in the printing to chase down, and also have to tune the buffer lengths and make them sensible and such --- run.c | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/run.c b/run.c index 15b2683..40f68e6 100644 --- a/run.c +++ b/run.c @@ -839,7 +839,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, strcpy(user_prompt, cli_user_prompt); } else { // otherwise get user prompt from stdin - read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); + read_stdin("User: ", user_prompt, sizeof(user_prompt)); } // render user/system prompts into the Llama 2 Chat schema if (pos == 0 && system_prompt[0] != '\0') { @@ -853,10 +853,10 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); user_idx = 0; // reset the user index user_turn = 0; + printf("Assistant: "); } // determine the token to pass into the transformer next - prev_token = token; if (user_idx < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token token = prompt_tokens[user_idx++]; @@ -872,10 +872,13 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, next = sample(sampler, logits); pos++; - // print the token as string, decode it with the Tokenizer object - char* piece = decode(tokenizer, prev_token, token); - safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes - fflush(stdout); + if (user_idx >= num_prompt_tokens && next != 2) { + // the Assistant is responding, so print its output + char* piece = decode(tokenizer, token, next); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + if (next == 2) { printf("\n"); } } printf("\n"); free(prompt_tokens); From fbe324fc5ab61eaab3b8f74694be5b3870d3d5ee Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 25 Aug 2023 14:54:05 +0000 Subject: [PATCH 4/4] adjust things a bit --- README.md | 19 ++++++++++++++----- run.c | 8 ++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index e9df1f6..8b05b49 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,18 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo! +You can also chat with the Llama Chat models. Export the chat model exactly as above: + +```bash +python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat +``` + +Then chat with it by specifying the chat mode using the `-m` flag, e.g.: + +```bash +./run llama2_7b_chat.bin -m chat +``` + ## hugginface models We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file. @@ -207,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`. If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183) -### OpenMP -Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors. +**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors. You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does: ```bash @@ -324,12 +335,10 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg ## unsorted todos -- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp -- ability to calculate perplexity in run.c, exactly as done in llama.cpp - add support in run.c of reading version 1+ files from export, later deprecate "version 0" -- add more tests inside [test.c](test.c) (call for help!) - runq.c (int8 quantization) add - run.cu (CUDA) investigate and merge +- add more tests inside [test.c](test.c) - make it easier to add a new dataset with not too much pain - (LoRA) finetuning and export of Llama 2 models diff --git a/run.c b/run.c index 40f68e6..9df918e 100644 --- a/run.c +++ b/run.c @@ -800,16 +800,20 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) { // ---------------------------------------------------------------------------- // chat loop +// I manually inspected the tokens for a few chat conversations compared to +// python reference and that seemed ok, but this was not thoroughly tested and +// is not safely implemented, it's more a proof of concept atm. void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *cli_user_prompt, char *cli_system_prompt, int steps) { // buffers for reading the system prompt and user prompt from stdin + // you'll notice they are soomewhat haphazardly and unsafely set atm char system_prompt[512]; char user_prompt[512]; - char rendered_prompt[512]; + char rendered_prompt[1152]; int num_prompt_tokens = 0; - int* prompt_tokens = (int*)malloc(512 * sizeof(int)); + int* prompt_tokens = (int*)malloc(1152 * sizeof(int)); int user_idx; // start the main loop