and finallygit add run.c split off the generate function. alongside it will come a chat function. we are close

This commit is contained in:
Andrej Karpathy
2023-08-22 02:22:36 +00:00
parent d73b917d3b
commit 0e362f735f
+58 -48
View File
@@ -696,6 +696,62 @@ long time_in_ms() {
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
// ----------------------------------------------------------------------------
// generation loop
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
int num_prompt_tokens = 0; // the total number of prompt tokens
if (prompt != NULL) {
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
encode(tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
while (pos < steps) {
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
// advance the state state machine
if (pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// otherwise sample the next token from the logits
next = sample(sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
printf("%s", piece);
fflush(stdout);
token = next;
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
printf("\n");
// report achieved tok/s (pos-1 because the timer starts after first iteration)
if (pos > 1) {
long end = time_in_ms();
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
}
free(prompt_tokens);
}
// ----------------------------------------------------------------------------
// int main
@@ -759,56 +815,10 @@ int main(int argc, char *argv[]) {
Sampler sampler;
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
// encode the (string) prompt into tokens sequence, if any is given
int *prompt_tokens = NULL; // the sequence of prompt tokens
int num_prompt_tokens = 0; // the total number of prompt tokens
if (prompt != NULL) {
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
while (pos < steps) {
// forward the transformer to get logits for the next token
float* logits = forward(&transformer, token, pos);
// advance the state state machine
if (pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// otherwise sample the next token from the logits
next = sample(&sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
char* piece = decode(&tokenizer, token, next);
printf("%s", piece);
fflush(stdout);
token = next;
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
printf("\n");
// report achieved tok/s (pos-1 because the timer starts after first iteration)
if (pos > 1) {
long end = time_in_ms();
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
}
// run!
generate(&transformer, &tokenizer, &sampler, prompt, steps);
// memory and file handles cleanup
free(prompt_tokens);
free_sampler(&sampler);
free_tokenizer(&tokenizer);
free_transformer(&transformer);