diff --git a/run.c b/run.c index 8202534..02cc877 100644 --- a/run.c +++ b/run.c @@ -603,12 +603,12 @@ int main(int argc, char *argv[]) { int next; // will store the next token in the sequence int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer int pos = 0; // position in the sequence - printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons while (pos < steps) { // forward the transformer to get logits for the next token transformer(token, pos, &config, &state, &weights); + // advance the state state machine if(pos < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos]; @@ -632,22 +632,27 @@ int main(int argc, char *argv[]) { } } } + pos++; - // following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89) + // data-dependent terminating condition: the BOS (1) token delimits sequences + if (next == 1) { break; } + + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; printf("%s", token_str); fflush(stdout); - - // advance forward token = next; - pos++; - // init our timer here because the first iteration is slow due to memmap + + // init the timer here because the first iteration can be slower if (start == 0) { start = time_in_ms(); } } + printf("\n"); - // report achieved tok/s - long end = time_in_ms(); - fprintf(stderr, "\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000); + // report achieved tok/s (pos-1 because the timer starts after first iteration) + if (pos > 1) { + long end = time_in_ms(); + fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); + } // memory and file handles cleanup free_run_state(&state);