small improvements to comments and warnings and increase header size during model export
This commit is contained in:
@@ -380,10 +380,10 @@ class Transformer(nn.Module):
|
||||
maxerr = err.max().item()
|
||||
return int8val, wmax, maxerr
|
||||
|
||||
# first write out the header. the header will be 128 bytes
|
||||
# first write out the header. the header will be 256 bytes
|
||||
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
||||
out_file.write(struct.pack('I', 0x616b3432))
|
||||
# 2) write version, which will be uint32 of 1
|
||||
# 2) write version, which will be uint32
|
||||
out_file.write(struct.pack('I', 1))
|
||||
# 3) write the params, which will be 7 ints
|
||||
p = self.params
|
||||
@@ -395,8 +395,8 @@ class Transformer(nn.Module):
|
||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
||||
out_file.write(struct.pack('B', shared_classifier))
|
||||
# ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes
|
||||
# let's pad the rest of the header to exactly 128 bytes
|
||||
out_file.write(struct.pack('B'*91, *[0]*91))
|
||||
pad = 256 - 37 # pad the rest with zeros
|
||||
out_file.write(b'\0' * pad)
|
||||
# now that the header is done, let's write out the model
|
||||
|
||||
# first let's write out all the params that we are keeping in fp32: the norms
|
||||
@@ -421,16 +421,27 @@ class Transformer(nn.Module):
|
||||
|
||||
ew = []
|
||||
for i, w in enumerate(weights):
|
||||
|
||||
# find a good group size for this weight tensor
|
||||
gs = 64 # group size we want
|
||||
while w.numel() % gs != 0:
|
||||
gs //= 2 # but fall back as needed
|
||||
if gs <= 8:
|
||||
print(f"WARNING: weight of shape {tuple(w.shape)} caused group size to fall down to {gs}")
|
||||
|
||||
# quantize this weight
|
||||
q, s, err = quantize_q80(w, group_size=gs)
|
||||
out_file.write(struct.pack('I', gs))
|
||||
|
||||
# save to file
|
||||
out_file.write(struct.pack('I', gs)) # save the group size as uint32
|
||||
serialize_int8(q) # save the tensor in int8
|
||||
serialize_fp32(s) # save the scaling factors in fp32
|
||||
|
||||
# logging
|
||||
ew.append((err, w.shape))
|
||||
print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}")
|
||||
|
||||
# print the highest error across all weights, should be very small, e.g. O(~0.001)
|
||||
ew.sort(reverse=True)
|
||||
print(f"max quantization group error across all weights: {ew[0][0]}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user