ok getting closer, and manually verified correctness of the schema matching python. still some weirdness in the printing to chase down, and also have to tune the buffer lengths and make them sensible and such

This commit is contained in:
Andrej Karpathy
2023-08-24 04:31:06 +00:00
parent 40fb902cf0
commit 3d787b2463
+9 -6
View File
@@ -839,7 +839,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
strcpy(user_prompt, cli_user_prompt);
} else {
// otherwise get user prompt from stdin
read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt));
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
@@ -853,10 +853,10 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
user_turn = 0;
printf("Assistant: ");
}
// determine the token to pass into the transformer next
prev_token = token;
if (user_idx < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
token = prompt_tokens[user_idx++];
@@ -872,10 +872,13 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
next = sample(sampler, logits);
pos++;
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, prev_token, token);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
if (user_idx >= num_prompt_tokens && next != 2) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
if (next == 2) { printf("\n"); }
}
printf("\n");
free(prompt_tokens);