re-write the model export to do int8 quantization in groups, with group size fallback, and also change the header to be much better
This commit is contained in:
@@ -340,53 +340,100 @@ class Transformer(nn.Module):
|
|||||||
return idx
|
return idx
|
||||||
|
|
||||||
def export(self, filepath='model.bin'):
|
def export(self, filepath='model.bin'):
|
||||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
"""export the model weights in Q8_0 into .bin file to be read from C"""
|
||||||
f = open(filepath, 'wb')
|
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
||||||
|
out_file = open(filepath, 'wb')
|
||||||
|
|
||||||
def serialize(t):
|
def serialize_fp32(t):
|
||||||
|
""" writes one fp32 tensor to file """
|
||||||
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||||
b = struct.pack(f'{len(d)}f', *d)
|
b = struct.pack(f'{len(d)}f', *d)
|
||||||
f.write(b)
|
out_file.write(b)
|
||||||
|
|
||||||
# first write out the header
|
def serialize_int8(t):
|
||||||
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
|
""" writes one int8 tensor to file """
|
||||||
|
d = t.detach().cpu().view(-1).numpy().astype(np.int8)
|
||||||
|
b = struct.pack(f'{len(d)}b', *d)
|
||||||
|
out_file.write(b)
|
||||||
|
|
||||||
|
def quantize_q80(w, group_size=32):
|
||||||
|
"""
|
||||||
|
takes a tensor and returns the Q8_0 quantized version
|
||||||
|
i.e. symmetric quantization into int8, range [-127,127]
|
||||||
|
"""
|
||||||
|
assert w.numel() % group_size == 0
|
||||||
|
ori_shape = w.shape
|
||||||
|
w = w.float() # convert to float32
|
||||||
|
w = w.reshape(-1, group_size)
|
||||||
|
# find the max in each group
|
||||||
|
wmax = torch.abs(w).max(dim=1).values
|
||||||
|
# scale into range [-127, 127]
|
||||||
|
scaled = w/wmax[:,None]*127
|
||||||
|
# round to nearest integer
|
||||||
|
int8val = torch.round(scaled).to(torch.int8)
|
||||||
|
# dequantize by rescaling
|
||||||
|
fp32val = (int8val.float()*wmax[:,None]/127.0).view(-1)
|
||||||
|
fp32valr = fp32val.reshape(-1, group_size)
|
||||||
|
# calculate the max error in each group
|
||||||
|
err = torch.abs(fp32valr - w).max(dim=1).values
|
||||||
|
# find the max error across all groups
|
||||||
|
maxerr = err.max().item()
|
||||||
|
return int8val, wmax, maxerr
|
||||||
|
|
||||||
|
# first write out the header. the header will be 128 bytes
|
||||||
|
# 1) write magic, which will be uint32 of "ak42" in ASCII
|
||||||
|
out_file.write(struct.pack('I', 0x616b3432))
|
||||||
|
# 2) write version, which will be uint32 of 1
|
||||||
|
out_file.write(struct.pack('I', 1))
|
||||||
|
# 3) write the params, which will be 7 ints
|
||||||
p = self.params
|
p = self.params
|
||||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
header = struct.pack('IIIIIII', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||||
f.write(header)
|
out_file.write(header)
|
||||||
|
# 4) write some other flags
|
||||||
|
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
||||||
|
out_file.write(struct.pack('B', shared_classifier))
|
||||||
|
# ok so we so far used 4 + 4 + 7*4 + 1 = 37 bytes
|
||||||
|
# let's pad the rest of the header to exactly 128 bytes
|
||||||
|
out_file.write(struct.pack('B'*91, *[0]*91))
|
||||||
|
# now that the header is done, let's write out the model
|
||||||
|
|
||||||
# next write out the embedding weights
|
# first let's write out all the params that we are keeping in fp32: the norms
|
||||||
serialize(self.tok_embeddings.weight)
|
for layer in self.layers: # attention norms
|
||||||
|
serialize_fp32(layer.attention_norm.weight)
|
||||||
|
for layer in self.layers: # MLP norms
|
||||||
|
serialize_fp32(layer.ffn_norm.weight)
|
||||||
|
serialize_fp32(self.norm.weight) # final pre-classifier norm
|
||||||
|
|
||||||
# now all the layers
|
# now let's write out all the params that we are quantizing to Q8_0
|
||||||
# attention weights
|
# note we skip classifier weights, which are shared with the embedding
|
||||||
for layer in self.layers:
|
weights = [
|
||||||
serialize(layer.attention_norm.weight)
|
self.tok_embeddings.weight,
|
||||||
for layer in self.layers:
|
*[layer.attention.wq.weight for layer in self.layers],
|
||||||
serialize(layer.attention.wq.weight)
|
*[layer.attention.wk.weight for layer in self.layers],
|
||||||
for layer in self.layers:
|
*[layer.attention.wv.weight for layer in self.layers],
|
||||||
serialize(layer.attention.wk.weight)
|
*[layer.attention.wo.weight for layer in self.layers],
|
||||||
for layer in self.layers:
|
*[layer.feed_forward.w1.weight for layer in self.layers],
|
||||||
serialize(layer.attention.wv.weight)
|
*[layer.feed_forward.w2.weight for layer in self.layers],
|
||||||
for layer in self.layers:
|
*[layer.feed_forward.w3.weight for layer in self.layers],
|
||||||
serialize(layer.attention.wo.weight)
|
]
|
||||||
# ffn weights
|
|
||||||
for layer in self.layers:
|
ew = []
|
||||||
serialize(layer.ffn_norm.weight)
|
for i, w in enumerate(weights):
|
||||||
for layer in self.layers:
|
gs = 64 # group size we want
|
||||||
serialize(layer.feed_forward.w1.weight)
|
while w.numel() % gs != 0:
|
||||||
for layer in self.layers:
|
gs //= 2 # but fall back as needed
|
||||||
serialize(layer.feed_forward.w2.weight)
|
q, s, err = quantize_q80(w, group_size=gs)
|
||||||
for layer in self.layers:
|
out_file.write(struct.pack('I', gs))
|
||||||
serialize(layer.feed_forward.w3.weight)
|
serialize_int8(q) # save the tensor in int8
|
||||||
# final rmsnorm
|
serialize_fp32(s) # save the scaling factors in fp32
|
||||||
serialize(self.norm.weight)
|
ew.append((err, w.shape))
|
||||||
# note: no need to write final classifier weights due to weight sharing
|
print(f"{i:3d} quantized {tuple(w.shape)} to Q8_0 with group size {gs} and max error {err}")
|
||||||
# freqs_cis
|
|
||||||
serialize(self.freqs_cos[:p.max_seq_len])
|
ew.sort(reverse=True)
|
||||||
serialize(self.freqs_sin[:p.max_seq_len])
|
print(f"max quantization group error across all weights: {ew[0][0]}")
|
||||||
|
|
||||||
# write to binary file
|
# write to binary file
|
||||||
f.close()
|
out_file.close()
|
||||||
print(f"wrote {filepath}")
|
print(f"wrote {filepath}")
|
||||||
|
|||||||
Reference in New Issue
Block a user