diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index 801077b..53ca652 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -55,10 +55,10 @@ def export(p, state_dict, filepath='model.bin'): # final rmsnorm serialize('norm.weight') - # freqs_cis - freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']] - state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']] + # freqs_cos, freqs_sin + freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) + state_dict['freqs_cis.real'] = freqs_cos[:p['max_seq_len']] + state_dict['freqs_cis.imag'] = freqs_sin[:p['max_seq_len']] serialize('freqs_cis.real') serialize('freqs_cis.imag')