diff --git a/model.py b/model.py index a246087..9ca91fe 100644 --- a/model.py +++ b/model.py @@ -215,7 +215,7 @@ class Transformer(nn.Module): # share the unembedding parameters with the embedding parameters self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying - # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse + # some useful precompute for the RoPE relative positional embeddings freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False)