Improve locality

This commit is contained in:
aegkmq
2023-07-26 13:24:27 +09:00
parent f5650891d5
commit 36c522a0d8
+7 -5
View File
@@ -279,12 +279,14 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
softmax(att, pos + 1); softmax(att, pos + 1);
// weighted sum of the values, store back into xb // weighted sum of the values, store back into xb
for (int i = 0; i < head_size; i++) { float* xb = s->xb + h * head_size;
float val = 0.0f; memset(xb, 0, head_size * sizeof(float));
for (int t = 0; t <= pos; t++) { for (int t = 0; t <= pos; t += 1) {
val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality float* v = s->value_cache + loff + t * dim + h * head_size;
float a = s->att[h * p->seq_len + t];
for (int i = 0; i < head_size; i += 1) {
xb[i] += a * v[i];
} }
s->xb[h * head_size + i] = val;
} }
} }