ok this works but is super slow because we are doing all the work in fp32 still

This commit is contained in:
Andrej Karpathy
2023-08-18 03:40:18 +00:00
parent e9cbe3e84f
commit 591f1353c7
2 changed files with 169 additions and 94 deletions
+30 -22
View File
@@ -339,7 +339,7 @@ class Transformer(nn.Module):
return idx
def export(self, filepath='model.bin'):
def export(self, filepath='model.bin', group_size=64):
"""export the model weights in Q8_0 into .bin file to be read from C"""
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
out_file = open(filepath, 'wb')
@@ -356,7 +356,7 @@ class Transformer(nn.Module):
b = struct.pack(f'{len(d)}b', *d)
out_file.write(b)
def quantize_q80(w, group_size=32):
def quantize_q80(w):
"""
takes a tensor and returns the Q8_0 quantized version
i.e. symmetric quantization into int8, range [-127,127]
@@ -367,35 +367,44 @@ class Transformer(nn.Module):
w = w.reshape(-1, group_size)
# find the max in each group
wmax = torch.abs(w).max(dim=1).values
# calculate the scaling factor such that float = quant * scale
scale = wmax / 127.0
# scale into range [-127, 127]
scaled = w/wmax[:,None]*127
quant = w / scale[:,None]
# round to nearest integer
int8val = torch.round(scaled).to(torch.int8)
int8val = torch.round(quant).to(torch.int8)
# dequantize by rescaling
fp32val = (int8val.float()*wmax[:,None]/127.0).view(-1)
fp32val = (int8val.float() * scale[:,None]).view(-1)
fp32valr = fp32val.reshape(-1, group_size)
# calculate the max error in each group
err = torch.abs(fp32valr - w).max(dim=1).values
# find the max error across all groups
maxerr = err.max().item()
return int8val, wmax, maxerr
return int8val, scale, maxerr
# first write out the header. the header will be 256 bytes
nbytes = 0
# 1) write magic, which will be uint32 of "ak42" in ASCII
out_file.write(struct.pack('I', 0x616b3432))
# 2) write version, which will be uint32
out_file.write(struct.pack('I', 1))
nbytes += 4
# 2) write version, which will be int
out_file.write(struct.pack('i', 1))
nbytes += 4
# 3) write the params, which will be 7 ints
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('IIIIIII', p.dim, hidden_dim, p.n_layers, p.n_heads,
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
nbytes += 7*4
# 4) write some other flags
shared_classifier = 1 # we do share a classifier, write flag as a byte
out_file.write(struct.pack('B', shared_classifier))
# ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes
pad = 256 - 37 # pad the rest with zeros
nbytes += 1
out_file.write(struct.pack('i', group_size)) # group size used for quantization
nbytes += 4
pad = 256 - nbytes # pad the rest with zeros
assert pad >= 0
out_file.write(b'\0' * pad)
# now that the header is done, let's write out the model
@@ -420,26 +429,25 @@ class Transformer(nn.Module):
]
ew = []
scales = []
for i, w in enumerate(weights):
# find a good group size for this weight tensor
gs = 64 # group size we want
while w.numel() % gs != 0:
gs //= 2 # but fall back as needed
if gs <= 8:
print(f"WARNING: weight of shape {tuple(w.shape)} caused group size to fall down to {gs}")
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
# quantize this weight
q, s, err = quantize_q80(w, group_size=gs)
q, s, err = quantize_q80(w)
# save to file
out_file.write(struct.pack('I', gs)) # save the group size as uint32
serialize_int8(q) # save the tensor in int8
serialize_fp32(s) # save the scaling factors in fp32
scales.append(s) # we'll do all the scales after all the qs
# logging
ew.append((err, w.shape))
print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}")
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
# save the scaling factors in fp32 here
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
for s in scales:
serialize_fp32(s)
# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)