Merge pull request #50 from karpathy/memmap
candidate memmap implementation
This commit is contained in:
@@ -13,6 +13,9 @@ $ ./run
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Transformer and RunState structs, and related memory management
|
||||
@@ -104,68 +107,39 @@ void free_run_state(RunState* s) {
|
||||
free(s->value_cache);
|
||||
}
|
||||
|
||||
void malloc_weights(TransformerWeights* w, Config* p) {
|
||||
// we calloc instead of malloc to keep valgrind happy
|
||||
w->token_embedding_table = calloc(p->vocab_size * p->dim, sizeof(float));
|
||||
w->rms_att_weight = calloc(p->n_layers * p->dim, sizeof(float));
|
||||
w->rms_ffn_weight = calloc(p->n_layers * p->dim, sizeof(float));
|
||||
w->wq = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
|
||||
w->wk = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
|
||||
w->wv = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
|
||||
w->wo = calloc(p->n_layers * p->dim * p->dim, sizeof(float));
|
||||
w->w1 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
|
||||
w->w2 = calloc(p->n_layers * p->dim * p->hidden_dim, sizeof(float));
|
||||
w->w3 = calloc(p->n_layers * p->hidden_dim * p->dim, sizeof(float));
|
||||
w->rms_final_weight = calloc(p->dim, sizeof(float));
|
||||
w->freq_cis_real = calloc(p->seq_len * p->dim / 2, sizeof(float));
|
||||
w->freq_cis_imag = calloc(p->seq_len * p->dim / 2, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!w->token_embedding_table || !w->rms_att_weight || !w->rms_ffn_weight
|
||||
|| !w->wq || !w->wk || !w->wv || !w->wo || !w->w1 || !w->w2 || !w->w3 ||
|
||||
!w->rms_final_weight || !w->freq_cis_real || !w->freq_cis_imag) {
|
||||
printf("malloc failed!\n");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void free_weights(TransformerWeights* w) {
|
||||
free(w->token_embedding_table);
|
||||
free(w->rms_att_weight);
|
||||
free(w->rms_ffn_weight);
|
||||
free(w->wq);
|
||||
free(w->wk);
|
||||
free(w->wv);
|
||||
free(w->wo);
|
||||
free(w->w1);
|
||||
free(w->w2);
|
||||
free(w->w3);
|
||||
free(w->rms_final_weight);
|
||||
free(w->freq_cis_real);
|
||||
free(w->freq_cis_imag);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// initialization: read from checkpoint
|
||||
|
||||
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
|
||||
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != p->vocab_size * p->dim) return 1;
|
||||
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1;
|
||||
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
|
||||
if (fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
|
||||
if (fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
|
||||
if (fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f) != p->n_layers * p->dim * p->dim) return 1;
|
||||
if (fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f) != p->n_layers * p->dim) return 1;
|
||||
if (fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1;
|
||||
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != p->n_layers * p->hidden_dim * p->dim) return 1;
|
||||
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != p->n_layers * p->dim * p->hidden_dim) return 1;
|
||||
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != p->dim) return 1;
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
|
||||
float* ptr = f;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
w->rms_att_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->wq = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
w->wk = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
w->wv = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
w->wo = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
w->rms_ffn_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->w1 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->w2 = ptr;
|
||||
ptr += p->n_layers * p->hidden_dim * p->dim;
|
||||
w->w3 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
w->rms_final_weight = ptr;
|
||||
ptr += p->dim;
|
||||
w->freq_cis_real = ptr;
|
||||
int head_size = p->dim / p->n_heads;
|
||||
if (fread(w->freq_cis_real, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1;
|
||||
if (fread(w->freq_cis_imag, sizeof(float), p->seq_len * head_size / 2, f) != p->seq_len * head_size / 2) return 1;
|
||||
return 0;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->freq_cis_imag = ptr;
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// neural net blocks
|
||||
|
||||
@@ -410,6 +384,9 @@ int main(int argc, char *argv[]) {
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
TransformerWeights weights;
|
||||
int fd = 0;
|
||||
float* data = NULL;
|
||||
long file_size;
|
||||
{
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) {
|
||||
@@ -418,10 +395,16 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
// read in the config header
|
||||
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// read in the Transformer weights
|
||||
malloc_weights(&weights, &config);
|
||||
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
|
||||
// figure out the file size
|
||||
fseek(file, 0, SEEK_END); // move file pointer to end of file
|
||||
file_size = ftell(file); // get the file size, in bytes
|
||||
fclose(file);
|
||||
// memory map the Transformer weights into the data pointer
|
||||
fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||
if (fd == -1) { printf("open failed!\n"); return 1; }
|
||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
|
||||
checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float));
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
@@ -484,10 +467,11 @@ int main(int argc, char *argv[]) {
|
||||
long end = time_in_ms();
|
||||
printf("\nachieved tok/s: %f\n", config.seq_len / (double)(end-start)*1000);
|
||||
|
||||
// memory cleanup
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
free_weights(&weights);
|
||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||
free(vocab);
|
||||
if (data != MAP_FAILED) munmap(data, file_size);
|
||||
if (fd != -1) close(fd);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user