diff --git a/export.py b/export.py new file mode 100644 index 0000000..4710649 --- /dev/null +++ b/export.py @@ -0,0 +1,243 @@ +""" +This script has functions and utilties for model export. +Basically, we have a bunch of versions of the model, and we +want to export them to .bin files to be read from and inferenced in C. + +Among the "input" versions of PyTorch files/models: +- Official Llama 2 weights released by Meta +- Huggingface weights available on the hub +- llama2.c (this repo) trained models + +Among the "output" versions of .bin files: +- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED) +- v1-vN: Improved .bin files with a proper header, cache alignment, etc. + +This script aspires to provide all of these conversions. +""" +import struct +import argparse +import torch +import numpy as np + +from model import ModelArgs, Transformer + +# ----------------------------------------------------------------------------- +# common utilities + +def serialize_fp32(file, tensor): + """ writes one fp32 tensor to file that is open in wb mode """ + d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) + b = struct.pack(f'{len(d)}f', *d) + file.write(b) + +def serialize_int8(file, tensor): + """ writes one int8 tensor to file that is open in wb mode """ + d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) + b = struct.pack(f'{len(d)}b', *d) + file.write(b) + +def quantize_q80(w, group_size): + """ + takes a tensor and returns the Q8_0 quantized version + i.e. symmetric quantization into int8, range [-127,127] + """ + assert w.numel() % group_size == 0 + ori_shape = w.shape + w = w.float() # convert to float32 + w = w.reshape(-1, group_size) + # find the max in each group + wmax = torch.abs(w).max(dim=1).values + # calculate the scaling factor such that float = quant * scale + scale = wmax / 127.0 + # scale into range [-127, 127] + quant = w / scale[:,None] + # round to nearest integer + int8val = torch.round(quant).to(torch.int8) + # dequantize by rescaling + fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32valr = fp32val.reshape(-1, group_size) + # calculate the max error in each group + err = torch.abs(fp32valr - w).max(dim=1).values + # find the max error across all groups + maxerr = err.max().item() + return int8val, scale, maxerr + +# ----------------------------------------------------------------------------- +# legacy + +def legacy_export(model, filepath): + """ Original export of llama2.c bin files, i.e. version v0 """ + out_file = open(filepath, 'wb') + + # first write out the header + hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] + p = model.params + n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + n_kv_heads, p.vocab_size, p.max_seq_len) + out_file.write(header) + + # next write out the embedding weights + serialize_fp32(out_file, model.tok_embeddings.weight) + + # now all the layers + # attention weights + for layer in model.layers: + serialize_fp32(out_file, layer.attention_norm.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wq.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wk.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wv.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wo.weight) + # ffn weights + for layer in model.layers: + serialize_fp32(out_file, layer.ffn_norm.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w1.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w2.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w3.weight) + # final rmsnorm + serialize_fp32(out_file, model.norm.weight) + # note: no need to write final classifier weights due to weight sharing + # freqs_cis + serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) + serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) + + # write to binary file + out_file.close() + print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# new version + +def version1_export(model, filepath, group_size=64): + """ + Export the model weights in Q8_0 into .bin file to be read from C. + That is: + - quantize all weights to symmetric int8, in range [-127, 127] + - all other tensors (the rmsnorm params) are kept and exported in fp32 + - quantization is done in groups of group_size to reduce the effects of any outliers + """ + version = 1 + + # let's first do some validation for this export type + while model.params.dim % group_size != 0: + group_size //= 2 + print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim") + weights = [ + model.tok_embeddings.weight, + *[layer.attention.wq.weight for layer in model.layers], + *[layer.attention.wk.weight for layer in model.layers], + *[layer.attention.wv.weight for layer in model.layers], + *[layer.attention.wo.weight for layer in model.layers], + *[layer.feed_forward.w1.weight for layer in model.layers], + *[layer.feed_forward.w2.weight for layer in model.layers], + *[layer.feed_forward.w3.weight for layer in model.layers], + ] + for w in weights: + assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" + + # write + out_file = open(filepath, 'wb') + # first write out the header. the header will be 256 bytes + nbytes = 0 + # 1) write magic, which will be uint32 of "ak42" in ASCII + out_file.write(struct.pack('I', 0x616b3432)) + nbytes += 4 + # 2) write version, which will be int + out_file.write(struct.pack('i', version)) + nbytes += 4 + # 3) write the params, which will be 7 ints + p = model.params + hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] + n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + n_kv_heads, p.vocab_size, p.max_seq_len) + out_file.write(header) + nbytes += 7*4 + # 4) write some other flags + shared_classifier = 1 # we do share a classifier, write flag as a byte + out_file.write(struct.pack('B', shared_classifier)) + nbytes += 1 + out_file.write(struct.pack('i', group_size)) # group size used for quantization + nbytes += 4 + pad = 256 - nbytes # pad the rest with zeros + assert pad >= 0 + out_file.write(b'\0' * pad) + # now that the header is done, let's write out the model + + # first let's write out all the params that we are keeping in fp32: the norms + for layer in model.layers: # attention norms + serialize_fp32(out_file, layer.attention_norm.weight) + for layer in model.layers: # MLP norms + serialize_fp32(out_file, layer.ffn_norm.weight) + serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm + + # now let's write out all the params that we are quantizing to Q8_0 + # note we skip classifier weights, which are shared with the embedding + ew = [] + scales = [] + for i, w in enumerate(weights): + # quantize this weight + q, s, err = quantize_q80(w, group_size) + # save the int8 weights to file + serialize_int8(out_file, q) # save the tensor in int8 + scales.append(s) # we'll do all the scales after all the qs + # logging + ew.append((err, w.shape)) + print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") + + # save the scaling factors in fp32 here + # this is done to keep all the weights contiquous, making pointer arithmetic easier in C + for s in scales: + serialize_fp32(out_file, s) + + # print the highest error across all weights, should be very small, e.g. O(~0.001) + ew.sort(reverse=True) + print(f"max quantization group error across all weights: {ew[0][0]}") + + # write to binary file + out_file.close() + print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# API entrypoint + +def model_export(model, filepath, version): + if version == 0: + legacy_export(model, filepath) + elif version == 1: + version1_export(model, filepath) + else: + raise ValueError(f"unknown version {version}") + +# ----------------------------------------------------------------------------- +# CLI entrypoint + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("filepath", type=str, help="the output filepath") + parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file") + parser.add_argument("--version", default=0, type=int, help="the version to export with") + args = parser.parse_args() + + # load the provided model checkpoint + checkpoint_dict = torch.load(args.checkpoint, map_location='cpu') + gptconf = ModelArgs(**checkpoint_dict['model_args']) + model = Transformer(gptconf) + state_dict = checkpoint_dict['model'] + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) + model.eval() + + # export + model_export(model, args.filepath, args.version) diff --git a/model.py b/model.py index c8c82a9..044712f 100644 --- a/model.py +++ b/model.py @@ -338,55 +338,3 @@ class Transformer(nn.Module): idx = torch.cat((idx, idx_next), dim=1) return idx - - def export(self, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - - def serialize(t): - d = t.detach().cpu().view(-1).numpy().astype(np.float32) - b = struct.pack(f'{len(d)}f', *d) - f.write(b) - - # first write out the header - hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] - p = self.params - n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) - f.write(header) - - # next write out the embedding weights - serialize(self.tok_embeddings.weight) - - # now all the layers - # attention weights - for layer in self.layers: - serialize(layer.attention_norm.weight) - for layer in self.layers: - serialize(layer.attention.wq.weight) - for layer in self.layers: - serialize(layer.attention.wk.weight) - for layer in self.layers: - serialize(layer.attention.wv.weight) - for layer in self.layers: - serialize(layer.attention.wo.weight) - # ffn weights - for layer in self.layers: - serialize(layer.ffn_norm.weight) - for layer in self.layers: - serialize(layer.feed_forward.w1.weight) - for layer in self.layers: - serialize(layer.feed_forward.w2.weight) - for layer in self.layers: - serialize(layer.feed_forward.w3.weight) - # final rmsnorm - serialize(self.norm.weight) - # note: no need to write final classifier weights due to weight sharing - # freqs_cis - serialize(self.freqs_cos[:p.max_seq_len]) - serialize(self.freqs_sin[:p.max_seq_len]) - - # write to binary file - f.close() - print(f"wrote {filepath}") diff --git a/train.py b/train.py index b1972dc..e958538 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from tinystories import Task +from export import model_export # ----------------------------------------------------------------------------- # I/O @@ -287,7 +288,7 @@ while True: } print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) - raw_model.export(os.path.join(out_dir, "model.bin")) + model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0) if iter_num == 0 and eval_only: break