From de005474d37d0cde1356739b8c79ebe7b42b5973 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 14:13:47 +0300 Subject: [PATCH] Added load_meta_model() to export.py --- export.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/export.py b/export.py index e486a81..a60d7cf 100644 --- a/export.py +++ b/export.py @@ -19,6 +19,9 @@ import gzip import shutil import struct import argparse +import json +from pathlib import Path + import numpy as np import torch from torch import nn @@ -30,7 +33,7 @@ from model import ModelArgs, Transformer def serialize_fp32(file, tensor): """ writes one fp32 tensor to file that is open in wb mode """ - d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) + d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() b = struct.pack(f'{len(d)}f', *d) file.write(b) @@ -281,6 +284,71 @@ def load_checkpoint(checkpoint): model.eval() return model +def load_meta_model(model_path): + params_path = os.path.join(model_path, 'params.json') + with open(params_path) as f: + params = json.load(f) + print(params) + + model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) + models = [torch.load(p, map_location='cpu') for p in model_paths] + + def concat_weights(models): + state_dict = {} + for name in list(models[0]): + tensors = [model[name] for model in models] + if len(tensors) == 1 or len(tensors[0].shape) == 1: + state_dict[name] = tensors[0] + continue + is_axis_1 = ( + name.startswith('tok_embeddings.') + or name.endswith('.attention.wo.weight') + or name.endswith('.feed_forward.w2.weight') + ) + axis = 1 if is_axis_1 else 0 + state_dict[name] = torch.cat(tensors, dim=axis) + for model in models: + del model[name] + return state_dict + + state_dict = concat_weights(models) + del models + + # set ModelArgs + config = ModelArgs() + config.dim = params["dim"] + config.n_layers = params["n_layers"] + config.n_heads = params["n_heads"] + config.n_kv_heads = params.get('n_kv_heads') or params['n_heads'] + config.multiple_of = params["multiple_of"] + config.norm_eps = params["norm_eps"] + + config.vocab_size = 32000 + config.max_seq_len = 2048 + + # create a new Transformer object and set weights + model = Transformer(config) + + model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight']) + model.norm.weight = nn.Parameter(state_dict['norm.weight']) + + for layer in model.layers: + i = layer.layer_id + layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight']) + layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight']) + layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight']) + layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight']) + layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight']) + layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight']) + layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight']) + layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight']) + layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight']) + + # final classifier + model.output.weight = nn.Parameter(state_dict['output.weight']) + model.eval() + return model + def load_hf_model(model_path): try: @@ -381,17 +449,19 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") - parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") - parser.add_argument("--hf", type=str, help="huggingface model") parser.add_argument("--version", default=0, type=int, help="the version to export with") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") + group.add_argument("--meta-llama", type=str, help="meta llama model path") + group.add_argument("--hf", type=str, help="huggingface model path") args = parser.parse_args() if args.checkpoint: model = load_checkpoint(args.checkpoint) + elif args.meta_llama: + model = load_meta_model(args.meta_llama) elif args.hf: model = load_hf_model(args.hf) - else: - parser.error("Input model missing: --checkpoint or --hf is required") if model is None: parser.error("Can't load input model!")