let's start respecting the BOS token. Don't print it explicitly, and terminate sequence if it appears. This makes sense especially after the recent addition of prompting. Also be careful with timings and making sure they come out right if we exit early in this data-dependent manner

This commit is contained in:
Andrej Karpathy
2023-08-06 16:33:23 +00:00
parent 4e8a3e8d5d
commit 79791f39b4
+14 -9
View File
@@ -603,12 +603,12 @@ int main(int argc, char *argv[]) {
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
while (pos < steps) {
// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
// advance the state state machine
if(pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
@@ -632,22 +632,27 @@ int main(int argc, char *argv[]) {
}
}
}
pos++;
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
fflush(stdout);
// advance forward
token = next;
pos++;
// init our timer here because the first iteration is slow due to memmap
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
printf("\n");
// report achieved tok/s
long end = time_in_ms();
fprintf(stderr, "\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000);
// report achieved tok/s (pos-1 because the timer starts after first iteration)
if (pos > 1) {
long end = time_in_ms();
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
}
// memory and file handles cleanup
free_run_state(&state);