diff --git a/run.c b/run.c index da1fbc4..6f281f8 100644 --- a/run.c +++ b/run.c @@ -732,6 +732,8 @@ long time_in_ms() { // generation loop void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { + char *empty_prompt = ""; + if (prompt == NULL) { prompt = empty_prompt; } // encode the (string) prompt into tokens sequence int num_prompt_tokens = 0; @@ -785,6 +787,98 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, free(prompt_tokens); } +void read_stdin(const char* guide, char* buffer, size_t bufsize) { + // read a line from stdin, up to but not including \n + printf("%s", guide); + if (fgets(buffer, bufsize, stdin) != NULL) { + size_t len = strlen(buffer); + if (len > 0 && buffer[len - 1] == '\n') { + buffer[len - 1] = '\0'; // strip newline + } + } +} + +void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, + char *cli_user_prompt, char *cli_system_prompt, int steps) { + + // buffers for reading the system prompt and user prompt from stdin + char system_prompt[512]; + char user_prompt[512]; + char rendered_prompt[512]; + int num_prompt_tokens = 0; + int* prompt_tokens = (int*)malloc(512 * sizeof(int)); + int user_idx; + + // start the main loop + int8_t user_turn = 1; // user starts + int next; // will store the next token in the sequence + int token; // stores the current token to feed into the transformer + int prev_token; + int pos = 0; // position in the sequence + while (pos < steps) { + + // when it is the user's turn to contribute tokens to the dialog... + if (user_turn) { + // get the (optional) system prompt at position 0 + if (pos == 0) { + // at position 0, the user can also contribute a system prompt + if (cli_system_prompt == NULL) { + // system prompt was not passed in, attempt to get it from stdin + read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt)); + } else { + // system prompt was passed in, use it + strcpy(system_prompt, cli_system_prompt); + } + } + // get the user prompt + if (pos == 0 && cli_user_prompt != NULL) { + // user prompt for position 0 was passed in, use it + strcpy(user_prompt, cli_user_prompt); + } else { + // otherwise get user prompt from stdin + read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); + } + // render user/system prompts into the Llama 2 Chat schema + if (pos == 0 && system_prompt[0] != '\0') { + char system_template[] = "[INST] <>\n%s\n<>\n\n%s\n[/INST]"; + sprintf(rendered_prompt, system_template, system_prompt, user_prompt); + } else { + char user_template[] = "[INST] %s [/INST]"; + sprintf(rendered_prompt, user_template, user_prompt); + } + // encode the rendered prompt into tokens + encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); + user_idx = 0; // reset the user index + user_turn = 0; + } + + // determine the token to pass into the transformer next + prev_token = token; + if (user_idx < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + token = prompt_tokens[user_idx++]; + } else { + // otherwise use the next token sampled from previous turn + token = next; + } + // EOS (=2) token ends the Assistant turn + if (token == 2) { user_turn = 1; } + + // forward the transformer to get logits for the next token + float* logits = forward(transformer, token, pos); + next = sample(sampler, logits); + pos++; + + // print the token as string, decode it with the Tokenizer object + char* piece = decode(tokenizer, prev_token, token); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + printf("\n"); + free(prompt_tokens); +} + + // ---------------------------------------------------------------------------- // CLI, include only if not testing #ifndef TESTING @@ -799,6 +893,8 @@ void error_usage() { fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n"); fprintf(stderr, " -i input prompt\n"); fprintf(stderr, " -z optional path to custom tokenizer\n"); + fprintf(stderr, " -m mode: generate|chat, default: generate\n"); + fprintf(stderr, " -y (optional) system prompt in chat mode\n"); exit(EXIT_FAILURE); } @@ -807,11 +903,13 @@ int main(int argc, char *argv[]) { // default parameters char *checkpoint_path = NULL; // e.g. out/model.bin char *tokenizer_path = "tokenizer.bin"; - float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher - float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower - int steps = 256; // number of steps to run for - char *prompt = ""; // prompt string + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower + int steps = 256; // number of steps to run for + char *prompt = NULL; // prompt string unsigned long long rng_seed = 0; // seed rng with time by default + char *mode = "generate"; // generate|chat + char *system_prompt = NULL; // the (optional) system prompt to use in chat mode // poor man's C argparse so we can override the defaults above from the command line if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } @@ -827,6 +925,8 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } + else if (argv[i][1] == 'm') { mode = argv[i + 1]; } + else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; } else { error_usage(); } } @@ -850,7 +950,14 @@ int main(int argc, char *argv[]) { build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed); // run! - generate(&transformer, &tokenizer, &sampler, prompt, steps); + if (strcmp(mode, "generate") == 0) { + generate(&transformer, &tokenizer, &sampler, prompt, steps); + } else if (strcmp(mode, "chat") == 0) { + chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps); + } else { + fprintf(stderr, "unknown mode: %s\n", mode); + error_usage(); + } // memory and file handles cleanup free_sampler(&sampler);