Added load_meta_model() to export.py
This commit is contained in:
@@ -19,6 +19,9 @@ import gzip
|
|||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -30,7 +33,7 @@ from model import ModelArgs, Transformer
|
|||||||
|
|
||||||
def serialize_fp32(file, tensor):
|
def serialize_fp32(file, tensor):
|
||||||
""" writes one fp32 tensor to file that is open in wb mode """
|
""" writes one fp32 tensor to file that is open in wb mode """
|
||||||
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
|
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
|
||||||
b = struct.pack(f'{len(d)}f', *d)
|
b = struct.pack(f'{len(d)}f', *d)
|
||||||
file.write(b)
|
file.write(b)
|
||||||
|
|
||||||
@@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
|
|||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_meta_model(model_path):
|
||||||
|
params_path = os.path.join(model_path, 'params.json')
|
||||||
|
with open(params_path) as f:
|
||||||
|
params = json.load(f)
|
||||||
|
print(params)
|
||||||
|
|
||||||
|
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||||
|
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||||
|
|
||||||
|
def concat_weights(models):
|
||||||
|
state_dict = {}
|
||||||
|
for name in list(models[0]):
|
||||||
|
tensors = [model[name] for model in models]
|
||||||
|
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||||
|
state_dict[name] = tensors[0]
|
||||||
|
continue
|
||||||
|
is_axis_1 = (
|
||||||
|
name.startswith('tok_embeddings.')
|
||||||
|
or name.endswith('.attention.wo.weight')
|
||||||
|
or name.endswith('.feed_forward.w2.weight')
|
||||||
|
)
|
||||||
|
axis = 1 if is_axis_1 else 0
|
||||||
|
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||||
|
for model in models:
|
||||||
|
del model[name]
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
state_dict = concat_weights(models)
|
||||||
|
del models
|
||||||
|
|
||||||
|
# set ModelArgs
|
||||||
|
config = ModelArgs()
|
||||||
|
config.dim = params["dim"]
|
||||||
|
config.n_layers = params["n_layers"]
|
||||||
|
config.n_heads = params["n_heads"]
|
||||||
|
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
|
||||||
|
config.multiple_of = params["multiple_of"]
|
||||||
|
config.norm_eps = params["norm_eps"]
|
||||||
|
|
||||||
|
config.vocab_size = 32000
|
||||||
|
config.max_seq_len = 2048
|
||||||
|
|
||||||
|
# create a new Transformer object and set weights
|
||||||
|
model = Transformer(config)
|
||||||
|
|
||||||
|
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
|
||||||
|
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
i = layer.layer_id
|
||||||
|
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
|
||||||
|
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
|
||||||
|
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
|
||||||
|
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
|
||||||
|
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
|
||||||
|
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
|
||||||
|
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
|
||||||
|
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
|
||||||
|
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
|
||||||
|
|
||||||
|
# final classifier
|
||||||
|
model.output.weight = nn.Parameter(state_dict['output.weight'])
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
def load_hf_model(model_path):
|
def load_hf_model(model_path):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -381,17 +449,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||||
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
|
||||||
parser.add_argument("--hf", type=str, help="huggingface model")
|
|
||||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||||
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||||
|
group.add_argument("--meta-llama", type=str, help="meta llama model path")
|
||||||
|
group.add_argument("--hf", type=str, help="huggingface model path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
model = load_checkpoint(args.checkpoint)
|
model = load_checkpoint(args.checkpoint)
|
||||||
|
elif args.meta_llama:
|
||||||
|
model = load_meta_model(args.meta_llama)
|
||||||
elif args.hf:
|
elif args.hf:
|
||||||
model = load_hf_model(args.hf)
|
model = load_hf_model(args.hf)
|
||||||
else:
|
|
||||||
parser.error("Input model missing: --checkpoint or --hf is required")
|
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
parser.error("Can't load input model!")
|
parser.error("Can't load input model!")
|
||||||
|
|||||||
Reference in New Issue
Block a user