From 3d787b24635a7031f933ed42afc58e8117ee4504 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 24 Aug 2023 04:31:06 +0000 Subject: [PATCH] ok getting closer, and manually verified correctness of the schema matching python. still some weirdness in the printing to chase down, and also have to tune the buffer lengths and make them sensible and such --- run.c | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/run.c b/run.c index 15b2683..40f68e6 100644 --- a/run.c +++ b/run.c @@ -839,7 +839,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, strcpy(user_prompt, cli_user_prompt); } else { // otherwise get user prompt from stdin - read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); + read_stdin("User: ", user_prompt, sizeof(user_prompt)); } // render user/system prompts into the Llama 2 Chat schema if (pos == 0 && system_prompt[0] != '\0') { @@ -853,10 +853,10 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); user_idx = 0; // reset the user index user_turn = 0; + printf("Assistant: "); } // determine the token to pass into the transformer next - prev_token = token; if (user_idx < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token token = prompt_tokens[user_idx++]; @@ -872,10 +872,13 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, next = sample(sampler, logits); pos++; - // print the token as string, decode it with the Tokenizer object - char* piece = decode(tokenizer, prev_token, token); - safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes - fflush(stdout); + if (user_idx >= num_prompt_tokens && next != 2) { + // the Assistant is responding, so print its output + char* piece = decode(tokenizer, token, next); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + if (next == 2) { printf("\n"); } } printf("\n"); free(prompt_tokens);