diff --git a/README.md b/README.md index d08f611..6f092ce 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,18 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo! +You can also chat with the Llama Chat models. Export the chat model exactly as above: + +```bash +python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat +``` + +Then chat with it by specifying the chat mode using the `-m` flag, e.g.: + +```bash +./run llama2_7b_chat.bin -m chat +``` + ## hugginface models We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file. @@ -207,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`. If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183) -### OpenMP -Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors. +**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors. You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does: ```bash @@ -324,13 +335,11 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg ## unsorted todos -- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp -- ability to calculate perplexity in run.c, exactly as done in llama.cpp - add support in run.c of reading version 1+ files from export, later deprecate "version 0" -- add more tests in [test.c](test.c) - runq.c (int8 quantization) add - run.cu (CUDA) investigate and merge -- add an Engine class that serves the model ~efficiently but in PyTorch (see [Issue 346](https://github.com/karpathy/llama2.c/issues/346)) +- add more tests inside [test.c](test.c) +- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping - make it easier to add a new dataset with not too much pain - (LoRA) finetuning and export of Llama 2 models diff --git a/run.c b/run.c index 0eaa655..9329b93 100644 --- a/run.c +++ b/run.c @@ -732,6 +732,8 @@ long time_in_ms() { // generation loop void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { + char *empty_prompt = ""; + if (prompt == NULL) { prompt = empty_prompt; } // encode the (string) prompt into tokens sequence int num_prompt_tokens = 0; @@ -785,6 +787,108 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, free(prompt_tokens); } +void read_stdin(const char* guide, char* buffer, size_t bufsize) { + // read a line from stdin, up to but not including \n + printf("%s", guide); + if (fgets(buffer, bufsize, stdin) != NULL) { + size_t len = strlen(buffer); + if (len > 0 && buffer[len - 1] == '\n') { + buffer[len - 1] = '\0'; // strip newline + } + } +} + +// ---------------------------------------------------------------------------- +// chat loop +// I manually inspected the tokens for a few chat conversations compared to +// python reference and that seemed ok, but this was not thoroughly tested and +// is not safely implemented, it's more a proof of concept atm. + +void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, + char *cli_user_prompt, char *cli_system_prompt, int steps) { + + // buffers for reading the system prompt and user prompt from stdin + // you'll notice they are soomewhat haphazardly and unsafely set atm + char system_prompt[512]; + char user_prompt[512]; + char rendered_prompt[1152]; + int num_prompt_tokens = 0; + int* prompt_tokens = (int*)malloc(1152 * sizeof(int)); + int user_idx; + + // start the main loop + int8_t user_turn = 1; // user starts + int next; // will store the next token in the sequence + int token; // stores the current token to feed into the transformer + int prev_token; + int pos = 0; // position in the sequence + while (pos < steps) { + + // when it is the user's turn to contribute tokens to the dialog... + if (user_turn) { + // get the (optional) system prompt at position 0 + if (pos == 0) { + // at position 0, the user can also contribute a system prompt + if (cli_system_prompt == NULL) { + // system prompt was not passed in, attempt to get it from stdin + read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt)); + } else { + // system prompt was passed in, use it + strcpy(system_prompt, cli_system_prompt); + } + } + // get the user prompt + if (pos == 0 && cli_user_prompt != NULL) { + // user prompt for position 0 was passed in, use it + strcpy(user_prompt, cli_user_prompt); + } else { + // otherwise get user prompt from stdin + read_stdin("User: ", user_prompt, sizeof(user_prompt)); + } + // render user/system prompts into the Llama 2 Chat schema + if (pos == 0 && system_prompt[0] != '\0') { + char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; + sprintf(rendered_prompt, system_template, system_prompt, user_prompt); + } else { + char user_template[] = "[INST] %s [/INST]"; + sprintf(rendered_prompt, user_template, user_prompt); + } + // encode the rendered prompt into tokens + encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + user_idx = 0; // reset the user index + user_turn = 0; + printf("Assistant: "); + } + + // determine the token to pass into the transformer next + if (user_idx < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + token = prompt_tokens[user_idx++]; + } else { + // otherwise use the next token sampled from previous turn + token = next; + } + // EOS (=2) token ends the Assistant turn + if (token == 2) { user_turn = 1; } + + // forward the transformer to get logits for the next token + float* logits = forward(transformer, token, pos); + next = sample(sampler, logits); + pos++; + + if (user_idx >= num_prompt_tokens && next != 2) { + // the Assistant is responding, so print its output + char* piece = decode(tokenizer, token, next); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + if (next == 2) { printf("\n"); } + } + printf("\n"); + free(prompt_tokens); +} + + // ---------------------------------------------------------------------------- // CLI, include only if not testing #ifndef TESTING @@ -799,6 +903,8 @@ void error_usage() { fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); fprintf(stderr, " -z optional path to custom tokenizer\n"); + fprintf(stderr, " -m mode: generate|chat, default: generate\n"); + fprintf(stderr, " -y (optional) system prompt in chat mode\n"); exit(EXIT_FAILURE); } @@ -807,11 +913,13 @@ int main(int argc, char *argv[]) { // default parameters char *checkpoint_path = NULL; // e.g. out/model.bin char *tokenizer_path = "tokenizer.bin"; - float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher - float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower - int steps = 256; // number of steps to run for - char *prompt = ""; // prompt string + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower + int steps = 256; // number of steps to run for + char *prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default + char *mode = "generate"; // generate|chat + char *system_prompt = NULL; // the (optional) system prompt to use in chat mode // poor man's C argparse so we can override the defaults above from the command line if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } @@ -827,6 +935,8 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } + else if (argv[i][1] == 'm') { mode = argv[i + 1]; } + else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; } else { error_usage(); } } @@ -850,7 +960,14 @@ int main(int argc, char *argv[]) { build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); // run! - generate(&transformer, &tokenizer, &sampler, prompt, steps); + if (strcmp(mode, "generate") == 0) { + generate(&transformer, &tokenizer, &sampler, prompt, steps); + } else if (strcmp(mode, "chat") == 0) { + chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps); + } else { + fprintf(stderr, "unknown mode: %s\n", mode); + error_usage(); + } // memory and file handles cleanup free_sampler(&sampler);