let's start respecting the BOS token. Don't print it explicitly, and terminate sequence if it appears. This makes sense especially after the recent addition of prompting. Also be careful with timings and making sure they come out right if we exit early in this data-dependent manner
This commit is contained in:
@@ -603,12 +603,12 @@ int main(int argc, char *argv[]) {
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int pos = 0; // position in the sequence
|
||||
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
transformer(token, pos, &config, &state, &weights);
|
||||
|
||||
// advance the state state machine
|
||||
if(pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
@@ -632,22 +632,27 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
}
|
||||
}
|
||||
pos++;
|
||||
|
||||
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
printf("%s", token_str);
|
||||
fflush(stdout);
|
||||
|
||||
// advance forward
|
||||
token = next;
|
||||
pos++;
|
||||
// init our timer here because the first iteration is slow due to memmap
|
||||
|
||||
// init the timer here because the first iteration can be slower
|
||||
if (start == 0) { start = time_in_ms(); }
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
// report achieved tok/s
|
||||
long end = time_in_ms();
|
||||
fprintf(stderr, "\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000);
|
||||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||||
if (pos > 1) {
|
||||
long end = time_in_ms();
|
||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||
}
|
||||
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
|
||||
Reference in New Issue
Block a user