From 0e362f735f097a7c8306b67b56cc41c57ef9a091 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 22 Aug 2023 02:22:36 +0000 Subject: [PATCH] and finallygit add run.c split off the generate function. alongside it will come a chat function. we are close --- run.c | 106 ++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 48 deletions(-) diff --git a/run.c b/run.c index ae78f67..328519a 100644 --- a/run.c +++ b/run.c @@ -696,6 +696,62 @@ long time_in_ms() { return time.tv_sec * 1000 + time.tv_nsec / 1000000; } +// ---------------------------------------------------------------------------- +// generation loop + +void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) { + + // encode the (string) prompt into tokens sequence, if any is given + int *prompt_tokens = NULL; // the sequence of prompt tokens + int num_prompt_tokens = 0; // the total number of prompt tokens + if (prompt != NULL) { + prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); + encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens); + } + + // start the main loop + long start = 0; // used to time our code, only initialized after first iteration + int next; // will store the next token in the sequence + int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer + int pos = 0; // position in the sequence + while (pos < steps) { + + // forward the transformer to get logits for the next token + float* logits = forward(transformer, token, pos); + + // advance the state state machine + if (pos < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + next = prompt_tokens[pos]; + } else { + // otherwise sample the next token from the logits + next = sample(sampler, logits); + } + pos++; + + // data-dependent terminating condition: the BOS (1) token delimits sequences + if (next == 1) { break; } + + // print the token as string, decode it with the Tokenizer object + char* piece = decode(tokenizer, token, next); + printf("%s", piece); + fflush(stdout); + token = next; + + // init the timer here because the first iteration can be slower + if (start == 0) { start = time_in_ms(); } + } + printf("\n"); + + // report achieved tok/s (pos-1 because the timer starts after first iteration) + if (pos > 1) { + long end = time_in_ms(); + fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); + } + + free(prompt_tokens); +} + // ---------------------------------------------------------------------------- // int main @@ -759,56 +815,10 @@ int main(int argc, char *argv[]) { Sampler sampler; build_sampler(&sampler, transformer.config.vocab_size, temperature, topp); - // encode the (string) prompt into tokens sequence, if any is given - int *prompt_tokens = NULL; // the sequence of prompt tokens - int num_prompt_tokens = 0; // the total number of prompt tokens - if (prompt != NULL) { - prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); - encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); - } - - // start the main loop - long start = 0; // used to time our code, only initialized after first iteration - int next; // will store the next token in the sequence - int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer - int pos = 0; // position in the sequence - while (pos < steps) { - - // forward the transformer to get logits for the next token - float* logits = forward(&transformer, token, pos); - - // advance the state state machine - if (pos < num_prompt_tokens) { - // if we are still processing the input prompt, force the next prompt token - next = prompt_tokens[pos]; - } else { - // otherwise sample the next token from the logits - next = sample(&sampler, logits); - } - pos++; - - // data-dependent terminating condition: the BOS (1) token delimits sequences - if (next == 1) { break; } - - // print the token as string, decode it with the Tokenizer object - char* piece = decode(&tokenizer, token, next); - printf("%s", piece); - fflush(stdout); - token = next; - - // init the timer here because the first iteration can be slower - if (start == 0) { start = time_in_ms(); } - } - printf("\n"); - - // report achieved tok/s (pos-1 because the timer starts after first iteration) - if (pos > 1) { - long end = time_in_ms(); - fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); - } + // run! + generate(&transformer, &tokenizer, &sampler, prompt, steps); // memory and file handles cleanup - free(prompt_tokens); free_sampler(&sampler); free_tokenizer(&tokenizer); free_transformer(&transformer);