diff --git a/run.c b/run.c index 2f29f35..22f3fa5 100644 --- a/run.c +++ b/run.c @@ -232,24 +232,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim); matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim); - // apply RoPE rotation to the q and k vectors for each head - for (int h = 0; h < p->n_heads; h++) { - // get the q and k vectors for this head - float* q = s->q + h * head_size; - float* k = s->k + h * head_size; - // rotate q and k by the freq_cis_real and freq_cis_imag - for (int i = 0; i < head_size; i+=2) { - float q0 = q[i]; - float q1 = q[i+1]; - float k0 = k[i]; - float k1 = k[i+1]; - float fcr = freq_cis_real_row[i/2]; - float fci = freq_cis_imag_row[i/2]; - q[i] = q0 * fcr - q1 * fci; - q[i+1] = q0 * fci + q1 * fcr; - k[i] = k0 * fcr - k1 * fci; - k[i+1] = k0 * fci + k1 * fcr; - } + // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head + for (int i = 0; i < dim; i+=2) { + float q0 = s->q[i]; + float q1 = s->q[i+1]; + float k0 = s->k[i]; + float k1 = s->k[i+1]; + float fcr = freq_cis_real_row[(i % head_size) / 2]; + float fci = freq_cis_imag_row[(i % head_size) / 2]; + s->q[i] = q0 * fcr - q1 * fci; + s->q[i+1] = q0 * fci + q1 * fcr; + s->k[i] = k0 * fcr - k1 * fci; + s->k[i+1] = k0 * fci + k1 * fcr; } // save key,value at this time step (pos) to our kv cache