Improve locality
This commit is contained in:
@@ -277,14 +277,16 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
|
|
||||||
// softmax the scores to get attention weights, from 0..pos inclusively
|
// softmax the scores to get attention weights, from 0..pos inclusively
|
||||||
softmax(att, pos + 1);
|
softmax(att, pos + 1);
|
||||||
|
|
||||||
// weighted sum of the values, store back into xb
|
// weighted sum of the values, store back into xb
|
||||||
for (int i = 0; i < head_size; i++) {
|
float* xb = s->xb + h * head_size;
|
||||||
float val = 0.0f;
|
memset(xb, 0, head_size * sizeof(float));
|
||||||
for (int t = 0; t <= pos; t++) {
|
for (int t = 0; t <= pos; t += 1) {
|
||||||
val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality
|
float* v = s->value_cache + loff + t * dim + h * head_size;
|
||||||
|
float a = s->att[h * p->seq_len + t];
|
||||||
|
for (int i = 0; i < head_size; i += 1) {
|
||||||
|
xb[i] += a * v[i];
|
||||||
}
|
}
|
||||||
s->xb[h * head_size + i] = val;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user