Added load_meta_model() to export.py

This commit is contained in:
atamyrat
2023-08-21 14:13:47 +03:00
parent dd61b13e57
commit de005474d3
+75 -5
View File
@@ -19,6 +19,9 @@ import gzip
import shutil import shutil
import struct import struct
import argparse import argparse
import json
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@@ -30,7 +33,7 @@ from model import ModelArgs, Transformer
def serialize_fp32(file, tensor): def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """ """ writes one fp32 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d) b = struct.pack(f'{len(d)}f', *d)
file.write(b) file.write(b)
@@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
model.eval() model.eval()
return model return model
def load_meta_model(model_path):
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [torch.load(p, map_location='cpu') for p in model_paths]
def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict
state_dict = concat_weights(models)
del models
# set ModelArgs
config = ModelArgs()
config.dim = params["dim"]
config.n_layers = params["n_layers"]
config.n_heads = params["n_heads"]
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
config.multiple_of = params["multiple_of"]
config.norm_eps = params["norm_eps"]
config.vocab_size = 32000
config.max_seq_len = 2048
# create a new Transformer object and set weights
model = Transformer(config)
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
# final classifier
model.output.weight = nn.Parameter(state_dict['output.weight'])
model.eval()
return model
def load_hf_model(model_path): def load_hf_model(model_path):
try: try:
@@ -381,17 +449,19 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath") parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
parser.add_argument("--hf", type=str, help="huggingface model")
parser.add_argument("--version", default=0, type=int, help="the version to export with") parser.add_argument("--version", default=0, type=int, help="the version to export with")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args() args = parser.parse_args()
if args.checkpoint: if args.checkpoint:
model = load_checkpoint(args.checkpoint) model = load_checkpoint(args.checkpoint)
elif args.meta_llama:
model = load_meta_model(args.meta_llama)
elif args.hf: elif args.hf:
model = load_hf_model(args.hf) model = load_hf_model(args.hf)
else:
parser.error("Input model missing: --checkpoint or --hf is required")
if model is None: if model is None:
parser.error("Can't load input model!") parser.error("Can't load input model!")