simplify rope
This commit is contained in:
@@ -232,24 +232,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
|
||||
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
|
||||
|
||||
// apply RoPE rotation to the q and k vectors for each head
|
||||
for (int h = 0; h < p->n_heads; h++) {
|
||||
// get the q and k vectors for this head
|
||||
float* q = s->q + h * head_size;
|
||||
float* k = s->k + h * head_size;
|
||||
// rotate q and k by the freq_cis_real and freq_cis_imag
|
||||
for (int i = 0; i < head_size; i+=2) {
|
||||
float q0 = q[i];
|
||||
float q1 = q[i+1];
|
||||
float k0 = k[i];
|
||||
float k1 = k[i+1];
|
||||
float fcr = freq_cis_real_row[i/2];
|
||||
float fci = freq_cis_imag_row[i/2];
|
||||
q[i] = q0 * fcr - q1 * fci;
|
||||
q[i+1] = q0 * fci + q1 * fcr;
|
||||
k[i] = k0 * fcr - k1 * fci;
|
||||
k[i+1] = k0 * fci + k1 * fcr;
|
||||
}
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
float q0 = s->q[i];
|
||||
float q1 = s->q[i+1];
|
||||
float k0 = s->k[i];
|
||||
float k1 = s->k[i+1];
|
||||
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
||||
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
||||
s->q[i] = q0 * fcr - q1 * fci;
|
||||
s->q[i+1] = q0 * fci + q1 * fcr;
|
||||
s->k[i] = k0 * fcr - k1 * fci;
|
||||
s->k[i+1] = k0 * fci + k1 * fcr;
|
||||
}
|
||||
|
||||
// save key,value at this time step (pos) to our kv cache
|
||||
|
||||
Reference in New Issue
Block a user