diff --git a/run.c b/run.c index 513eda9..10d468b 100644 --- a/run.c +++ b/run.c @@ -43,7 +43,7 @@ typedef struct { float* w3; // (layer, hidden_dim, dim) // final rmsnorm float* rms_final_weight; // (dim,) - // freq_cis for RoPE relatively positional embeddings + // freq_cis for RoPE relatively positional embeddings (not used anymore) float* freq_cis_real; // (seq_len, head_size/2) float* freq_cis_imag; // (seq_len, head_size/2) // (optional) classifier weights for the logits, on the last layer @@ -214,10 +214,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* float* content_row = &(w->token_embedding_table[token * dim]); memcpy(x, content_row, dim*sizeof(*x)); - // pluck out the "pos" row of freq_cis_real and freq_cis_imag - float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2; - float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2; - // forward all the layers for(int l = 0; l < p->n_layers; l++) { @@ -229,15 +225,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); - // RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head - for (int v = 0; v < 2; v++) { - float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) - int vec_size = v == 0 ? dim : kv_dim; // the size of the vector - for (int i = 0; i < vec_size; i+=2) { + // RoPE relative positional encoding: complex-valued rotate q and k in each head + for (int i = 0; i < dim; i+=2) { + int head_dim = i % head_size; + float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); + float val = pos * freq; + float fcr = cosf(val); + float fci = sinf(val); + int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + for (int v = 0; v < rotn; v++) { + float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) float v0 = vec[i]; float v1 = vec[i+1]; - float fcr = freq_cis_real_row[(i % head_size) / 2]; - float fci = freq_cis_imag_row[(i % head_size) / 2]; vec[i] = v0 * fcr - v1 * fci; vec[i+1] = v0 * fci + v1 * fcr; }