and finallygit add run.c split off the generate function. alongside it will come a chat function. we are close
This commit is contained in:
@@ -696,6 +696,62 @@ long time_in_ms() {
|
|||||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// generation loop
|
||||||
|
|
||||||
|
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
|
||||||
|
|
||||||
|
// encode the (string) prompt into tokens sequence, if any is given
|
||||||
|
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
||||||
|
int num_prompt_tokens = 0; // the total number of prompt tokens
|
||||||
|
if (prompt != NULL) {
|
||||||
|
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||||
|
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// start the main loop
|
||||||
|
long start = 0; // used to time our code, only initialized after first iteration
|
||||||
|
int next; // will store the next token in the sequence
|
||||||
|
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||||
|
int pos = 0; // position in the sequence
|
||||||
|
while (pos < steps) {
|
||||||
|
|
||||||
|
// forward the transformer to get logits for the next token
|
||||||
|
float* logits = forward(transformer, token, pos);
|
||||||
|
|
||||||
|
// advance the state state machine
|
||||||
|
if (pos < num_prompt_tokens) {
|
||||||
|
// if we are still processing the input prompt, force the next prompt token
|
||||||
|
next = prompt_tokens[pos];
|
||||||
|
} else {
|
||||||
|
// otherwise sample the next token from the logits
|
||||||
|
next = sample(sampler, logits);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||||
|
if (next == 1) { break; }
|
||||||
|
|
||||||
|
// print the token as string, decode it with the Tokenizer object
|
||||||
|
char* piece = decode(tokenizer, token, next);
|
||||||
|
printf("%s", piece);
|
||||||
|
fflush(stdout);
|
||||||
|
token = next;
|
||||||
|
|
||||||
|
// init the timer here because the first iteration can be slower
|
||||||
|
if (start == 0) { start = time_in_ms(); }
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||||||
|
if (pos > 1) {
|
||||||
|
long end = time_in_ms();
|
||||||
|
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(prompt_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// int main
|
// int main
|
||||||
|
|
||||||
@@ -759,56 +815,10 @@ int main(int argc, char *argv[]) {
|
|||||||
Sampler sampler;
|
Sampler sampler;
|
||||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
||||||
|
|
||||||
// encode the (string) prompt into tokens sequence, if any is given
|
// run!
|
||||||
int *prompt_tokens = NULL; // the sequence of prompt tokens
|
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||||
int num_prompt_tokens = 0; // the total number of prompt tokens
|
|
||||||
if (prompt != NULL) {
|
|
||||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
|
||||||
encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
// start the main loop
|
|
||||||
long start = 0; // used to time our code, only initialized after first iteration
|
|
||||||
int next; // will store the next token in the sequence
|
|
||||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
|
||||||
int pos = 0; // position in the sequence
|
|
||||||
while (pos < steps) {
|
|
||||||
|
|
||||||
// forward the transformer to get logits for the next token
|
|
||||||
float* logits = forward(&transformer, token, pos);
|
|
||||||
|
|
||||||
// advance the state state machine
|
|
||||||
if (pos < num_prompt_tokens) {
|
|
||||||
// if we are still processing the input prompt, force the next prompt token
|
|
||||||
next = prompt_tokens[pos];
|
|
||||||
} else {
|
|
||||||
// otherwise sample the next token from the logits
|
|
||||||
next = sample(&sampler, logits);
|
|
||||||
}
|
|
||||||
pos++;
|
|
||||||
|
|
||||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
|
||||||
if (next == 1) { break; }
|
|
||||||
|
|
||||||
// print the token as string, decode it with the Tokenizer object
|
|
||||||
char* piece = decode(&tokenizer, token, next);
|
|
||||||
printf("%s", piece);
|
|
||||||
fflush(stdout);
|
|
||||||
token = next;
|
|
||||||
|
|
||||||
// init the timer here because the first iteration can be slower
|
|
||||||
if (start == 0) { start = time_in_ms(); }
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
|
||||||
if (pos > 1) {
|
|
||||||
long end = time_in_ms();
|
|
||||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
|
||||||
}
|
|
||||||
|
|
||||||
// memory and file handles cleanup
|
// memory and file handles cleanup
|
||||||
free(prompt_tokens);
|
|
||||||
free_sampler(&sampler);
|
free_sampler(&sampler);
|
||||||
free_tokenizer(&tokenizer);
|
free_tokenizer(&tokenizer);
|
||||||
free_transformer(&transformer);
|
free_transformer(&transformer);
|
||||||
|
|||||||
Reference in New Issue
Block a user