From e6e3f1322b79f209395e0f38a7d67f2991a58d5c Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 24 Jul 2023 22:54:49 +0000 Subject: [PATCH] candidate memmap implementation --- run.c | 106 +++++++++++++++++++++++++--------------------------------- 1 file changed, 45 insertions(+), 61 deletions(-) diff --git a/run.c b/run.c index 11af4dc..5733a2e 100644 --- a/run.c +++ b/run.c @@ -14,6 +14,9 @@ $ ./run #include #include #include +#include +#include +#include // ---------------------------------------------------------------------------- // Transformer and RunState structs, and related memory management @@ -105,68 +108,39 @@ void free_run_state(RunState* s) { free(s->value_cache); } -void malloc_weights(TransformerWeights* w, Config* p) { - // we calloc instead of malloc to keep valgrind happy - w->token_embedding_table = calloc(p->vocab_size * p->dim, sizeof(float)); - w->rms_att_weight = calloc(p->n_layers * p->dim, sizeof(float)); - w->rms_ffn_weight = calloc(p->n_layers * p->dim, sizeof(float)); - w->wq = calloc(p->n_layers * p->dim * p->dim, sizeof(float)); - w->wk = calloc(p->n_layers * p->dim * p->dim, sizeof(float)); - w->wv = calloc(p->n_layers * p->dim * p->dim, sizeof(float)); - w->wo = calloc(p->n_layers * p->dim * p->dim, sizeof(float)); - w->w1 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float)); - w->w2 = calloc(p->n_layers * p->dim * p->hidden_dim, sizeof(float)); - w->w3 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float)); - w->rms_final_weight = calloc(p->dim, sizeof(float)); - w->freq_cis_real = calloc(p->seq_len * p->dim / 2, sizeof(float)); - w->freq_cis_imag = calloc(p->seq_len * p->dim / 2, sizeof(float)); - // ensure all mallocs went fine - if (!w->token_embedding_table || !w->rms_att_weight || !w->rms_ffn_weight - || !w->wq || !w->wk || !w->wv || !w->wo || !w->w1 || !w->w2 || !w->w3 || - !w->rms_final_weight || !w->freq_cis_real || !w->freq_cis_imag) { - printf("malloc failed!\n"); - exit(1); - } -} - -void free_weights(TransformerWeights* w) { - free(w->token_embedding_table); - free(w->rms_att_weight); - free(w->rms_ffn_weight); - free(w->wq); - free(w->wk); - free(w->wv); - free(w->wo); - free(w->w1); - free(w->w2); - free(w->w3); - free(w->rms_final_weight); - free(w->freq_cis_real); - free(w->freq_cis_imag); -} - // ---------------------------------------------------------------------------- // initialization: read from checkpoint -int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) { - if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != p->vocab_size * p->dim) return 1; - if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1; - if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1; - if (fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1; - if (fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1; - if (fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1; - if (fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1; - if (fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1; - if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != p->n_layers * p->hidden_dim * p->dim) return 1; - if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1; - if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != p->dim) return 1; +void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) { + float* ptr = f; + w->token_embedding_table = ptr; + ptr += p->vocab_size * p->dim; + w->rms_att_weight = ptr; + ptr += p->n_layers * p->dim; + w->wq = ptr; + ptr += p->n_layers * p->dim * p->dim; + w->wk = ptr; + ptr += p->n_layers * p->dim * p->dim; + w->wv = ptr; + ptr += p->n_layers * p->dim * p->dim; + w->wo = ptr; + ptr += p->n_layers * p->dim * p->dim; + w->rms_ffn_weight = ptr; + ptr += p->n_layers * p->dim; + w->w1 = ptr; + ptr += p->n_layers * p->dim * p->hidden_dim; + w->w2 = ptr; + ptr += p->n_layers * p->hidden_dim * p->dim; + w->w3 = ptr; + ptr += p->n_layers * p->dim * p->hidden_dim; + w->rms_final_weight = ptr; + ptr += p->dim; + w->freq_cis_real = ptr; int head_size = p->dim / p->n_heads; - if (fread(w->freq_cis_real, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1; - if (fread(w->freq_cis_imag, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1; - return 0; + ptr += p->seq_len * head_size / 2; + w->freq_cis_imag = ptr; } - // ---------------------------------------------------------------------------- // neural net blocks @@ -411,6 +385,9 @@ int main(int argc, char *argv[]) { // read in the model.bin file Config config; TransformerWeights weights; + int fd = 0; + float* data = NULL; + long file_size; { FILE *file = fopen(checkpoint, "rb"); if (!file) { @@ -419,10 +396,16 @@ int main(int argc, char *argv[]) { } // read in the config header if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } - // read in the Transformer weights - malloc_weights(&weights, &config); - if(checkpoint_init_weights(&weights, &config, file)) { return 1; } + // figure out the file size + fseek(file, 0, SEEK_END); // move file pointer to end of file + file_size = ftell(file); // get the file size, in bytes fclose(file); + // memory map the Transformer weights into the data pointer + fd = open(checkpoint, O_RDONLY); // open in read only mode + if (fd == -1) { printf("open failed!\n"); return 1; } + data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); + if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; } + checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float)); } // right now we cannot run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } @@ -485,10 +468,11 @@ int main(int argc, char *argv[]) { long end = time_in_ms(); printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000); - // memory cleanup + // memory and file handles cleanup free_run_state(&state); - free_weights(&weights); for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } free(vocab); + if (data != MAP_FAILED) munmap(data, file_size); + if (fd != -1) close(fd); return 0; }