ok getting closer, and manually verified correctness of the schema matching python. still some weirdness in the printing to chase down, and also have to tune the buffer lengths and make them sensible and such

This commit is contained in:
Andrej Karpathy
2023-08-24 04:31:06 +00:00
parent 40fb902cf0
commit 3d787b2463
+9 -6
View File
@@ -839,7 +839,7 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
strcpy(user_prompt, cli_user_prompt); strcpy(user_prompt, cli_user_prompt);
} else { } else {
// otherwise get user prompt from stdin // otherwise get user prompt from stdin
read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt)); read_stdin("User: ", user_prompt, sizeof(user_prompt));
} }
// render user/system prompts into the Llama 2 Chat schema // render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') { if (pos == 0 && system_prompt[0] != '\0') {
@@ -853,10 +853,10 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index user_idx = 0; // reset the user index
user_turn = 0; user_turn = 0;
printf("Assistant: ");
} }
// determine the token to pass into the transformer next // determine the token to pass into the transformer next
prev_token = token;
if (user_idx < num_prompt_tokens) { if (user_idx < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token // if we are still processing the input prompt, force the next prompt token
token = prompt_tokens[user_idx++]; token = prompt_tokens[user_idx++];
@@ -872,10 +872,13 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
next = sample(sampler, logits); next = sample(sampler, logits);
pos++; pos++;
// print the token as string, decode it with the Tokenizer object if (user_idx >= num_prompt_tokens && next != 2) {
char* piece = decode(tokenizer, prev_token, token); // the Assistant is responding, so print its output
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes char* piece = decode(tokenizer, token, next);
fflush(stdout); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
if (next == 2) { printf("\n"); }
} }
printf("\n"); printf("\n");
free(prompt_tokens); free(prompt_tokens);