Merge pull request #85 from python273/export-llama-without-llama
Export llama without llama
This commit is contained in:
+90
-67
@@ -1,91 +1,114 @@
|
|||||||
"""
|
"""
|
||||||
This script exports the Llama 2 weights in llama2c.bin format.
|
This script exports the Llama 2 weights in llama2c.bin format.
|
||||||
|
|
||||||
Place it into the root directory of:
|
|
||||||
https://github.com/facebookresearch/llama
|
|
||||||
|
|
||||||
And then run it similar to their other examples, via torchrun sadly:
|
|
||||||
torchrun --nproc_per_node 1 export_meta_llama_bin.py
|
|
||||||
"""
|
"""
|
||||||
|
import sys
|
||||||
|
import struct
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
from llama import Llama
|
import torch
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
from model import precompute_freqs_cis
|
||||||
def export(self, filepath='model.bin'):
|
|
||||||
|
|
||||||
|
def export(p, state_dict, filepath='model.bin'):
|
||||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||||
|
|
||||||
f = open(filepath, 'wb')
|
f = open(filepath, 'wb')
|
||||||
import struct
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def serialize(t):
|
def serialize(key):
|
||||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
print(f"writing {key}...")
|
||||||
b = struct.pack(f'{len(d)}f', *d)
|
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||||
f.write(b)
|
f.write(memoryview(t))
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
# first write out the header
|
# first write out the header
|
||||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
|
||||||
p = self.params
|
p['vocab_size'] = 32000
|
||||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
p['max_seq_len'] = 2048
|
||||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
|
||||||
n_kv_heads, -p.vocab_size, p.max_seq_len)
|
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||||
|
header = struct.pack(
|
||||||
|
'iiiiiii',
|
||||||
|
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||||
|
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||||
|
)
|
||||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||||
# in the checkpoint and should be loaded.
|
# in the checkpoint and should be loaded.
|
||||||
f.write(header)
|
f.write(header)
|
||||||
|
|
||||||
# next write out the embedding weights
|
# next write out the embedding weights
|
||||||
print("writing tok_embeddings...")
|
print("writing tok_embeddings...")
|
||||||
serialize(self.tok_embeddings.weight)
|
serialize('tok_embeddings.weight')
|
||||||
|
|
||||||
# now all the layers
|
# now all the layers
|
||||||
# attention weights
|
# attention weights
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
|
||||||
print(f"writing attention_norm layer {i}...")
|
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
|
||||||
serialize(layer.attention_norm.weight)
|
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
|
||||||
print(f"writing attention.wq layer {i}...")
|
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
|
||||||
serialize(layer.attention.wq.weight)
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
print(f"writing attention.wk layer {i}...")
|
|
||||||
serialize(layer.attention.wk.weight)
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
print(f"writing attention.wv layer {i}...")
|
|
||||||
serialize(layer.attention.wv.weight)
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
print(f"writing attention.wo layer {i}...")
|
|
||||||
serialize(layer.attention.wo.weight)
|
|
||||||
# ffn weights
|
# ffn weights
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
|
||||||
print(f"writing ffn_norm layer {i}...")
|
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
|
||||||
serialize(layer.ffn_norm.weight)
|
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
|
||||||
for i, layer in enumerate(self.layers):
|
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
|
||||||
print(f"writing feed_forward.w1 layer {i}...")
|
|
||||||
serialize(layer.feed_forward.w1.weight)
|
# final rmsnorm
|
||||||
for i, layer in enumerate(self.layers):
|
serialize('norm.weight')
|
||||||
print(f"writing feed_forward.w2 layer {i}...")
|
# freqs_cis
|
||||||
serialize(layer.feed_forward.w2.weight)
|
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||||
for i, layer in enumerate(self.layers):
|
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
|
||||||
print(f"writing feed_forward.w3 layer {i}...")
|
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
|
||||||
serialize(layer.feed_forward.w3.weight)
|
serialize('freqs_cis.real')
|
||||||
# final rmsnorm
|
serialize('freqs_cis.imag')
|
||||||
print("writing final rmsnorm, classifier and freq_cis...")
|
|
||||||
serialize(self.norm.weight)
|
# finally write the output weights
|
||||||
# freqs_cis
|
serialize('output.weight')
|
||||||
serialize(self.freqs_cis.real[:p.max_seq_len])
|
|
||||||
serialize(self.freqs_cis.imag[:p.max_seq_len])
|
|
||||||
# finally write the output weights
|
|
||||||
serialize(self.output.weight)
|
|
||||||
|
|
||||||
# write to binary file
|
|
||||||
f.close()
|
f.close()
|
||||||
print(f"wrote {filepath}")
|
print(f"wrote {filepath}")
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# init Llama as normal
|
|
||||||
generator = Llama.build(
|
def concat_weights(models):
|
||||||
ckpt_dir="llama-2-7b",
|
state_dict = {}
|
||||||
tokenizer_path="tokenizer.model",
|
for name in list(models[0]):
|
||||||
max_seq_len=4096,
|
tensors = [model[name] for model in models]
|
||||||
max_batch_size=1,
|
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||||
)
|
state_dict[name] = tensors[0]
|
||||||
export(generator.model, "llama2_7b.bin")
|
continue
|
||||||
|
is_axis_1 = (
|
||||||
|
name.startswith('tok_embeddings.')
|
||||||
|
or name.endswith('.attention.wo.weight')
|
||||||
|
or name.endswith('.feed_forward.w2.weight')
|
||||||
|
)
|
||||||
|
axis = 1 if is_axis_1 else 0
|
||||||
|
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||||
|
for model in models:
|
||||||
|
del model[name]
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_export(model_path, output_path):
|
||||||
|
with open(model_path + 'params.json') as f:
|
||||||
|
params = json.load(f)
|
||||||
|
print(params)
|
||||||
|
|
||||||
|
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||||
|
models = []
|
||||||
|
for i in model_paths:
|
||||||
|
print(f'Loading {i}')
|
||||||
|
models.append(torch.load(i, map_location='cpu'))
|
||||||
|
|
||||||
|
state_dict = concat_weights(models)
|
||||||
|
del models
|
||||||
|
export(params, state_dict, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
print('[Llama model folder path] [output path]')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model_path = sys.argv[1]
|
||||||
|
output_path = sys.argv[2]
|
||||||
|
load_and_export(model_path, output_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user