358 lines
14 KiB
Python
358 lines
14 KiB
Python
"""
|
|
This script has functions and utilties for model export.
|
|
Basically, we have a bunch of versions of the model, and we
|
|
want to export them to .bin files to be read from and inferenced in C.
|
|
|
|
Among the "input" versions of PyTorch files/models:
|
|
- Official Llama 2 weights released by Meta
|
|
- Huggingface weights available on the hub
|
|
- llama2.c (this repo) trained models
|
|
|
|
Among the "output" versions of .bin files:
|
|
- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
|
|
- v1-vN: Improved .bin files with a proper header, cache alignment, etc.
|
|
|
|
This script aspires to provide all of these conversions.
|
|
"""
|
|
import struct
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from model import ModelArgs, Transformer
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# common utilities
|
|
|
|
def serialize_fp32(file, tensor):
|
|
""" writes one fp32 tensor to file that is open in wb mode """
|
|
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
|
|
b = struct.pack(f'{len(d)}f', *d)
|
|
file.write(b)
|
|
|
|
def serialize_int8(file, tensor):
|
|
""" writes one int8 tensor to file that is open in wb mode """
|
|
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
|
|
b = struct.pack(f'{len(d)}b', *d)
|
|
file.write(b)
|
|
|
|
def quantize_q80(w, group_size):
|
|
"""
|
|
takes a tensor and returns the Q8_0 quantized version
|
|
i.e. symmetric quantization into int8, range [-127,127]
|
|
"""
|
|
assert w.numel() % group_size == 0
|
|
ori_shape = w.shape
|
|
w = w.float() # convert to float32
|
|
w = w.reshape(-1, group_size)
|
|
# find the max in each group
|
|
wmax = torch.abs(w).max(dim=1).values
|
|
# calculate the scaling factor such that float = quant * scale
|
|
scale = wmax / 127.0
|
|
# scale into range [-127, 127]
|
|
quant = w / scale[:,None]
|
|
# round to nearest integer
|
|
int8val = torch.round(quant).to(torch.int8)
|
|
# dequantize by rescaling
|
|
fp32val = (int8val.float() * scale[:,None]).view(-1)
|
|
fp32valr = fp32val.reshape(-1, group_size)
|
|
# calculate the max error in each group
|
|
err = torch.abs(fp32valr - w).max(dim=1).values
|
|
# find the max error across all groups
|
|
maxerr = err.max().item()
|
|
return int8val, scale, maxerr
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# legacy
|
|
|
|
def legacy_export(model, filepath):
|
|
""" Original export of llama2.c bin files, i.e. version v0 """
|
|
out_file = open(filepath, 'wb')
|
|
|
|
# first write out the header
|
|
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
|
p = model.params
|
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
|
# legacy format uses negative/positive vocab size as a shared classifier flag
|
|
if not shared_classifier:
|
|
p.vocab_size = -p.vocab_size
|
|
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
|
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
|
out_file.write(header)
|
|
|
|
# next write out the embedding weights
|
|
serialize_fp32(out_file, model.tok_embeddings.weight)
|
|
|
|
# now all the layers
|
|
# attention weights
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.attention_norm.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.attention.wq.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.attention.wk.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.attention.wv.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.attention.wo.weight)
|
|
# ffn weights
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.ffn_norm.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.feed_forward.w1.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.feed_forward.w2.weight)
|
|
for layer in model.layers:
|
|
serialize_fp32(out_file, layer.feed_forward.w3.weight)
|
|
# final rmsnorm
|
|
serialize_fp32(out_file, model.norm.weight)
|
|
# freqs_cis
|
|
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
|
|
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
|
|
|
|
# final classifier weights
|
|
if not shared_classifier:
|
|
serialize_fp32(out_file, model.output.weight)
|
|
|
|
# write to binary file
|
|
out_file.close()
|
|
print(f"wrote {filepath}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# new version
|
|
|
|
def version1_export(model, filepath):
|
|
"""
|
|
Export the model weights in full float32 .bin file to be read from C.
|
|
This is same as legacy_export, but with a proper header.
|
|
"""
|
|
version = 1
|
|
|
|
out_file = open(filepath, 'wb')
|
|
# first write out the header. the header will be 256 bytes
|
|
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
|
out_file.write(struct.pack('I', 0x616b3432))
|
|
# 2) write version, which will be int
|
|
out_file.write(struct.pack('i', version))
|
|
# 3) write the params, which will be 7 ints
|
|
p = model.params
|
|
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
|
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
|
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
|
out_file.write(header)
|
|
# 4) write some other flags
|
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
|
out_file.write(struct.pack('B', int(shared_classifier)))
|
|
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
|
assert pad >= 0
|
|
out_file.write(b'\0' * pad)
|
|
|
|
# now let's write out all the params
|
|
weights = [
|
|
*[layer.attention_norm.weight for layer in model.layers],
|
|
*[layer.ffn_norm.weight for layer in model.layers],
|
|
model.norm.weight,
|
|
model.tok_embeddings.weight,
|
|
*[layer.attention.wq.weight for layer in model.layers],
|
|
*[layer.attention.wk.weight for layer in model.layers],
|
|
*[layer.attention.wv.weight for layer in model.layers],
|
|
*[layer.attention.wo.weight for layer in model.layers],
|
|
*[layer.feed_forward.w1.weight for layer in model.layers],
|
|
*[layer.feed_forward.w2.weight for layer in model.layers],
|
|
*[layer.feed_forward.w3.weight for layer in model.layers],
|
|
]
|
|
if not shared_classifier:
|
|
weights.append(model.output.weight)
|
|
for w in weights:
|
|
serialize_fp32(out_file, w)
|
|
|
|
# write to binary file
|
|
out_file.close()
|
|
print(f"wrote {filepath}")
|
|
|
|
def version2_export(model, filepath, group_size=64):
|
|
"""
|
|
Export the model weights in Q8_0 into .bin file to be read from C.
|
|
That is:
|
|
- quantize all weights to symmetric int8, in range [-127, 127]
|
|
- all other tensors (the rmsnorm params) are kept and exported in fp32
|
|
- quantization is done in groups of group_size to reduce the effects of any outliers
|
|
"""
|
|
version = 2
|
|
|
|
# let's first do some validation for this export type
|
|
while model.params.dim % group_size != 0:
|
|
group_size //= 2
|
|
print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
|
|
weights = [
|
|
model.tok_embeddings.weight,
|
|
*[layer.attention.wq.weight for layer in model.layers],
|
|
*[layer.attention.wk.weight for layer in model.layers],
|
|
*[layer.attention.wv.weight for layer in model.layers],
|
|
*[layer.attention.wo.weight for layer in model.layers],
|
|
*[layer.feed_forward.w1.weight for layer in model.layers],
|
|
*[layer.feed_forward.w2.weight for layer in model.layers],
|
|
*[layer.feed_forward.w3.weight for layer in model.layers],
|
|
]
|
|
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
|
if not shared_classifier:
|
|
weights.append(model.output.weight)
|
|
for w in weights:
|
|
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
|
|
|
# write
|
|
out_file = open(filepath, 'wb')
|
|
# first write out the header. the header will be 256 bytes
|
|
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
|
out_file.write(struct.pack('I', 0x616b3432))
|
|
# 2) write version, which will be int
|
|
out_file.write(struct.pack('i', version))
|
|
# 3) write the params, which will be 7 ints
|
|
p = model.params
|
|
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
|
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
|
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
|
out_file.write(header)
|
|
# 4) write some other flags
|
|
out_file.write(struct.pack('B', int(shared_classifier)))
|
|
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
|
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
|
assert pad >= 0
|
|
out_file.write(b'\0' * pad)
|
|
# now that the header is done, let's write out the model
|
|
|
|
# first let's write out all the params that we are keeping in fp32: the norms
|
|
for layer in model.layers: # attention norms
|
|
serialize_fp32(out_file, layer.attention_norm.weight)
|
|
for layer in model.layers: # MLP norms
|
|
serialize_fp32(out_file, layer.ffn_norm.weight)
|
|
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
|
|
|
|
# now let's write out all the params that we are quantizing to Q8_0
|
|
# note we skip classifier weights, which are shared with the embedding
|
|
ew = []
|
|
scales = []
|
|
for i, w in enumerate(weights):
|
|
# quantize this weight
|
|
q, s, err = quantize_q80(w, group_size)
|
|
# save the int8 weights to file
|
|
serialize_int8(out_file, q) # save the tensor in int8
|
|
scales.append(s) # we'll do all the scales after all the qs
|
|
# logging
|
|
ew.append((err, w.shape))
|
|
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
|
|
|
|
# save the scaling factors in fp32 here
|
|
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
|
|
for s in scales:
|
|
serialize_fp32(out_file, s)
|
|
|
|
# print the highest error across all weights, should be very small, e.g. O(~0.001)
|
|
ew.sort(reverse=True)
|
|
print(f"max quantization group error across all weights: {ew[0][0]}")
|
|
|
|
# write to binary file
|
|
out_file.close()
|
|
print(f"wrote {filepath}")
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Load / import functions
|
|
|
|
def load_checkpoint(checkpoint):
|
|
|
|
# load the provided model checkpoint
|
|
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
|
|
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
|
model = Transformer(gptconf)
|
|
state_dict = checkpoint_dict['model']
|
|
unwanted_prefix = '_orig_mod.'
|
|
for k,v in list(state_dict.items()):
|
|
if k.startswith(unwanted_prefix):
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model.eval()
|
|
return model
|
|
|
|
def load_hf_model(model_path):
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
# load HF model
|
|
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
hf_dict = hf_model.state_dict()
|
|
|
|
# convert LlamaConfig to ModelArgs
|
|
config = ModelArgs()
|
|
config.dim = hf_model.config.hidden_size
|
|
config.n_layers = hf_model.config.num_hidden_layers
|
|
config.n_heads = hf_model.config.num_attention_heads
|
|
config.n_kv_heads = hf_model.config.num_attention_heads
|
|
config.vocab_size = hf_model.config.vocab_size
|
|
config.hidden_dim = hf_model.config.intermediate_size
|
|
config.norm_eps = hf_model.config.rms_norm_eps
|
|
config.max_seq_len = hf_model.config.max_position_embeddings
|
|
|
|
# create a new Transformer object and set weights
|
|
model = Transformer(config)
|
|
|
|
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
|
|
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
|
|
|
|
for layer in model.layers:
|
|
i = layer.layer_id
|
|
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
|
|
layer.attention.wq.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])
|
|
layer.attention.wk.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])
|
|
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
|
|
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
|
|
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
|
|
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
|
|
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
|
|
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
|
|
|
|
# final classifier
|
|
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# API entrypoint
|
|
|
|
def model_export(model, filepath, version):
|
|
if version == 0:
|
|
legacy_export(model, filepath)
|
|
elif version == 1:
|
|
version1_export(model, filepath)
|
|
elif version == 2:
|
|
version2_export(model, filepath)
|
|
else:
|
|
raise ValueError(f"unknown version {version}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# CLI entrypoint
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("filepath", type=str, help="the output filepath")
|
|
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
|
parser.add_argument("--hf", type=str, help="huggingface model")
|
|
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
|
args = parser.parse_args()
|
|
|
|
if args.checkpoint:
|
|
model = load_checkpoint(args.checkpoint)
|
|
elif args.hf:
|
|
model = load_hf_model(args.hf)
|
|
else:
|
|
parser.error("Input model missing: --checkpoint or --hf is required")
|
|
|
|
# export
|
|
model_export(model, args.filepath, args.version)
|