calculate the freq_cis online, no need to write/read them to/from checkpoints
This commit is contained in:
@@ -43,7 +43,7 @@ typedef struct {
|
||||
float* w3; // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
float* rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
// freq_cis for RoPE relatively positional embeddings (not used anymore)
|
||||
float* freq_cis_real; // (seq_len, head_size/2)
|
||||
float* freq_cis_imag; // (seq_len, head_size/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
@@ -214,10 +214,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
float* content_row = &(w->token_embedding_table[token * dim]);
|
||||
memcpy(x, content_row, dim*sizeof(*x));
|
||||
|
||||
// pluck out the "pos" row of freq_cis_real and freq_cis_imag
|
||||
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
|
||||
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
|
||||
|
||||
// forward all the layers
|
||||
for(int l = 0; l < p->n_layers; l++) {
|
||||
|
||||
@@ -229,15 +225,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
||||
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
||||
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
||||
for (int v = 0; v < 2; v++) {
|
||||
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
|
||||
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
|
||||
for (int i = 0; i < vec_size; i+=2) {
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
int head_dim = i % head_size;
|
||||
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
|
||||
float val = pos * freq;
|
||||
float fcr = cosf(val);
|
||||
float fci = sinf(val);
|
||||
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
|
||||
for (int v = 0; v < rotn; v++) {
|
||||
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
|
||||
float v0 = vec[i];
|
||||
float v1 = vec[i+1];
|
||||
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
||||
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
||||
vec[i] = v0 * fcr - v1 * fci;
|
||||
vec[i+1] = v0 * fci + v1 * fcr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user