Added load_meta_model() to export.py
This commit is contained in:
@@ -19,6 +19,9 @@ import gzip
|
||||
import shutil
|
||||
import struct
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -30,7 +33,7 @@ from model import ModelArgs, Transformer
|
||||
|
||||
def serialize_fp32(file, tensor):
|
||||
""" writes one fp32 tensor to file that is open in wb mode """
|
||||
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
file.write(b)
|
||||
|
||||
@@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_meta_model(model_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
|
||||
# set ModelArgs
|
||||
config = ModelArgs()
|
||||
config.dim = params["dim"]
|
||||
config.n_layers = params["n_layers"]
|
||||
config.n_heads = params["n_heads"]
|
||||
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
|
||||
config.multiple_of = params["multiple_of"]
|
||||
config.norm_eps = params["norm_eps"]
|
||||
|
||||
config.vocab_size = 32000
|
||||
config.max_seq_len = 2048
|
||||
|
||||
# create a new Transformer object and set weights
|
||||
model = Transformer(config)
|
||||
|
||||
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
|
||||
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
|
||||
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
|
||||
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
|
||||
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
|
||||
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
|
||||
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
|
||||
|
||||
# final classifier
|
||||
model.output.weight = nn.Parameter(state_dict['output.weight'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_hf_model(model_path):
|
||||
|
||||
try:
|
||||
@@ -381,17 +449,19 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
parser.add_argument("--hf", type=str, help="huggingface model")
|
||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
group.add_argument("--meta-llama", type=str, help="meta llama model path")
|
||||
group.add_argument("--hf", type=str, help="huggingface model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint:
|
||||
model = load_checkpoint(args.checkpoint)
|
||||
elif args.meta_llama:
|
||||
model = load_meta_model(args.meta_llama)
|
||||
elif args.hf:
|
||||
model = load_hf_model(args.hf)
|
||||
else:
|
||||
parser.error("Input model missing: --checkpoint or --hf is required")
|
||||
|
||||
if model is None:
|
||||
parser.error("Can't load input model!")
|
||||
|
||||
Reference in New Issue
Block a user