diff --git a/run.c b/run.c index 8f565cd..fb9a428 100644 --- a/run.c +++ b/run.c @@ -239,21 +239,17 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head - for (int i = 0; i < dim; i+=2) { - float q0 = s->q[i]; - float q1 = s->q[i+1]; - float fcr = freq_cis_real_row[(i % head_size) / 2]; - float fci = freq_cis_imag_row[(i % head_size) / 2]; - s->q[i] = q0 * fcr - q1 * fci; - s->q[i+1] = q0 * fci + q1 * fcr; - } - for (int i = 0; i < kv_dim; i+=2) { - float k0 = s->k[i]; - float k1 = s->k[i+1]; - float fcr = freq_cis_real_row[(i % head_size) / 2]; - float fci = freq_cis_imag_row[(i % head_size) / 2]; - s->k[i] = k0 * fcr - k1 * fci; - s->k[i+1] = k0 * fci + k1 * fcr; + for (int v = 0; v < 2; v++) { + float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) + int vec_size = v == 0 ? dim : kv_dim; // the size of the vector + for (int i = 0; i < vec_size; i+=2) { + float v0 = vec[i]; + float v1 = vec[i+1]; + float fcr = freq_cis_real_row[(i % head_size) / 2]; + float fci = freq_cis_imag_row[(i % head_size) / 2]; + vec[i] = v0 * fcr - v1 * fci; + vec[i+1] = v0 * fci + v1 * fcr; + } } // save key,value at this time step (pos) to our kv cache