diff --git a/run.c b/run.c index 15b2683..40f68e6 100644 --- a/run.c +++ b/run.c @@ -839,7 +839,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, strcpy(user_prompt, cli_user_prompt); } else { // otherwise get user prompt from stdin - read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); + read_stdin("User: ", user_prompt, sizeof(user_prompt)); } // render user/system prompts into the Llama 2 Chat schema if (pos == 0 && system_prompt[0] != '\0') { @@ -853,10 +853,10 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); user_idx = 0; // reset the user index user_turn = 0; + printf("Assistant: "); } // determine the token to pass into the transformer next - prev_token = token; if (user_idx < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token token = prompt_tokens[user_idx++]; @@ -872,10 +872,13 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, next = sample(sampler, logits); pos++; - // print the token as string, decode it with the Tokenizer object - char* piece = decode(tokenizer, prev_token, token); - safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes - fflush(stdout); + if (user_idx >= num_prompt_tokens && next != 2) { + // the Assistant is responding, so print its output + char* piece = decode(tokenizer, token, next); + safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes + fflush(stdout); + } + if (next == 2) { printf("\n"); } } printf("\n"); free(prompt_tokens);