git push origin masterMerge branch 'admu-progvar-master'
This commit is contained in:
@@ -60,7 +60,7 @@ typedef struct {
|
|||||||
float *q; // query (dim,)
|
float *q; // query (dim,)
|
||||||
float *k; // key (dim,)
|
float *k; // key (dim,)
|
||||||
float *v; // value (dim,)
|
float *v; // value (dim,)
|
||||||
float *att; // buffer for scores/attention values (seq_len,)
|
float *att; // buffer for scores/attention values (n_heads, seq_len)
|
||||||
float *logits; // output logits
|
float *logits; // output logits
|
||||||
// kv cache
|
// kv cache
|
||||||
float* key_cache; // (layer, seq_len, dim)
|
float* key_cache; // (layer, seq_len, dim)
|
||||||
@@ -77,7 +77,7 @@ void malloc_run_state(RunState* s, Config* p) {
|
|||||||
s->q = calloc(p->dim, sizeof(float));
|
s->q = calloc(p->dim, sizeof(float));
|
||||||
s->k = calloc(p->dim, sizeof(float));
|
s->k = calloc(p->dim, sizeof(float));
|
||||||
s->v = calloc(p->dim, sizeof(float));
|
s->v = calloc(p->dim, sizeof(float));
|
||||||
s->att = calloc(p->seq_len, sizeof(float));
|
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||||
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||||
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||||
@@ -278,9 +278,12 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
||||||
|
|
||||||
// multihead attention. iterate over all heads
|
// multihead attention. iterate over all heads
|
||||||
|
#pragma omp parallel for
|
||||||
for (int h = 0; h < p->n_heads; h++) {
|
for (int h = 0; h < p->n_heads; h++) {
|
||||||
// get the query vector for this head
|
// get the query vector for this head
|
||||||
float* q = s->q + h * head_size;
|
float* q = s->q + h * head_size;
|
||||||
|
// attention scores for this head
|
||||||
|
float* att = s->att + h * p->seq_len;
|
||||||
// iterate over all timesteps, including the current one
|
// iterate over all timesteps, including the current one
|
||||||
for (int t = 0; t <= pos; t++) {
|
for (int t = 0; t <= pos; t++) {
|
||||||
// get the key vector for this head and at this timestep
|
// get the key vector for this head and at this timestep
|
||||||
@@ -292,17 +295,17 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
}
|
}
|
||||||
score /= sqrtf(head_size);
|
score /= sqrtf(head_size);
|
||||||
// save the score to the attention buffer
|
// save the score to the attention buffer
|
||||||
s->att[t] = score;
|
att[t] = score;
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax the scores to get attention weights, from 0..pos inclusively
|
// softmax the scores to get attention weights, from 0..pos inclusively
|
||||||
softmax(s->att, pos + 1);
|
softmax(att, pos + 1);
|
||||||
|
|
||||||
// weighted sum of the values, store back into xb
|
// weighted sum of the values, store back into xb
|
||||||
for (int i = 0; i < head_size; i++) {
|
for (int i = 0; i < head_size; i++) {
|
||||||
float val = 0.0f;
|
float val = 0.0f;
|
||||||
for (int t = 0; t <= pos; t++) {
|
for (int t = 0; t <= pos; t++) {
|
||||||
val += s->att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality
|
val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality
|
||||||
}
|
}
|
||||||
s->xb[h * head_size + i] = val;
|
s->xb[h * head_size + i] = val;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user