attempt at chat function, but it was 8AM and I didn't have coffee yet. Seems to work but it's probably subtly broken or too complex. version 1 only, lots of hard-coded non-sensical buffer sizes. Have to go to work now

This commit is contained in:
Andrej Karpathy
2023-08-23 16:27:48 +00:00
parent 7ac65cb2c2
commit c5e0e7fce4
+112 -5
View File
@@ -732,6 +732,8 @@ long time_in_ms() {
// generation loop
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
char *empty_prompt = "";
if (prompt == NULL) { prompt = empty_prompt; }
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
@@ -785,6 +787,98 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
free(prompt_tokens);
}
void read_stdin(const char* guide, char* buffer, size_t bufsize) {
// read a line from stdin, up to but not including \n
printf("%s", guide);
if (fgets(buffer, bufsize, stdin) != NULL) {
size_t len = strlen(buffer);
if (len > 0 && buffer[len - 1] == '\n') {
buffer[len - 1] = '\0'; // strip newline
}
}
}
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
char *cli_user_prompt, char *cli_system_prompt, int steps) {
// buffers for reading the system prompt and user prompt from stdin
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[512];
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(512 * sizeof(int));
int user_idx;
// start the main loop
int8_t user_turn = 1; // user starts
int next; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int prev_token;
int pos = 0; // position in the sequence
while (pos < steps) {
// when it is the user's turn to contribute tokens to the dialog...
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
// at position 0, the user can also contribute a system prompt
if (cli_system_prompt == NULL) {
// system prompt was not passed in, attempt to get it from stdin
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
} else {
// system prompt was passed in, use it
strcpy(system_prompt, cli_system_prompt);
}
}
// get the user prompt
if (pos == 0 && cli_user_prompt != NULL) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_user_prompt);
} else {
// otherwise get user prompt from stdin
read_stdin("Enter user prompt: ", user_prompt, sizeof(user_prompt));
}
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s\n[/INST]";
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
char user_template[] = "[INST] %s [/INST]";
sprintf(rendered_prompt, user_template, user_prompt);
}
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
user_turn = 0;
}
// determine the token to pass into the transformer next
prev_token = token;
if (user_idx < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
token = prompt_tokens[user_idx++];
} else {
// otherwise use the next token sampled from previous turn
token = next;
}
// EOS (=2) token ends the Assistant turn
if (token == 2) { user_turn = 1; }
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
next = sample(sampler, logits);
pos++;
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, prev_token, token);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
printf("\n");
free(prompt_tokens);
}
// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTING
@@ -799,6 +893,8 @@ void error_usage() {
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
exit(EXIT_FAILURE);
}
@@ -807,11 +903,13 @@ int main(int argc, char *argv[]) {
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = ""; // prompt string
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
char *mode = "generate"; // generate|chat
char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@@ -827,6 +925,8 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
else { error_usage(); }
}
@@ -850,7 +950,14 @@ int main(int argc, char *argv[]) {
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
// run!
generate(&transformer, &tokenizer, &sampler, prompt, steps);
if (strcmp(mode, "generate") == 0) {
generate(&transformer, &tokenizer, &sampler, prompt, steps);
} else if (strcmp(mode, "chat") == 0) {
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
} else {
fprintf(stderr, "unknown mode: %s\n", mode);
error_usage();
}
// memory and file handles cleanup
free_sampler(&sampler);