From 6a61831e19a5cacc0d3825d579632f403ba4c909 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 24 Jul 2023 04:22:32 +0000 Subject: [PATCH] make init code much less sketchy --- run.c | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/run.c b/run.c index a479bc1..70b55c7 100644 --- a/run.c +++ b/run.c @@ -1,10 +1,8 @@ /* Inference for Llama-2 Transformer model in pure C. -Compile simply with: -$ gcc -o run run.c -Or if that doesn't work then: -$ gcc -o run run.c -lm +Example compile: (see README for more details) +$ gcc -O3 -o run run.c -lm Then run with: $ ./run @@ -402,16 +400,24 @@ int main(int argc, char *argv[]) { srand((unsigned int)current_time); } - // read in the config header + // read in the checkpoint .bin file Config config; - FILE *file = fopen(checkpoint, "rb"); - if (!file) { - printf("Unable to open file!"); - return 1; + TransformerWeights weights; + { + FILE *file = fopen(checkpoint, "rb"); + if (!file) { + printf("Unable to open file!"); + return 1; + } + // read in the config header + fread(&config, sizeof(Config), 1, file); + // read in the Transformer weights + malloc_weights(&weights, &config); + checkpoint_init_weights(&weights, &config, file); + fclose(file); } - fread(&config, sizeof(Config), 1, file); - // init the Tokenizer + // read in the tokenizer vocab char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); { FILE *file = fopen("tokenizer.bin", "r"); @@ -427,14 +433,9 @@ int main(int argc, char *argv[]) { fread(vocab[i], len, 1, file); vocab[i][len] = '\0'; // add the string terminating token } + fclose(file); } - // create and init the Transformer - TransformerWeights weights; - malloc_weights(&weights, &config); - checkpoint_init_weights(&weights, &config, file); - fclose(file); - // create and init the application RunState RunState state; malloc_run_state(&state, &config); @@ -469,7 +470,7 @@ int main(int argc, char *argv[]) { pos++; } printf("\n"); - + // report our achieved tok/s clock_t end = clock(); double elapsed = (double)(end - start) / CLOCKS_PER_SEC;