""" This script has functions and utilties for model export. Basically, we have a bunch of versions of the model, and we want to export them to .bin files to be read from and inferenced in C. Among the "input" versions of PyTorch files/models: - Official Llama 2 weights released by Meta - Huggingface weights available on the hub - llama2.c (this repo) trained models Among the "output" versions of .bin files: - v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED) - v1-vN: Improved .bin files with a proper header, cache alignment, etc. This script aspires to provide all of these conversions. """ import struct import argparse import torch import numpy as np from model import ModelArgs, Transformer # ----------------------------------------------------------------------------- # common utilities def serialize_fp32(file, tensor): """ writes one fp32 tensor to file that is open in wb mode """ d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) b = struct.pack(f'{len(d)}f', *d) file.write(b) def serialize_int8(file, tensor): """ writes one int8 tensor to file that is open in wb mode """ d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) b = struct.pack(f'{len(d)}b', *d) file.write(b) def quantize_q80(w, group_size): """ takes a tensor and returns the Q8_0 quantized version i.e. symmetric quantization into int8, range [-127,127] """ assert w.numel() % group_size == 0 ori_shape = w.shape w = w.float() # convert to float32 w = w.reshape(-1, group_size) # find the max in each group wmax = torch.abs(w).max(dim=1).values # calculate the scaling factor such that float = quant * scale scale = wmax / 127.0 # scale into range [-127, 127] quant = w / scale[:,None] # round to nearest integer int8val = torch.round(quant).to(torch.int8) # dequantize by rescaling fp32val = (int8val.float() * scale[:,None]).view(-1) fp32valr = fp32val.reshape(-1, group_size) # calculate the max error in each group err = torch.abs(fp32valr - w).max(dim=1).values # find the max error across all groups maxerr = err.max().item() return int8val, scale, maxerr # ----------------------------------------------------------------------------- # legacy def legacy_export(model, filepath): """ Original export of llama2.c bin files, i.e. version v0 """ out_file = open(filepath, 'wb') # first write out the header hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] p = model.params n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) # next write out the embedding weights serialize_fp32(out_file, model.tok_embeddings.weight) # now all the layers # attention weights for layer in model.layers: serialize_fp32(out_file, layer.attention_norm.weight) for layer in model.layers: serialize_fp32(out_file, layer.attention.wq.weight) for layer in model.layers: serialize_fp32(out_file, layer.attention.wk.weight) for layer in model.layers: serialize_fp32(out_file, layer.attention.wv.weight) for layer in model.layers: serialize_fp32(out_file, layer.attention.wo.weight) # ffn weights for layer in model.layers: serialize_fp32(out_file, layer.ffn_norm.weight) for layer in model.layers: serialize_fp32(out_file, layer.feed_forward.w1.weight) for layer in model.layers: serialize_fp32(out_file, layer.feed_forward.w2.weight) for layer in model.layers: serialize_fp32(out_file, layer.feed_forward.w3.weight) # final rmsnorm serialize_fp32(out_file, model.norm.weight) # note: no need to write final classifier weights due to weight sharing # freqs_cis serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) # write to binary file out_file.close() print(f"wrote {filepath}") # ----------------------------------------------------------------------------- # new version def version1_export(model, filepath, group_size=64): """ Export the model weights in Q8_0 into .bin file to be read from C. That is: - quantize all weights to symmetric int8, in range [-127, 127] - all other tensors (the rmsnorm params) are kept and exported in fp32 - quantization is done in groups of group_size to reduce the effects of any outliers """ version = 1 # let's first do some validation for this export type while model.params.dim % group_size != 0: group_size //= 2 print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim") weights = [ model.tok_embeddings.weight, *[layer.attention.wq.weight for layer in model.layers], *[layer.attention.wk.weight for layer in model.layers], *[layer.attention.wv.weight for layer in model.layers], *[layer.attention.wo.weight for layer in model.layers], *[layer.feed_forward.w1.weight for layer in model.layers], *[layer.feed_forward.w2.weight for layer in model.layers], *[layer.feed_forward.w3.weight for layer in model.layers], ] for w in weights: assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" # write out_file = open(filepath, 'wb') # first write out the header. the header will be 256 bytes nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) nbytes += 4 # 2) write version, which will be int out_file.write(struct.pack('i', version)) nbytes += 4 # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) nbytes += 1 out_file.write(struct.pack('i', group_size)) # group size used for quantization nbytes += 4 pad = 256 - nbytes # pad the rest with zeros assert pad >= 0 out_file.write(b'\0' * pad) # now that the header is done, let's write out the model # first let's write out all the params that we are keeping in fp32: the norms for layer in model.layers: # attention norms serialize_fp32(out_file, layer.attention_norm.weight) for layer in model.layers: # MLP norms serialize_fp32(out_file, layer.ffn_norm.weight) serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm # now let's write out all the params that we are quantizing to Q8_0 # note we skip classifier weights, which are shared with the embedding ew = [] scales = [] for i, w in enumerate(weights): # quantize this weight q, s, err = quantize_q80(w, group_size) # save the int8 weights to file serialize_int8(out_file, q) # save the tensor in int8 scales.append(s) # we'll do all the scales after all the qs # logging ew.append((err, w.shape)) print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") # save the scaling factors in fp32 here # this is done to keep all the weights contiquous, making pointer arithmetic easier in C for s in scales: serialize_fp32(out_file, s) # print the highest error across all weights, should be very small, e.g. O(~0.001) ew.sort(reverse=True) print(f"max quantization group error across all weights: {ew[0][0]}") # write to binary file out_file.close() print(f"wrote {filepath}") # ----------------------------------------------------------------------------- # API entrypoint def model_export(model, filepath, version): if version == 0: legacy_export(model, filepath) elif version == 1: version1_export(model, filepath) else: raise ValueError(f"unknown version {version}") # ----------------------------------------------------------------------------- # CLI entrypoint if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file") parser.add_argument("--version", default=0, type=int, help="the version to export with") args = parser.parse_args() # load the provided model checkpoint checkpoint_dict = torch.load(args.checkpoint, map_location='cpu') gptconf = ModelArgs(**checkpoint_dict['model_args']) model = Transformer(gptconf) state_dict = checkpoint_dict['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) model.eval() # export model_export(model, args.filepath, args.version)