collapsing copy paste code because it's driving my ocd crazy

This commit is contained in:
Andrej Karpathy
2023-08-15 16:03:11 +00:00
parent 88eb238255
commit a47f9b3969
+11 -15
View File
@@ -239,21 +239,17 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
for (int i = 0; i < dim; i+=2) {
float q0 = s->q[i];
float q1 = s->q[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->q[i] = q0 * fcr - q1 * fci;
s->q[i+1] = q0 * fci + q1 * fcr;
}
for (int i = 0; i < kv_dim; i+=2) {
float k0 = s->k[i];
float k1 = s->k[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->k[i] = k0 * fcr - k1 * fci;
s->k[i+1] = k0 * fci + k1 * fcr;
for (int v = 0; v < 2; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
for (int i = 0; i < vec_size; i+=2) {
float v0 = vec[i];
float v1 = vec[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
// save key,value at this time step (pos) to our kv cache