Merge pull request #343 from karpathy/feature/chat
Add interactive loop to enable nice chat with a Llama 2 Chat model
This commit is contained in:
@@ -83,6 +83,18 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my
|
||||
|
||||
base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!
|
||||
|
||||
You can also chat with the Llama Chat models. Export the chat model exactly as above:
|
||||
|
||||
```bash
|
||||
python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat
|
||||
```
|
||||
|
||||
Then chat with it by specifying the chat mode using the `-m` flag, e.g.:
|
||||
|
||||
```bash
|
||||
./run llama2_7b_chat.bin -m chat
|
||||
```
|
||||
|
||||
## hugginface models
|
||||
|
||||
We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
|
||||
@@ -207,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`.
|
||||
|
||||
If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183)
|
||||
|
||||
### OpenMP
|
||||
Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
|
||||
**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
|
||||
You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does:
|
||||
|
||||
```bash
|
||||
@@ -324,13 +335,11 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
|
||||
## unsorted todos
|
||||
|
||||
- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp
|
||||
- ability to calculate perplexity in run.c, exactly as done in llama.cpp
|
||||
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
|
||||
- add more tests in [test.c](test.c)
|
||||
- runq.c (int8 quantization) add
|
||||
- run.cu (CUDA) investigate and merge
|
||||
- add an Engine class that serves the model ~efficiently but in PyTorch (see [Issue 346](https://github.com/karpathy/llama2.c/issues/346))
|
||||
- add more tests inside [test.c](test.c)
|
||||
- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping
|
||||
- make it easier to add a new dataset with not too much pain
|
||||
- (LoRA) finetuning and export of Llama 2 models
|
||||
|
||||
|
||||
@@ -732,6 +732,8 @@ long time_in_ms() {
|
||||
// generation loop
|
||||
|
||||
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||
char *empty_prompt = "";
|
||||
if (prompt == NULL) { prompt = empty_prompt; }
|
||||
|
||||
// encode the (string) prompt into tokens sequence
|
||||
int num_prompt_tokens = 0;
|
||||
@@ -785,6 +787,108 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
void read_stdin(const char* guide, char* buffer, size_t bufsize) {
|
||||
// read a line from stdin, up to but not including \n
|
||||
printf("%s", guide);
|
||||
if (fgets(buffer, bufsize, stdin) != NULL) {
|
||||
size_t len = strlen(buffer);
|
||||
if (len > 0 && buffer[len - 1] == '\n') {
|
||||
buffer[len - 1] = '\0'; // strip newline
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// chat loop
|
||||
// I manually inspected the tokens for a few chat conversations compared to
|
||||
// python reference and that seemed ok, but this was not thoroughly tested and
|
||||
// is not safely implemented, it's more a proof of concept atm.
|
||||
|
||||
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
char *cli_user_prompt, char *cli_system_prompt, int steps) {
|
||||
|
||||
// buffers for reading the system prompt and user prompt from stdin
|
||||
// you'll notice they are soomewhat haphazardly and unsafely set atm
|
||||
char system_prompt[512];
|
||||
char user_prompt[512];
|
||||
char rendered_prompt[1152];
|
||||
int num_prompt_tokens = 0;
|
||||
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
|
||||
int user_idx;
|
||||
|
||||
// start the main loop
|
||||
int8_t user_turn = 1; // user starts
|
||||
int next; // will store the next token in the sequence
|
||||
int token; // stores the current token to feed into the transformer
|
||||
int prev_token;
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
// when it is the user's turn to contribute tokens to the dialog...
|
||||
if (user_turn) {
|
||||
// get the (optional) system prompt at position 0
|
||||
if (pos == 0) {
|
||||
// at position 0, the user can also contribute a system prompt
|
||||
if (cli_system_prompt == NULL) {
|
||||
// system prompt was not passed in, attempt to get it from stdin
|
||||
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
|
||||
} else {
|
||||
// system prompt was passed in, use it
|
||||
strcpy(system_prompt, cli_system_prompt);
|
||||
}
|
||||
}
|
||||
// get the user prompt
|
||||
if (pos == 0 && cli_user_prompt != NULL) {
|
||||
// user prompt for position 0 was passed in, use it
|
||||
strcpy(user_prompt, cli_user_prompt);
|
||||
} else {
|
||||
// otherwise get user prompt from stdin
|
||||
read_stdin("User: ", user_prompt, sizeof(user_prompt));
|
||||
}
|
||||
// render user/system prompts into the Llama 2 Chat schema
|
||||
if (pos == 0 && system_prompt[0] != '\0') {
|
||||
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
|
||||
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
|
||||
} else {
|
||||
char user_template[] = "[INST] %s [/INST]";
|
||||
sprintf(rendered_prompt, user_template, user_prompt);
|
||||
}
|
||||
// encode the rendered prompt into tokens
|
||||
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
user_idx = 0; // reset the user index
|
||||
user_turn = 0;
|
||||
printf("Assistant: ");
|
||||
}
|
||||
|
||||
// determine the token to pass into the transformer next
|
||||
if (user_idx < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
token = prompt_tokens[user_idx++];
|
||||
} else {
|
||||
// otherwise use the next token sampled from previous turn
|
||||
token = next;
|
||||
}
|
||||
// EOS (=2) token ends the Assistant turn
|
||||
if (token == 2) { user_turn = 1; }
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(transformer, token, pos);
|
||||
next = sample(sampler, logits);
|
||||
pos++;
|
||||
|
||||
if (user_idx >= num_prompt_tokens && next != 2) {
|
||||
// the Assistant is responding, so print its output
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||
fflush(stdout);
|
||||
}
|
||||
if (next == 2) { printf("\n"); }
|
||||
}
|
||||
printf("\n");
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// CLI, include only if not testing
|
||||
#ifndef TESTING
|
||||
@@ -799,6 +903,8 @@ void error_usage() {
|
||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -i <string> input prompt\n");
|
||||
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
|
||||
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
|
||||
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@@ -807,11 +913,13 @@ int main(int argc, char *argv[]) {
|
||||
// default parameters
|
||||
char *checkpoint_path = NULL; // e.g. out/model.bin
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = ""; // prompt string
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||
char *mode = "generate"; // generate|chat
|
||||
char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||
@@ -827,6 +935,8 @@ int main(int argc, char *argv[]) {
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
|
||||
@@ -850,7 +960,14 @@ int main(int argc, char *argv[]) {
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
|
||||
|
||||
// run!
|
||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||
if (strcmp(mode, "generate") == 0) {
|
||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||
} else if (strcmp(mode, "chat") == 0) {
|
||||
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
|
||||
} else {
|
||||
fprintf(stderr, "unknown mode: %s\n", mode);
|
||||
error_usage();
|
||||
}
|
||||
|
||||
// memory and file handles cleanup
|
||||
free_sampler(&sampler);
|
||||
|
||||
Reference in New Issue
Block a user