Register freqs_cis as non-persistent buffer
This commit is contained in:
@@ -195,7 +195,8 @@ class Transformer(nn.Module):
|
|||||||
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
||||||
|
|
||||||
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
|
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
|
||||||
self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
|
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
|
||||||
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
# init all weights
|
# init all weights
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
@@ -215,7 +216,6 @@ class Transformer(nn.Module):
|
|||||||
def forward(self, tokens, targets=None):
|
def forward(self, tokens, targets=None):
|
||||||
_bsz, seqlen = tokens.shape
|
_bsz, seqlen = tokens.shape
|
||||||
h = self.tok_embeddings(tokens)
|
h = self.tok_embeddings(tokens)
|
||||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
|
||||||
freqs_cis = self.freqs_cis[:seqlen]
|
freqs_cis = self.freqs_cis[:seqlen]
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
|
|||||||
Reference in New Issue
Block a user