Small changes to ROPE & comments
This commit is contained in:
@@ -216,7 +216,7 @@ class Transformer(nn.Module):
|
||||
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
|
||||
|
||||
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ typedef struct {
|
||||
// final rmsnorm
|
||||
float* rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
float* freq_cis_real; // (seq_len, dim/2)
|
||||
float* freq_cis_imag; // (seq_len, dim/2)
|
||||
float* freq_cis_real; // (seq_len, head_size/2)
|
||||
float* freq_cis_imag; // (seq_len, head_size/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
float* wcls;
|
||||
} TransformerWeights;
|
||||
|
||||
Reference in New Issue
Block a user