diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index 53ca652..41c1705 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -57,10 +57,10 @@ def export(p, state_dict, filepath='model.bin'): serialize('norm.weight') # freqs_cos, freqs_sin freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cis.real'] = freqs_cos[:p['max_seq_len']] - state_dict['freqs_cis.imag'] = freqs_sin[:p['max_seq_len']] - serialize('freqs_cis.real') - serialize('freqs_cis.imag') + state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']] + state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']] + serialize('freqs_cos') + serialize('freqs_sin') # finally write the output weights serialize('output.weight') diff --git a/model.py b/model.py index cafbbd6..1600f5b 100644 --- a/model.py +++ b/model.py @@ -376,8 +376,8 @@ class Transformer(nn.Module): serialize(self.norm.weight) # note: no need to write final classifier weights due to weight sharing # freqs_cis - serialize(self.freqs_cis.real[:p.max_seq_len]) - serialize(self.freqs_cis.imag[:p.max_seq_len]) + serialize(self.freqs_cos[:p.max_seq_len]) + serialize(self.freqs_sin[:p.max_seq_len]) # write to binary file f.close()