From 79791f39b49703f14fb015b558a2d8d6e692eb49 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 6 Aug 2023 16:33:23 +0000 Subject: [PATCH] let's start respecting the BOS token. Don't print it explicitly, and terminate sequence if it appears. This makes sense especially after the recent addition of prompting. Also be careful with timings and making sure they come out right if we exit early in this data-dependent manner --- run.c | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/run.c b/run.c index 8202534..02cc877 100644 --- a/run.c +++ b/run.c @@ -603,12 +603,12 @@ int main(int argc, char *argv[]) { int next; // will store the next token in the sequence int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer int pos = 0; // position in the sequence - printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons while (pos < steps) { // forward the transformer to get logits for the next token transformer(token, pos, &config, &state, &weights); + // advance the state state machine if(pos < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos]; @@ -632,22 +632,27 @@ int main(int argc, char *argv[]) { } } } + pos++; - // following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89) + // data-dependent terminating condition: the BOS (1) token delimits sequences + if (next == 1) { break; } + + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; printf("%s", token_str); fflush(stdout); - - // advance forward token = next; - pos++; - // init our timer here because the first iteration is slow due to memmap + + // init the timer here because the first iteration can be slower if (start == 0) { start = time_in_ms(); } } + printf("\n"); - // report achieved tok/s - long end = time_in_ms(); - fprintf(stderr, "\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000); + // report achieved tok/s (pos-1 because the timer starts after first iteration) + if (pos > 1) { + long end = time_in_ms(); + fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); + } // memory and file handles cleanup free_run_state(&state);