From 8a377a1d3110875ce3d6fdeda31a86489303b12a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 03:55:12 +0000 Subject: [PATCH] refactor the Transformer (Config, Weights, RunState) into a single object, with build and free too --- run.c | 95 +++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/run.c b/run.c index 314c700..1242596 100644 --- a/run.c +++ b/run.c @@ -14,7 +14,7 @@ #include #endif // ---------------------------------------------------------------------------- -// Transformer and RunState structs, and related memory management +// Transformer model typedef struct { int dim; // transformer dimension @@ -64,6 +64,16 @@ typedef struct { float* value_cache; // (layer, seq_len, dim) } RunState; +typedef struct { + Config config; // the hyperparameters of the architecture (the blueprint) + TransformerWeights weights; // the weights of the model + RunState state; // buffers for the "wave" of activations in the forward pass + // some more state needed to properly clean up the memory mapping (sigh) + int fd; // file descriptor for memory mapping + float* data; // memory mapped data pointer + ssize_t file_size; // size of the checkpoint file in bytes +} Transformer; + void malloc_run_state(RunState* s, Config* p) { // we calloc instead of malloc to keep valgrind happy int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; @@ -103,10 +113,7 @@ void free_run_state(RunState* s) { free(s->value_cache); } -// ---------------------------------------------------------------------------- -// initialization: read from checkpoint - -void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { +void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { int head_size = p->dim / p->n_heads; w->token_embedding_table = ptr; ptr += p->vocab_size * p->dim; @@ -154,11 +161,26 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } float* weights_ptr = *data + sizeof(Config)/sizeof(float); - checkpoint_init_weights(weights, config, weights_ptr, shared_weights); + memory_map_weights(weights, config, weights_ptr, shared_weights); +} + +void build_transformer(char* checkpoint_path, Transformer *t) { + // read in the Config and the Weights from the checkpoint + read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); + // allocate the RunState buffers + malloc_run_state(&t->state, &t->config); +} + +void free_transformer(Transformer* t) { + // close the memory mapping + if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } + if (t->fd != -1) { close(t->fd); } + // free the RunState buffers + free_run_state(&t->state); } // ---------------------------------------------------------------------------- -// neural net blocks +// neural net blocks; the dynamics of the Transformer void rmsnorm(float* o, float* x, float* weight, int size) { // calculate sum of squares @@ -209,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) { } } -void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { +float* forward(Transformer* transformer, int token, int pos) { // a few convenience variables + Config* p = &transformer->config; + TransformerWeights* w = &transformer->weights; + RunState* s = &transformer->state; float *x = s->x; int dim = p->dim; int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; @@ -338,6 +363,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // classifier into logits matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); + return s->logits; } // ---------------------------------------------------------------------------- @@ -351,7 +377,7 @@ typedef struct { char byte_piece[2]; } Tokenizer; -void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { +void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; // malloc space to hold the scores and the strings @@ -359,8 +385,8 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); t->byte_piece[1] = '\0'; // null terminate the byte_piece string // read in the file - FILE *file = fopen(tokenizer, "rb"); - if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); } + FILE *file = fopen(tokenizer_path, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } int len; for (int i = 0; i < vocab_size; i++) { @@ -374,9 +400,7 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { } void free_tokenizer(Tokenizer* t) { - for (int i = 0; i < t->vocab_size; i++) { - free(t->vocab[i]); - } + for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } free(t->vocab); free(t->vocab_scores); } @@ -667,28 +691,19 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } else { error_usage(); } } - if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} + if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} - // read in the model.bin file - Config config; - TransformerWeights weights; - int fd = 0; // file descriptor for memory mapping - float* data = NULL; // memory mapped data pointer - ssize_t file_size; // size of the checkpoint file in bytes - read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size); + // build the Transformer via the model .bin file + Transformer transformer; + build_transformer(checkpoint_path, &transformer); + int vocab_size = transformer.config.vocab_size; // convenience copy - // right now we cannot run for more than config.seq_len steps - if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } - - // read in the tokenizer .bin file + // build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; - build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size); + build_tokenizer(tokenizer_path, &tokenizer, vocab_size); // create and init the application RunState - RunState state; - malloc_run_state(&state, &config); - ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling - + ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling // process the prompt, if any int *prompt_tokens = NULL; int num_prompt_tokens = 0; @@ -705,7 +720,7 @@ int main(int argc, char *argv[]) { while (pos < steps) { // forward the transformer to get logits for the next token - transformer(token, pos, &config, &state, &weights); + float* logits = forward(&transformer, token, pos); // advance the state state machine if(pos < num_prompt_tokens) { @@ -715,19 +730,19 @@ int main(int argc, char *argv[]) { // sample the next token if (temperature == 0.0f) { // greedy argmax sampling: take the token with the highest probability - next = argmax(state.logits, config.vocab_size); + next = argmax(logits, vocab_size); } else { // apply the temperature to the logits - for (int q=0; q= 1) { // simply sample from the predicted probability distribution - next = sample(state.logits, config.vocab_size); + next = sample(logits, vocab_size); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(state.logits, config.vocab_size, topp, probindex); + next = sample_topp(logits, vocab_size, topp, probindex); } } } @@ -754,11 +769,9 @@ int main(int argc, char *argv[]) { } // memory and file handles cleanup - free_run_state(&state); free(probindex); + if (prompt_tokens != NULL) { free(prompt_tokens); } free_tokenizer(&tokenizer); - if (prompt_tokens != NULL) free(prompt_tokens); - if (data != MAP_FAILED) munmap(data, file_size); - if (fd != -1) close(fd); + free_transformer(&transformer); return 0; }