diff --git a/model.py b/model.py index d04fc76..8d76d5a 100644 --- a/model.py +++ b/model.py @@ -195,7 +195,8 @@ class Transformer(nn.Module): self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse - self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) + freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) # init all weights self.apply(self._init_weights) @@ -215,7 +216,6 @@ class Transformer(nn.Module): def forward(self, tokens, targets=None): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[:seqlen] for layer in self.layers: