Merge pull request #326 from atamurad/import_hf
Added huggingface model loader/importer to export.py
This commit is contained in:
@@ -16,8 +16,9 @@ This script aspires to provide all of these conversions.
|
|||||||
"""
|
"""
|
||||||
import struct
|
import struct
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from model import ModelArgs, Transformer
|
from model import ModelArgs, Transformer
|
||||||
|
|
||||||
@@ -72,6 +73,10 @@ def legacy_export(model, filepath):
|
|||||||
# first write out the header
|
# first write out the header
|
||||||
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
||||||
p = model.params
|
p = model.params
|
||||||
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||||
|
# legacy format uses negative/positive vocab size as a shared classifier flag
|
||||||
|
if not shared_classifier:
|
||||||
|
p.vocab_size = -p.vocab_size
|
||||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||||
@@ -103,11 +108,14 @@ def legacy_export(model, filepath):
|
|||||||
serialize_fp32(out_file, layer.feed_forward.w3.weight)
|
serialize_fp32(out_file, layer.feed_forward.w3.weight)
|
||||||
# final rmsnorm
|
# final rmsnorm
|
||||||
serialize_fp32(out_file, model.norm.weight)
|
serialize_fp32(out_file, model.norm.weight)
|
||||||
# note: no need to write final classifier weights due to weight sharing
|
|
||||||
# freqs_cis
|
# freqs_cis
|
||||||
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
|
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
|
||||||
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
|
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
|
||||||
|
|
||||||
|
# final classifier weights
|
||||||
|
if not shared_classifier:
|
||||||
|
serialize_fp32(out_file, model.output.weight)
|
||||||
|
|
||||||
# write to binary file
|
# write to binary file
|
||||||
out_file.close()
|
out_file.close()
|
||||||
print(f"wrote {filepath}")
|
print(f"wrote {filepath}")
|
||||||
@@ -136,8 +144,8 @@ def version1_export(model, filepath):
|
|||||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||||
out_file.write(header)
|
out_file.write(header)
|
||||||
# 4) write some other flags
|
# 4) write some other flags
|
||||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||||
out_file.write(struct.pack('B', shared_classifier))
|
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||||
assert pad >= 0
|
assert pad >= 0
|
||||||
out_file.write(b'\0' * pad)
|
out_file.write(b'\0' * pad)
|
||||||
@@ -156,6 +164,8 @@ def version1_export(model, filepath):
|
|||||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||||
]
|
]
|
||||||
|
if not shared_classifier:
|
||||||
|
weights.append(model.output.weight)
|
||||||
for w in weights:
|
for w in weights:
|
||||||
serialize_fp32(out_file, w)
|
serialize_fp32(out_file, w)
|
||||||
|
|
||||||
@@ -187,6 +197,9 @@ def version2_export(model, filepath, group_size=64):
|
|||||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||||
]
|
]
|
||||||
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||||
|
if not shared_classifier:
|
||||||
|
weights.append(model.output.weight)
|
||||||
for w in weights:
|
for w in weights:
|
||||||
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
||||||
|
|
||||||
@@ -205,8 +218,7 @@ def version2_export(model, filepath, group_size=64):
|
|||||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||||
out_file.write(header)
|
out_file.write(header)
|
||||||
# 4) write some other flags
|
# 4) write some other flags
|
||||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||||
out_file.write(struct.pack('B', shared_classifier))
|
|
||||||
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
||||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||||
assert pad >= 0
|
assert pad >= 0
|
||||||
@@ -247,6 +259,77 @@ def version2_export(model, filepath, group_size=64):
|
|||||||
out_file.close()
|
out_file.close()
|
||||||
print(f"wrote {filepath}")
|
print(f"wrote {filepath}")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Load / import functions
|
||||||
|
|
||||||
|
def load_checkpoint(checkpoint):
|
||||||
|
|
||||||
|
# load the provided model checkpoint
|
||||||
|
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
|
||||||
|
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||||
|
model = Transformer(gptconf)
|
||||||
|
state_dict = checkpoint_dict['model']
|
||||||
|
unwanted_prefix = '_orig_mod.'
|
||||||
|
for k,v in list(state_dict.items()):
|
||||||
|
if k.startswith(unwanted_prefix):
|
||||||
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_hf_model(model_path):
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
print("Error: transformers package is required to load huggingface models")
|
||||||
|
print("Please run `pip install transformers` to install it")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# load HF model
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||||
|
hf_dict = hf_model.state_dict()
|
||||||
|
|
||||||
|
# convert LlamaConfig to ModelArgs
|
||||||
|
config = ModelArgs()
|
||||||
|
config.dim = hf_model.config.hidden_size
|
||||||
|
config.n_layers = hf_model.config.num_hidden_layers
|
||||||
|
config.n_heads = hf_model.config.num_attention_heads
|
||||||
|
config.n_kv_heads = hf_model.config.num_attention_heads
|
||||||
|
config.vocab_size = hf_model.config.vocab_size
|
||||||
|
config.hidden_dim = hf_model.config.intermediate_size
|
||||||
|
config.norm_eps = hf_model.config.rms_norm_eps
|
||||||
|
config.max_seq_len = hf_model.config.max_position_embeddings
|
||||||
|
|
||||||
|
# create a new Transformer object and set weights
|
||||||
|
model = Transformer(config)
|
||||||
|
|
||||||
|
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
|
||||||
|
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
|
||||||
|
|
||||||
|
# huggingface permutes WQ and WK, this function reverses it
|
||||||
|
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
|
||||||
|
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
i = layer.layer_id
|
||||||
|
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
|
||||||
|
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
|
||||||
|
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
|
||||||
|
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
|
||||||
|
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
|
||||||
|
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
|
||||||
|
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
|
||||||
|
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
|
||||||
|
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
|
||||||
|
|
||||||
|
# final classifier
|
||||||
|
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# API entrypoint
|
# API entrypoint
|
||||||
|
|
||||||
@@ -267,21 +350,20 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||||
parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file")
|
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||||
|
parser.add_argument("--hf", type=str, help="huggingface model")
|
||||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load the provided model checkpoint
|
if args.checkpoint:
|
||||||
checkpoint_dict = torch.load(args.checkpoint, map_location='cpu')
|
model = load_checkpoint(args.checkpoint)
|
||||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
elif args.hf:
|
||||||
model = Transformer(gptconf)
|
model = load_hf_model(args.hf)
|
||||||
state_dict = checkpoint_dict['model']
|
else:
|
||||||
unwanted_prefix = '_orig_mod.'
|
parser.error("Input model missing: --checkpoint or --hf is required")
|
||||||
for k,v in list(state_dict.items()):
|
|
||||||
if k.startswith(unwanted_prefix):
|
if model is None:
|
||||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
parser.error("Can't load input model!")
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# export
|
# export
|
||||||
model_export(model, args.filepath, args.version)
|
model_export(model, args.filepath, args.version)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class ModelArgs:
|
|||||||
n_heads: int = 32
|
n_heads: int = 32
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
|
hidden_dim: Optional[int] = None
|
||||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
@@ -166,6 +167,8 @@ class Attention(nn.Module):
|
|||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if hidden_dim is None:
|
||||||
|
hidden_dim = 4 * dim
|
||||||
hidden_dim = int(2 * hidden_dim / 3)
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
@@ -186,7 +189,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self.attention = Attention(args)
|
self.attention = Attention(args)
|
||||||
self.feed_forward = FeedForward(
|
self.feed_forward = FeedForward(
|
||||||
dim=args.dim,
|
dim=args.dim,
|
||||||
hidden_dim=4 * args.dim,
|
hidden_dim=args.hidden_dim,
|
||||||
multiple_of=args.multiple_of,
|
multiple_of=args.multiple_of,
|
||||||
dropout=args.dropout,
|
dropout=args.dropout,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user