diff --git a/run.c b/run.c
index da1fbc4..6f281f8 100644
--- a/run.c
+++ b/run.c
@@ -732,6 +732,8 @@ long time_in_ms() {
// generation loop
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
+ char *empty_prompt = "";
+ if (prompt == NULL) { prompt = empty_prompt; }
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
@@ -785,6 +787,98 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
free(prompt_tokens);
}
+void read_stdin(const char* guide, char* buffer, size_t bufsize) {
+ // read a line from stdin, up to but not including \n
+ printf("%s", guide);
+ if (fgets(buffer, bufsize, stdin) != NULL) {
+ size_t len = strlen(buffer);
+ if (len > 0 && buffer[len - 1] == '\n') {
+ buffer[len - 1] = '\0'; // strip newline
+ }
+ }
+}
+
+void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
+
+ // buffers for reading the system prompt and user prompt from stdin
+ char system_prompt[512];
+ char user_prompt[512];
+ char rendered_prompt[512];
+ int num_prompt_tokens = 0;
+ int* prompt_tokens = (int*)malloc(512 * sizeof(int));
+ int user_idx;
+
+ // start the main loop
+ int8_t user_turn = 1; // user starts
+ int next; // will store the next token in the sequence
+ int token; // stores the current token to feed into the transformer
+ int prev_token;
+ int pos = 0; // position in the sequence
+ while (pos < steps) {
+
+ // when it is the user's turn to contribute tokens to the dialog...
+ if (user_turn) {
+ // get the (optional) system prompt at position 0
+ if (pos == 0) {
+ // at position 0, the user can also contribute a system prompt
+ if (cli_system_prompt == NULL) {
+ // system prompt was not passed in, attempt to get it from stdin
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
+ } else {
+ // system prompt was passed in, use it
+ strcpy(system_prompt, cli_system_prompt);
+ }
+ }
+ // get the user prompt
+ if (pos == 0 && cli_user_prompt != NULL) {
+ // user prompt for position 0 was passed in, use it
+ strcpy(user_prompt, cli_user_prompt);
+ } else {
+ // otherwise get user prompt from stdin
+ read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt));
+ }
+ // render user/system prompts into the Llama 2 Chat schema
+ if (pos == 0 && system_prompt[0] != '\0') {
+ char system_template[] = "[INST] <>\n%s\n<>\n\n%s\n[/INST]";
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
+ } else {
+ char user_template[] = "[INST] %s [/INST]";
+ sprintf(rendered_prompt, user_template, user_prompt);
+ }
+ // encode the rendered prompt into tokens
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
+ user_idx = 0; // reset the user index
+ user_turn = 0;
+ }
+
+ // determine the token to pass into the transformer next
+ prev_token = token;
+ if (user_idx < num_prompt_tokens) {
+ // if we are still processing the input prompt, force the next prompt token
+ token = prompt_tokens[user_idx++];
+ } else {
+ // otherwise use the next token sampled from previous turn
+ token = next;
+ }
+ // EOS (=2) token ends the Assistant turn
+ if (token == 2) { user_turn = 1; }
+
+ // forward the transformer to get logits for the next token
+ float* logits = forward(transformer, token, pos);
+ next = sample(sampler, logits);
+ pos++;
+
+ // print the token as string, decode it with the Tokenizer object
+ char* piece = decode(tokenizer, prev_token, token);
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
+ fflush(stdout);
+ }
+ printf("\n");
+ free(prompt_tokens);
+}
+
+
// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTING
@@ -799,6 +893,8 @@ void error_usage() {
fprintf(stderr, " -n number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -i input prompt\n");
fprintf(stderr, " -z optional path to custom tokenizer\n");
+ fprintf(stderr, " -m mode: generate|chat, default: generate\n");
+ fprintf(stderr, " -y (optional) system prompt in chat mode\n");
exit(EXIT_FAILURE);
}
@@ -807,11 +903,13 @@ int main(int argc, char *argv[]) {
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
- float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
- float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
- int steps = 256; // number of steps to run for
- char *prompt = ""; // prompt string
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
+ int steps = 256; // number of steps to run for
+ char *prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
+ char *mode = "generate"; // generate|chat
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@@ -827,6 +925,8 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
else { error_usage(); }
}
@@ -850,7 +950,14 @@ int main(int argc, char *argv[]) {
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
// run!
- generate(&transformer, &tokenizer, &sampler, prompt, steps);
+ if (strcmp(mode, "generate") == 0) {
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
+ } else if (strcmp(mode, "chat") == 0) {
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
+ } else {
+ fprintf(stderr, "unknown mode: %s\n", mode);
+ error_usage();
+ }
// memory and file handles cleanup
free_sampler(&sampler);