From 5e2e5b28f48467bc83e4ee10e6f90af581053f43 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 17 Aug 2023 05:56:20 +0000 Subject: [PATCH] re-write the model export to do int8 quantization in groups, with group size fallback, and also change the header to be much better --- model.py | 123 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 85 insertions(+), 38 deletions(-) diff --git a/model.py b/model.py index c8c82a9..b9584df 100644 --- a/model.py +++ b/model.py @@ -340,53 +340,100 @@ class Transformer(nn.Module): return idx def export(self, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') + """export the model weights in Q8_0 into .bin file to be read from C""" + hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] + out_file = open(filepath, 'wb') - def serialize(t): + def serialize_fp32(t): + """ writes one fp32 tensor to file """ d = t.detach().cpu().view(-1).numpy().astype(np.float32) b = struct.pack(f'{len(d)}f', *d) - f.write(b) + out_file.write(b) - # first write out the header - hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] + def serialize_int8(t): + """ writes one int8 tensor to file """ + d = t.detach().cpu().view(-1).numpy().astype(np.int8) + b = struct.pack(f'{len(d)}b', *d) + out_file.write(b) + + def quantize_q80(w, group_size=32): + """ + takes a tensor and returns the Q8_0 quantized version + i.e. symmetric quantization into int8, range [-127,127] + """ + assert w.numel() % group_size == 0 + ori_shape = w.shape + w = w.float() # convert to float32 + w = w.reshape(-1, group_size) + # find the max in each group + wmax = torch.abs(w).max(dim=1).values + # scale into range [-127, 127] + scaled = w/wmax[:,None]*127 + # round to nearest integer + int8val = torch.round(scaled).to(torch.int8) + # dequantize by rescaling + fp32val = (int8val.float()*wmax[:,None]/127.0).view(-1) + fp32valr = fp32val.reshape(-1, group_size) + # calculate the max error in each group + err = torch.abs(fp32valr - w).max(dim=1).values + # find the max error across all groups + maxerr = err.max().item() + return int8val, wmax, maxerr + + # first write out the header. the header will be 128 bytes + # 1) write magic, which will be uint32 of "ak42" in ASCII + out_file.write(struct.pack('I', 0x616b3432)) + # 2) write version, which will be uint32 of 1 + out_file.write(struct.pack('I', 1)) + # 3) write the params, which will be 7 ints p = self.params n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + header = struct.pack('IIIIIII', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) - f.write(header) + out_file.write(header) + # 4) write some other flags + shared_classifier = 1 # we do share a classifier, write flag as a byte + out_file.write(struct.pack('B', shared_classifier)) + # ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes + # let's pad the rest of the header to exactly 128 bytes + out_file.write(struct.pack('B'*91, *[0]*91)) + # now that the header is done, let's write out the model - # next write out the embedding weights - serialize(self.tok_embeddings.weight) + # first let's write out all the params that we are keeping in fp32: the norms + for layer in self.layers: # attention norms + serialize_fp32(layer.attention_norm.weight) + for layer in self.layers: # MLP norms + serialize_fp32(layer.ffn_norm.weight) + serialize_fp32(self.norm.weight) # final pre-classifier norm - # now all the layers - # attention weights - for layer in self.layers: - serialize(layer.attention_norm.weight) - for layer in self.layers: - serialize(layer.attention.wq.weight) - for layer in self.layers: - serialize(layer.attention.wk.weight) - for layer in self.layers: - serialize(layer.attention.wv.weight) - for layer in self.layers: - serialize(layer.attention.wo.weight) - # ffn weights - for layer in self.layers: - serialize(layer.ffn_norm.weight) - for layer in self.layers: - serialize(layer.feed_forward.w1.weight) - for layer in self.layers: - serialize(layer.feed_forward.w2.weight) - for layer in self.layers: - serialize(layer.feed_forward.w3.weight) - # final rmsnorm - serialize(self.norm.weight) - # note: no need to write final classifier weights due to weight sharing - # freqs_cis - serialize(self.freqs_cos[:p.max_seq_len]) - serialize(self.freqs_sin[:p.max_seq_len]) + # now let's write out all the params that we are quantizing to Q8_0 + # note we skip classifier weights, which are shared with the embedding + weights = [ + self.tok_embeddings.weight, + *[layer.attention.wq.weight for layer in self.layers], + *[layer.attention.wk.weight for layer in self.layers], + *[layer.attention.wv.weight for layer in self.layers], + *[layer.attention.wo.weight for layer in self.layers], + *[layer.feed_forward.w1.weight for layer in self.layers], + *[layer.feed_forward.w2.weight for layer in self.layers], + *[layer.feed_forward.w3.weight for layer in self.layers], + ] + + ew = [] + for i, w in enumerate(weights): + gs = 64 # group size we want + while w.numel() % gs != 0: + gs //= 2 # but fall back as needed + q, s, err = quantize_q80(w, group_size=gs) + out_file.write(struct.pack('I', gs)) + serialize_int8(q) # save the tensor in int8 + serialize_fp32(s) # save the scaling factors in fp32 + ew.append((err, w.shape)) + print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}") + + ew.sort(reverse=True) + print(f"max quantization group error across all weights: {ew[0][0]}") # write to binary file - f.close() + out_file.close() print(f"wrote {filepath}")