refactor the Transformer (Config, Weights, RunState) into a single object, with build and free too
This commit is contained in:
@@ -14,7 +14,7 @@
|
|||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
#endif
|
#endif
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Transformer and RunState structs, and related memory management
|
// Transformer model
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int dim; // transformer dimension
|
int dim; // transformer dimension
|
||||||
@@ -64,6 +64,16 @@ typedef struct {
|
|||||||
float* value_cache; // (layer, seq_len, dim)
|
float* value_cache; // (layer, seq_len, dim)
|
||||||
} RunState;
|
} RunState;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
Config config; // the hyperparameters of the architecture (the blueprint)
|
||||||
|
TransformerWeights weights; // the weights of the model
|
||||||
|
RunState state; // buffers for the "wave" of activations in the forward pass
|
||||||
|
// some more state needed to properly clean up the memory mapping (sigh)
|
||||||
|
int fd; // file descriptor for memory mapping
|
||||||
|
float* data; // memory mapped data pointer
|
||||||
|
ssize_t file_size; // size of the checkpoint file in bytes
|
||||||
|
} Transformer;
|
||||||
|
|
||||||
void malloc_run_state(RunState* s, Config* p) {
|
void malloc_run_state(RunState* s, Config* p) {
|
||||||
// we calloc instead of malloc to keep valgrind happy
|
// we calloc instead of malloc to keep valgrind happy
|
||||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||||
@@ -103,10 +113,7 @@ void free_run_state(RunState* s) {
|
|||||||
free(s->value_cache);
|
free(s->value_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||||
// initialization: read from checkpoint
|
|
||||||
|
|
||||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
|
||||||
int head_size = p->dim / p->n_heads;
|
int head_size = p->dim / p->n_heads;
|
||||||
w->token_embedding_table = ptr;
|
w->token_embedding_table = ptr;
|
||||||
ptr += p->vocab_size * p->dim;
|
ptr += p->vocab_size * p->dim;
|
||||||
@@ -154,11 +161,26 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
|
|||||||
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
|
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
|
||||||
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
|
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
|
||||||
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
|
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
|
||||||
checkpoint_init_weights(weights, config, weights_ptr, shared_weights);
|
memory_map_weights(weights, config, weights_ptr, shared_weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
void build_transformer(char* checkpoint_path, Transformer *t) {
|
||||||
|
// read in the Config and the Weights from the checkpoint
|
||||||
|
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
|
||||||
|
// allocate the RunState buffers
|
||||||
|
malloc_run_state(&t->state, &t->config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_transformer(Transformer* t) {
|
||||||
|
// close the memory mapping
|
||||||
|
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
|
||||||
|
if (t->fd != -1) { close(t->fd); }
|
||||||
|
// free the RunState buffers
|
||||||
|
free_run_state(&t->state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// neural net blocks
|
// neural net blocks; the dynamics of the Transformer
|
||||||
|
|
||||||
void rmsnorm(float* o, float* x, float* weight, int size) {
|
void rmsnorm(float* o, float* x, float* weight, int size) {
|
||||||
// calculate sum of squares
|
// calculate sum of squares
|
||||||
@@ -209,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
|
float* forward(Transformer* transformer, int token, int pos) {
|
||||||
|
|
||||||
// a few convenience variables
|
// a few convenience variables
|
||||||
|
Config* p = &transformer->config;
|
||||||
|
TransformerWeights* w = &transformer->weights;
|
||||||
|
RunState* s = &transformer->state;
|
||||||
float *x = s->x;
|
float *x = s->x;
|
||||||
int dim = p->dim;
|
int dim = p->dim;
|
||||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||||
@@ -338,6 +363,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
|
|
||||||
// classifier into logits
|
// classifier into logits
|
||||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||||
|
return s->logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
@@ -351,7 +377,7 @@ typedef struct {
|
|||||||
char byte_piece[2];
|
char byte_piece[2];
|
||||||
} Tokenizer;
|
} Tokenizer;
|
||||||
|
|
||||||
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
|
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
|
||||||
// i should have written the vocab_size into the tokenizer file... sigh
|
// i should have written the vocab_size into the tokenizer file... sigh
|
||||||
t->vocab_size = vocab_size;
|
t->vocab_size = vocab_size;
|
||||||
// malloc space to hold the scores and the strings
|
// malloc space to hold the scores and the strings
|
||||||
@@ -359,8 +385,8 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
|
|||||||
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
|
||||||
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
|
||||||
// read in the file
|
// read in the file
|
||||||
FILE *file = fopen(tokenizer, "rb");
|
FILE *file = fopen(tokenizer_path, "rb");
|
||||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
|
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
|
||||||
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
|
||||||
int len;
|
int len;
|
||||||
for (int i = 0; i < vocab_size; i++) {
|
for (int i = 0; i < vocab_size; i++) {
|
||||||
@@ -374,9 +400,7 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void free_tokenizer(Tokenizer* t) {
|
void free_tokenizer(Tokenizer* t) {
|
||||||
for (int i = 0; i < t->vocab_size; i++) {
|
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
|
||||||
free(t->vocab[i]);
|
|
||||||
}
|
|
||||||
free(t->vocab);
|
free(t->vocab);
|
||||||
free(t->vocab_scores);
|
free(t->vocab_scores);
|
||||||
}
|
}
|
||||||
@@ -669,26 +693,17 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||||
|
|
||||||
// read in the model.bin file
|
// build the Transformer via the model .bin file
|
||||||
Config config;
|
Transformer transformer;
|
||||||
TransformerWeights weights;
|
build_transformer(checkpoint_path, &transformer);
|
||||||
int fd = 0; // file descriptor for memory mapping
|
int vocab_size = transformer.config.vocab_size; // convenience copy
|
||||||
float* data = NULL; // memory mapped data pointer
|
|
||||||
ssize_t file_size; // size of the checkpoint file in bytes
|
|
||||||
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
|
|
||||||
|
|
||||||
// right now we cannot run for more than config.seq_len steps
|
// build the Tokenizer via the tokenizer .bin file
|
||||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
|
||||||
|
|
||||||
// read in the tokenizer .bin file
|
|
||||||
Tokenizer tokenizer;
|
Tokenizer tokenizer;
|
||||||
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
|
build_tokenizer(tokenizer_path, &tokenizer, vocab_size);
|
||||||
|
|
||||||
// create and init the application RunState
|
// create and init the application RunState
|
||||||
RunState state;
|
ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
||||||
malloc_run_state(&state, &config);
|
|
||||||
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
|
|
||||||
|
|
||||||
// process the prompt, if any
|
// process the prompt, if any
|
||||||
int *prompt_tokens = NULL;
|
int *prompt_tokens = NULL;
|
||||||
int num_prompt_tokens = 0;
|
int num_prompt_tokens = 0;
|
||||||
@@ -705,7 +720,7 @@ int main(int argc, char *argv[]) {
|
|||||||
while (pos < steps) {
|
while (pos < steps) {
|
||||||
|
|
||||||
// forward the transformer to get logits for the next token
|
// forward the transformer to get logits for the next token
|
||||||
transformer(token, pos, &config, &state, &weights);
|
float* logits = forward(&transformer, token, pos);
|
||||||
|
|
||||||
// advance the state state machine
|
// advance the state state machine
|
||||||
if(pos < num_prompt_tokens) {
|
if(pos < num_prompt_tokens) {
|
||||||
@@ -715,19 +730,19 @@ int main(int argc, char *argv[]) {
|
|||||||
// sample the next token
|
// sample the next token
|
||||||
if (temperature == 0.0f) {
|
if (temperature == 0.0f) {
|
||||||
// greedy argmax sampling: take the token with the highest probability
|
// greedy argmax sampling: take the token with the highest probability
|
||||||
next = argmax(state.logits, config.vocab_size);
|
next = argmax(logits, vocab_size);
|
||||||
} else {
|
} else {
|
||||||
// apply the temperature to the logits
|
// apply the temperature to the logits
|
||||||
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
|
||||||
// apply softmax to the logits to get the probabilities for next token
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
softmax(state.logits, config.vocab_size);
|
softmax(logits, vocab_size);
|
||||||
// we sample from this distribution to get the next token
|
// we sample from this distribution to get the next token
|
||||||
if (topp <= 0 || topp >= 1) {
|
if (topp <= 0 || topp >= 1) {
|
||||||
// simply sample from the predicted probability distribution
|
// simply sample from the predicted probability distribution
|
||||||
next = sample(state.logits, config.vocab_size);
|
next = sample(logits, vocab_size);
|
||||||
} else {
|
} else {
|
||||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
next = sample_topp(state.logits, config.vocab_size, topp, probindex);
|
next = sample_topp(logits, vocab_size, topp, probindex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -754,11 +769,9 @@ int main(int argc, char *argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// memory and file handles cleanup
|
// memory and file handles cleanup
|
||||||
free_run_state(&state);
|
|
||||||
free(probindex);
|
free(probindex);
|
||||||
|
if (prompt_tokens != NULL) { free(prompt_tokens); }
|
||||||
free_tokenizer(&tokenizer);
|
free_tokenizer(&tokenizer);
|
||||||
if (prompt_tokens != NULL) free(prompt_tokens);
|
free_transformer(&transformer);
|
||||||
if (data != MAP_FAILED) munmap(data, file_size);
|
|
||||||
if (fd != -1) close(fd);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user