|
|
|
@@ -39,11 +39,11 @@ typedef struct {
|
|
|
|
|
// weights for rmsnorms
|
|
|
|
|
float* rms_att_weight; // (layer, dim) rmsnorm weights
|
|
|
|
|
float* rms_ffn_weight; // (layer, dim)
|
|
|
|
|
// weights for matmuls
|
|
|
|
|
float* wq; // (layer, dim, dim)
|
|
|
|
|
float* wk; // (layer, dim, dim)
|
|
|
|
|
float* wv; // (layer, dim, dim)
|
|
|
|
|
float* wo; // (layer, dim, dim)
|
|
|
|
|
// weights for matmuls. note dim == n_heads * head_size
|
|
|
|
|
float* wq; // (layer, dim, n_heads * head_size)
|
|
|
|
|
float* wk; // (layer, dim, n_kv_heads * head_size)
|
|
|
|
|
float* wv; // (layer, dim, n_kv_heads * head_size)
|
|
|
|
|
float* wo; // (layer, n_heads * head_size, dim)
|
|
|
|
|
// weights for ffn
|
|
|
|
|
float* w1; // (layer, hidden_dim, dim)
|
|
|
|
|
float* w2; // (layer, dim, hidden_dim)
|
|
|
|
@@ -82,6 +82,7 @@ typedef struct {
|
|
|
|
|
|
|
|
|
|
void malloc_run_state(RunState* s, Config* p) {
|
|
|
|
|
// we calloc instead of malloc to keep valgrind happy
|
|
|
|
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
|
|
|
|
s->x = calloc(p->dim, sizeof(float));
|
|
|
|
|
s->xb = calloc(p->dim, sizeof(float));
|
|
|
|
|
s->xb2 = calloc(p->dim, sizeof(float));
|
|
|
|
@@ -93,8 +94,8 @@ void malloc_run_state(RunState* s, Config* p) {
|
|
|
|
|
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
|
|
|
|
s->logits = calloc(p->vocab_size, sizeof(float));
|
|
|
|
|
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
|
|
|
|
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
|
|
|
|
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
|
|
|
|
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
|
|
|
|
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
|
|
|
|
// ensure all mallocs went fine
|
|
|
|
|
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
|
|
|
|
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
|
|
|
@@ -124,19 +125,20 @@ void free_run_state(RunState* s) {
|
|
|
|
|
// initialization: read from checkpoint
|
|
|
|
|
|
|
|
|
|
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
|
|
|
|
|
int head_size = p->dim / p->n_heads;
|
|
|
|
|
float* ptr = f;
|
|
|
|
|
w->token_embedding_table = ptr;
|
|
|
|
|
ptr += p->vocab_size * p->dim;
|
|
|
|
|
w->rms_att_weight = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim;
|
|
|
|
|
w->wq = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim * p->dim;
|
|
|
|
|
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
|
|
|
|
|
w->wk = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim * p->dim;
|
|
|
|
|
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
|
|
|
|
w->wv = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim * p->dim;
|
|
|
|
|
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
|
|
|
|
w->wo = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim * p->dim;
|
|
|
|
|
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
|
|
|
|
|
w->rms_ffn_weight = ptr;
|
|
|
|
|
ptr += p->n_layers * p->dim;
|
|
|
|
|
w->w1 = ptr;
|
|
|
|
@@ -148,7 +150,6 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int sha
|
|
|
|
|
w->rms_final_weight = ptr;
|
|
|
|
|
ptr += p->dim;
|
|
|
|
|
w->freq_cis_real = ptr;
|
|
|
|
|
int head_size = p->dim / p->n_heads;
|
|
|
|
|
ptr += p->seq_len * head_size / 2;
|
|
|
|
|
w->freq_cis_imag = ptr;
|
|
|
|
|
ptr += p->seq_len * head_size / 2;
|
|
|
|
@@ -218,6 +219,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|
|
|
|
// a few convenience variables
|
|
|
|
|
float *x = s->x;
|
|
|
|
|
int dim = p->dim;
|
|
|
|
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
|
|
|
|
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
|
|
|
|
|
int hidden_dim = p->hidden_dim;
|
|
|
|
|
int head_size = dim / p->n_heads;
|
|
|
|
|
|
|
|
|
@@ -237,29 +240,33 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|
|
|
|
|
|
|
|
|
// qkv matmuls for this position
|
|
|
|
|
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
|
|
|
|
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
|
|
|
|
|
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
|
|
|
|
|
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
|
|
|
|
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
|
|
|
|
|
|
|
|
|
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
|
|
|
|
for (int i = 0; i < dim; i+=2) {
|
|
|
|
|
float q0 = s->q[i];
|
|
|
|
|
float q1 = s->q[i+1];
|
|
|
|
|
float k0 = s->k[i];
|
|
|
|
|
float k1 = s->k[i+1];
|
|
|
|
|
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
|
|
|
|
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
|
|
|
|
s->q[i] = q0 * fcr - q1 * fci;
|
|
|
|
|
s->q[i+1] = q0 * fci + q1 * fcr;
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < kv_dim; i+=2) {
|
|
|
|
|
float k0 = s->k[i];
|
|
|
|
|
float k1 = s->k[i+1];
|
|
|
|
|
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
|
|
|
|
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
|
|
|
|
s->k[i] = k0 * fcr - k1 * fci;
|
|
|
|
|
s->k[i+1] = k0 * fci + k1 * fcr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// save key,value at this time step (pos) to our kv cache
|
|
|
|
|
int loff = l * p->seq_len * dim; // kv cache layer offset for convenience
|
|
|
|
|
float* key_cache_row = s->key_cache + loff + pos * dim;
|
|
|
|
|
float* value_cache_row = s->value_cache + loff + pos * dim;
|
|
|
|
|
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
|
|
|
|
|
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
|
|
|
|
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
|
|
|
|
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
|
|
|
|
|
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
|
|
|
|
|
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
|
|
|
|
|
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
|
|
|
|
|
|
|
|
|
|
// multihead attention. iterate over all heads
|
|
|
|
|
int h;
|
|
|
|
@@ -272,7 +279,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|
|
|
|
// iterate over all timesteps, including the current one
|
|
|
|
|
for (int t = 0; t <= pos; t++) {
|
|
|
|
|
// get the key vector for this head and at this timestep
|
|
|
|
|
float* k = s->key_cache + loff + t * dim + h * head_size;
|
|
|
|
|
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
|
|
|
|
// calculate the attention score as the dot product of q and k
|
|
|
|
|
float score = 0.0f;
|
|
|
|
|
for (int i = 0; i < head_size; i++) {
|
|
|
|
@@ -291,7 +298,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|
|
|
|
memset(xb, 0, head_size * sizeof(float));
|
|
|
|
|
for (int t = 0; t <= pos; t++) {
|
|
|
|
|
// get the value vector for this head and at this timestep
|
|
|
|
|
float* v = s->value_cache + loff + t * dim + h * head_size;
|
|
|
|
|
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
|
|
|
|
// get the attention weight for this timestep
|
|
|
|
|
float a = att[t];
|
|
|
|
|
// accumulate the weighted value into xb
|
|
|
|
|