diff --git a/model.py b/model.py index 1600f5b..a246087 100644 --- a/model.py +++ b/model.py @@ -216,7 +216,7 @@ class Transformer(nn.Module): self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse - freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) + freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) diff --git a/run.c b/run.c index 5d2c487..16ecebb 100644 --- a/run.c +++ b/run.c @@ -51,8 +51,8 @@ typedef struct { // final rmsnorm float* rms_final_weight; // (dim,) // freq_cis for RoPE relatively positional embeddings - float* freq_cis_real; // (seq_len, dim/2) - float* freq_cis_imag; // (seq_len, dim/2) + float* freq_cis_real; // (seq_len, head_size/2) + float* freq_cis_imag; // (seq_len, head_size/2) // (optional) classifier weights for the logits, on the last layer float* wcls; } TransformerWeights;