Small changes to ROPE & comments

This commit is contained in:
rahulschand
2023-08-03 20:13:50 +05:30
parent af8708d87b
commit 02cf3c7311
2 changed files with 3 additions and 3 deletions
+1 -1
View File
@@ -216,7 +216,7 @@ class Transformer(nn.Module):
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
+2 -2
View File
@@ -51,8 +51,8 @@ typedef struct {
// final rmsnorm
float* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
float* freq_cis_real; // (seq_len, head_size/2)
float* freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;