collapsing copy paste code because it's driving my ocd crazy
This commit is contained in:
@@ -239,21 +239,17 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
|||||||
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
||||||
|
|
||||||
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
||||||
for (int i = 0; i < dim; i+=2) {
|
for (int v = 0; v < 2; v++) {
|
||||||
float q0 = s->q[i];
|
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
|
||||||
float q1 = s->q[i+1];
|
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
|
||||||
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
for (int i = 0; i < vec_size; i+=2) {
|
||||||
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
float v0 = vec[i];
|
||||||
s->q[i] = q0 * fcr - q1 * fci;
|
float v1 = vec[i+1];
|
||||||
s->q[i+1] = q0 * fci + q1 * fcr;
|
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
||||||
}
|
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
||||||
for (int i = 0; i < kv_dim; i+=2) {
|
vec[i] = v0 * fcr - v1 * fci;
|
||||||
float k0 = s->k[i];
|
vec[i+1] = v0 * fci + v1 * fcr;
|
||||||
float k1 = s->k[i+1];
|
}
|
||||||
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
|
||||||
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
|
||||||
s->k[i] = k0 * fcr - k1 * fci;
|
|
||||||
s->k[i+1] = k0 * fci + k1 * fcr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// save key,value at this time step (pos) to our kv cache
|
// save key,value at this time step (pos) to our kv cache
|
||||||
|
|||||||
Reference in New Issue
Block a user