and finallygit add run.c split off the generate function. alongside it will come a chat function. we are close
This commit is contained in:
@@ -696,6 +696,62 @@ long time_in_ms() {
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// generation loop
|
||||
|
||||
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if (pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
} else {
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
printf("%s", piece);
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
// init the timer here because the first iteration can be slower
|
||||
if (start == 0) { start = time_in_ms(); }
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||||
if (pos > 1) {
|
||||
long end = time_in_ms();
|
||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||
}
|
||||
|
||||
free(prompt_tokens);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
|
||||
@@ -759,56 +815,10 @@ int main(int argc, char *argv[]) {
|
||||
Sampler sampler;
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
||||
|
||||
// encode the (string) prompt into tokens sequence, if any is given
|
||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(&transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
if (pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
} else {
|
||||
// otherwise sample the next token from the logits
|
||||
next = sample(&sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = decode(&tokenizer, token, next);
|
||||
printf("%s", piece);
|
||||
fflush(stdout);
|
||||
token = next;
|
||||
|
||||
// init the timer here because the first iteration can be slower
|
||||
if (start == 0) { start = time_in_ms(); }
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||||
if (pos > 1) {
|
||||
long end = time_in_ms();
|
||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||
}
|
||||
// run!
|
||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||
|
||||
// memory and file handles cleanup
|
||||
free(prompt_tokens);
|
||||
free_sampler(&sampler);
|
||||
free_tokenizer(&tokenizer);
|
||||
free_transformer(&transformer);
|
||||
|
||||
Reference in New Issue
Block a user