From e9cbe3e84fafa5b31c76a368265c6cf78ac7e564 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 17 Aug 2023 14:32:22 +0000 Subject: [PATCH] small improvements to comments and warnings and increase header size during model export --- model.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/model.py b/model.py index b9584df..f12f84c 100644 --- a/model.py +++ b/model.py @@ -380,10 +380,10 @@ class Transformer(nn.Module): maxerr = err.max().item() return int8val, wmax, maxerr - # first write out the header. the header will be 128 bytes + # first write out the header. the header will be 256 bytes # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - # 2) write version, which will be uint32 of 1 + # 2) write version, which will be uint32 out_file.write(struct.pack('I', 1)) # 3) write the params, which will be 7 ints p = self.params @@ -395,8 +395,8 @@ class Transformer(nn.Module): shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) # ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes - # let's pad the rest of the header to exactly 128 bytes - out_file.write(struct.pack('B'*91, *[0]*91)) + pad = 256 - 37 # pad the rest with zeros + out_file.write(b'\0' * pad) # now that the header is done, let's write out the model # first let's write out all the params that we are keeping in fp32: the norms @@ -421,16 +421,27 @@ class Transformer(nn.Module): ew = [] for i, w in enumerate(weights): + + # find a good group size for this weight tensor gs = 64 # group size we want while w.numel() % gs != 0: gs //= 2 # but fall back as needed + if gs <= 8: + print(f"WARNING: weight of shape {tuple(w.shape)} caused group size to fall down to {gs}") + + # quantize this weight q, s, err = quantize_q80(w, group_size=gs) - out_file.write(struct.pack('I', gs)) + + # save to file + out_file.write(struct.pack('I', gs)) # save the group size as uint32 serialize_int8(q) # save the tensor in int8 serialize_fp32(s) # save the scaling factors in fp32 + + # logging ew.append((err, w.shape)) print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}") + # print the highest error across all weights, should be very small, e.g. O(~0.001) ew.sort(reverse=True) print(f"max quantization group error across all weights: {ew[0][0]}")