small improvements to comments and warnings and increase header size during model export

This commit is contained in:
Andrej Karpathy
2023-08-17 14:32:22 +00:00
parent 5e2e5b28f4
commit e9cbe3e84f
+16 -5
View File
@@ -380,10 +380,10 @@ class Transformer(nn.Module):
maxerr = err.max().item()
return int8val, wmax, maxerr
# first write out the header. the header will be 128 bytes
# first write out the header. the header will be 256 bytes
# 1) write magic, which will be uint32 of "ak42" in ASCII
out_file.write(struct.pack('I', 0x616b3432))
# 2) write version, which will be uint32 of 1
# 2) write version, which will be uint32
out_file.write(struct.pack('I', 1))
# 3) write the params, which will be 7 ints
p = self.params
@@ -395,8 +395,8 @@ class Transformer(nn.Module):
shared_classifier = 1 # we do share a classifier, write flag as a byte
out_file.write(struct.pack('B', shared_classifier))
# ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes
# let's pad the rest of the header to exactly 128 bytes
out_file.write(struct.pack('B'*91, *[0]*91))
pad = 256 - 37 # pad the rest with zeros
out_file.write(b'\0' * pad)
# now that the header is done, let's write out the model
# first let's write out all the params that we are keeping in fp32: the norms
@@ -421,16 +421,27 @@ class Transformer(nn.Module):
ew = []
for i, w in enumerate(weights):
# find a good group size for this weight tensor
gs = 64 # group size we want
while w.numel() % gs != 0:
gs //= 2 # but fall back as needed
if gs <= 8:
print(f"WARNING: weight of shape {tuple(w.shape)} caused group size to fall down to {gs}")
# quantize this weight
q, s, err = quantize_q80(w, group_size=gs)
out_file.write(struct.pack('I', gs))
# save to file
out_file.write(struct.pack('I', gs)) # save the group size as uint32
serialize_int8(q) # save the tensor in int8
serialize_fp32(s) # save the scaling factors in fp32
# logging
ew.append((err, w.shape))
print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}")
# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)
print(f"max quantization group error across all weights: {ew[0][0]}")