diff --git a/run.c b/run.c index fc24dbd..809d96c 100644 --- a/run.c +++ b/run.c @@ -277,14 +277,16 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // softmax the scores to get attention weights, from 0..pos inclusively softmax(att, pos + 1); - + // weighted sum of the values, store back into xb - for (int i = 0; i < head_size; i++) { - float val = 0.0f; - for (int t = 0; t <= pos; t++) { - val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality + float* xb = s->xb + h * head_size; + memset(xb, 0, head_size * sizeof(float)); + for (int t = 0; t <= pos; t += 1) { + float* v = s->value_cache + loff + t * dim + h * head_size; + float a = s->att[h * p->seq_len + t]; + for (int i = 0; i < head_size; i += 1) { + xb[i] += a * v[i]; } - s->xb[h * head_size + i] = val; } }