fix freq_cos, freq_sin serialize
This commit is contained in:
@@ -57,10 +57,10 @@ def export(p, state_dict, filepath='model.bin'):
|
|||||||
serialize('norm.weight')
|
serialize('norm.weight')
|
||||||
# freqs_cos, freqs_sin
|
# freqs_cos, freqs_sin
|
||||||
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||||
state_dict['freqs_cis.real'] = freqs_cos[:p['max_seq_len']]
|
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
|
||||||
state_dict['freqs_cis.imag'] = freqs_sin[:p['max_seq_len']]
|
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
|
||||||
serialize('freqs_cis.real')
|
serialize('freqs_cos')
|
||||||
serialize('freqs_cis.imag')
|
serialize('freqs_sin')
|
||||||
|
|
||||||
# finally write the output weights
|
# finally write the output weights
|
||||||
serialize('output.weight')
|
serialize('output.weight')
|
||||||
|
|||||||
@@ -376,8 +376,8 @@ class Transformer(nn.Module):
|
|||||||
serialize(self.norm.weight)
|
serialize(self.norm.weight)
|
||||||
# note: no need to write final classifier weights due to weight sharing
|
# note: no need to write final classifier weights due to weight sharing
|
||||||
# freqs_cis
|
# freqs_cis
|
||||||
serialize(self.freqs_cis.real[:p.max_seq_len])
|
serialize(self.freqs_cos[:p.max_seq_len])
|
||||||
serialize(self.freqs_cis.imag[:p.max_seq_len])
|
serialize(self.freqs_sin[:p.max_seq_len])
|
||||||
|
|
||||||
# write to binary file
|
# write to binary file
|
||||||
f.close()
|
f.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user