refactor the Transformer (Config, Weights, RunState) into a single object, with build and free too

This commit is contained in:
Andrej Karpathy
2023-08-21 03:55:12 +00:00
parent ae2e4f8d88
commit 8a377a1d31
+53 -40
View File
@@ -14,7 +14,7 @@
#include <sys/mman.h> #include <sys/mman.h>
#endif #endif
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management // Transformer model
typedef struct { typedef struct {
int dim; // transformer dimension int dim; // transformer dimension
@@ -64,6 +64,16 @@ typedef struct {
float* value_cache; // (layer, seq_len, dim) float* value_cache; // (layer, seq_len, dim)
} RunState; } RunState;
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;
void malloc_run_state(RunState* s, Config* p) { void malloc_run_state(RunState* s, Config* p) {
// we calloc instead of malloc to keep valgrind happy // we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
@@ -103,10 +113,7 @@ void free_run_state(RunState* s) {
free(s->value_cache); free(s->value_cache);
} }
// ---------------------------------------------------------------------------- void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
// initialization: read from checkpoint
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads; int head_size = p->dim / p->n_heads;
w->token_embedding_table = ptr; w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim; ptr += p->vocab_size * p->dim;
@@ -154,11 +161,26 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
float* weights_ptr = *data + sizeof(Config)/sizeof(float); float* weights_ptr = *data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(weights, config, weights_ptr, shared_weights); memory_map_weights(weights, config, weights_ptr, shared_weights);
}
void build_transformer(char* checkpoint_path, Transformer *t) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
}
void free_transformer(Transformer* t) {
// close the memory mapping
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
if (t->fd != -1) { close(t->fd); }
// free the RunState buffers
free_run_state(&t->state);
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// neural net blocks // neural net blocks; the dynamics of the Transformer
void rmsnorm(float* o, float* x, float* weight, int size) { void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares // calculate sum of squares
@@ -209,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
} }
} }
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { float* forward(Transformer* transformer, int token, int pos) {
// a few convenience variables // a few convenience variables
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
float *x = s->x; float *x = s->x;
int dim = p->dim; int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
@@ -338,6 +363,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// classifier into logits // classifier into logits
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
return s->logits;
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@@ -351,7 +377,7 @@ typedef struct {
char byte_piece[2]; char byte_piece[2];
} Tokenizer; } Tokenizer;
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh // i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size; t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings // malloc space to hold the scores and the strings
@@ -359,8 +385,8 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->byte_piece[1] = '\0'; // null terminate the byte_piece string t->byte_piece[1] = '\0'; // null terminate the byte_piece string
// read in the file // read in the file
FILE *file = fopen(tokenizer, "rb"); FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); } if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len; int len;
for (int i = 0; i < vocab_size; i++) { for (int i = 0; i < vocab_size; i++) {
@@ -374,9 +400,7 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
} }
void free_tokenizer(Tokenizer* t) { void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) { for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
free(t->vocab[i]);
}
free(t->vocab); free(t->vocab);
free(t->vocab_scores); free(t->vocab_scores);
} }
@@ -669,26 +693,17 @@ int main(int argc, char *argv[]) {
} }
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
// read in the model.bin file // build the Transformer via the model .bin file
Config config; Transformer transformer;
TransformerWeights weights; build_transformer(checkpoint_path, &transformer);
int fd = 0; // file descriptor for memory mapping int vocab_size = transformer.config.vocab_size; // convenience copy
float* data = NULL; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
// right now we cannot run for more than config.seq_len steps // build the Tokenizer via the tokenizer .bin file
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// read in the tokenizer .bin file
Tokenizer tokenizer; Tokenizer tokenizer;
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size); build_tokenizer(tokenizer_path, &tokenizer, vocab_size);
// create and init the application RunState // create and init the application RunState
RunState state; ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
malloc_run_state(&state, &config);
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
// process the prompt, if any // process the prompt, if any
int *prompt_tokens = NULL; int *prompt_tokens = NULL;
int num_prompt_tokens = 0; int num_prompt_tokens = 0;
@@ -705,7 +720,7 @@ int main(int argc, char *argv[]) {
while (pos < steps) { while (pos < steps) {
// forward the transformer to get logits for the next token // forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights); float* logits = forward(&transformer, token, pos);
// advance the state state machine // advance the state state machine
if(pos < num_prompt_tokens) { if(pos < num_prompt_tokens) {
@@ -715,19 +730,19 @@ int main(int argc, char *argv[]) {
// sample the next token // sample the next token
if (temperature == 0.0f) { if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability // greedy argmax sampling: take the token with the highest probability
next = argmax(state.logits, config.vocab_size); next = argmax(logits, vocab_size);
} else { } else {
// apply the temperature to the logits // apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; } for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token // apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size); softmax(logits, vocab_size);
// we sample from this distribution to get the next token // we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) { if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution // simply sample from the predicted probability distribution
next = sample(state.logits, config.vocab_size); next = sample(logits, vocab_size);
} else { } else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero // top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(state.logits, config.vocab_size, topp, probindex); next = sample_topp(logits, vocab_size, topp, probindex);
} }
} }
} }
@@ -754,11 +769,9 @@ int main(int argc, char *argv[]) {
} }
// memory and file handles cleanup // memory and file handles cleanup
free_run_state(&state);
free(probindex); free(probindex);
if (prompt_tokens != NULL) { free(prompt_tokens); }
free_tokenizer(&tokenizer); free_tokenizer(&tokenizer);
if (prompt_tokens != NULL) free(prompt_tokens); free_transformer(&transformer);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);
return 0; return 0;
} }