diff --git a/run.c b/run.c index 15352ae..3193d89 100644 --- a/run.c +++ b/run.c @@ -60,7 +60,7 @@ typedef struct { float *q; // query (dim,) float *k; // key (dim,) float *v; // value (dim,) - float *att; // buffer for scores/attention values (seq_len,) + float *att; // buffer for scores/attention values (n_heads, seq_len) float *logits; // output logits // kv cache float* key_cache; // (layer, seq_len, dim) @@ -77,7 +77,7 @@ void malloc_run_state(RunState* s, Config* p) { s->q = calloc(p->dim, sizeof(float)); s->k = calloc(p->dim, sizeof(float)); s->v = calloc(p->dim, sizeof(float)); - s->att = calloc(p->seq_len, sizeof(float)); + s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); @@ -278,9 +278,12 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); // multihead attention. iterate over all heads + #pragma omp parallel for for (int h = 0; h < p->n_heads; h++) { // get the query vector for this head float* q = s->q + h * head_size; + // attention scores for this head + float* att = s->att + h * p->seq_len; // iterate over all timesteps, including the current one for (int t = 0; t <= pos; t++) { // get the key vector for this head and at this timestep @@ -292,17 +295,17 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* } score /= sqrtf(head_size); // save the score to the attention buffer - s->att[t] = score; + att[t] = score; } // softmax the scores to get attention weights, from 0..pos inclusively - softmax(s->att, pos + 1); + softmax(att, pos + 1); // weighted sum of the values, store back into xb for (int i = 0; i < head_size; i++) { float val = 0.0f; for (int t = 0; t <= pos; t++) { - val += s->att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality + val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality } s->xb[h * head_size + i] = val; }