Merge pull request #85 from python273/export-llama-without-llama
Export llama without llama
This commit is contained in:
+91
-68
@@ -1,91 +1,114 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
|
||||
Place it into the root directory of:
|
||||
https://github.com/facebookresearch/llama
|
||||
|
||||
And then run it similar to their other examples, via torchrun sadly:
|
||||
torchrun --nproc_per_node 1 export_meta_llama_bin.py
|
||||
"""
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from llama import Llama
|
||||
import torch
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def export(self, filepath='model.bin'):
|
||||
from model import precompute_freqs_cis
|
||||
|
||||
|
||||
def export(p, state_dict, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
|
||||
f = open(filepath, 'wb')
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
def serialize(t):
|
||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
f.write(b)
|
||||
def serialize(key):
|
||||
print(f"writing {key}...")
|
||||
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||
f.write(memoryview(t))
|
||||
del state_dict[key]
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||
p = self.params
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, -p.vocab_size, p.max_seq_len)
|
||||
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
|
||||
p['vocab_size'] = 32000
|
||||
p['max_seq_len'] = 2048
|
||||
|
||||
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||
header = struct.pack(
|
||||
'iiiiiii',
|
||||
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||
)
|
||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||
# in the checkpoint and should be loaded.
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
print("writing tok_embeddings...")
|
||||
serialize(self.tok_embeddings.weight)
|
||||
|
||||
serialize('tok_embeddings.weight')
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing attention_norm layer {i}...")
|
||||
serialize(layer.attention_norm.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing attention.wq layer {i}...")
|
||||
serialize(layer.attention.wq.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing attention.wk layer {i}...")
|
||||
serialize(layer.attention.wk.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing attention.wv layer {i}...")
|
||||
serialize(layer.attention.wv.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing attention.wo layer {i}...")
|
||||
serialize(layer.attention.wo.weight)
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
|
||||
# ffn weights
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing ffn_norm layer {i}...")
|
||||
serialize(layer.ffn_norm.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing feed_forward.w1 layer {i}...")
|
||||
serialize(layer.feed_forward.w1.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing feed_forward.w2 layer {i}...")
|
||||
serialize(layer.feed_forward.w2.weight)
|
||||
for i, layer in enumerate(self.layers):
|
||||
print(f"writing feed_forward.w3 layer {i}...")
|
||||
serialize(layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
print("writing final rmsnorm, classifier and freq_cis...")
|
||||
serialize(self.norm.weight)
|
||||
# freqs_cis
|
||||
serialize(self.freqs_cis.real[:p.max_seq_len])
|
||||
serialize(self.freqs_cis.imag[:p.max_seq_len])
|
||||
# finally write the output weights
|
||||
serialize(self.output.weight)
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
|
||||
|
||||
# final rmsnorm
|
||||
serialize('norm.weight')
|
||||
# freqs_cis
|
||||
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
|
||||
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
|
||||
serialize('freqs_cis.real')
|
||||
serialize('freqs_cis.imag')
|
||||
|
||||
# finally write the output weights
|
||||
serialize('output.weight')
|
||||
|
||||
# write to binary file
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# init Llama as normal
|
||||
generator = Llama.build(
|
||||
ckpt_dir="llama-2-7b",
|
||||
tokenizer_path="tokenizer.model",
|
||||
max_seq_len=4096,
|
||||
max_batch_size=1,
|
||||
)
|
||||
export(generator.model, "llama2_7b.bin")
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
with open(model_path + 'params.json') as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = []
|
||||
for i in model_paths:
|
||||
print(f'Loading {i}')
|
||||
models.append(torch.load(i, map_location='cpu'))
|
||||
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print('[Llama model folder path] [output path]')
|
||||
exit()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
load_and_export(model_path, output_path)
|
||||
|
||||
Reference in New Issue
Block a user