oops fix double indent on quantize def
This commit is contained in:
@@ -37,30 +37,30 @@ def serialize_int8(file, tensor):
|
|||||||
file.write(b)
|
file.write(b)
|
||||||
|
|
||||||
def quantize_q80(w, group_size):
|
def quantize_q80(w, group_size):
|
||||||
"""
|
"""
|
||||||
takes a tensor and returns the Q8_0 quantized version
|
takes a tensor and returns the Q8_0 quantized version
|
||||||
i.e. symmetric quantization into int8, range [-127,127]
|
i.e. symmetric quantization into int8, range [-127,127]
|
||||||
"""
|
"""
|
||||||
assert w.numel() % group_size == 0
|
assert w.numel() % group_size == 0
|
||||||
ori_shape = w.shape
|
ori_shape = w.shape
|
||||||
w = w.float() # convert to float32
|
w = w.float() # convert to float32
|
||||||
w = w.reshape(-1, group_size)
|
w = w.reshape(-1, group_size)
|
||||||
# find the max in each group
|
# find the max in each group
|
||||||
wmax = torch.abs(w).max(dim=1).values
|
wmax = torch.abs(w).max(dim=1).values
|
||||||
# calculate the scaling factor such that float = quant * scale
|
# calculate the scaling factor such that float = quant * scale
|
||||||
scale = wmax / 127.0
|
scale = wmax / 127.0
|
||||||
# scale into range [-127, 127]
|
# scale into range [-127, 127]
|
||||||
quant = w / scale[:,None]
|
quant = w / scale[:,None]
|
||||||
# round to nearest integer
|
# round to nearest integer
|
||||||
int8val = torch.round(quant).to(torch.int8)
|
int8val = torch.round(quant).to(torch.int8)
|
||||||
# dequantize by rescaling
|
# dequantize by rescaling
|
||||||
fp32val = (int8val.float() * scale[:,None]).view(-1)
|
fp32val = (int8val.float() * scale[:,None]).view(-1)
|
||||||
fp32valr = fp32val.reshape(-1, group_size)
|
fp32valr = fp32val.reshape(-1, group_size)
|
||||||
# calculate the max error in each group
|
# calculate the max error in each group
|
||||||
err = torch.abs(fp32valr - w).max(dim=1).values
|
err = torch.abs(fp32valr - w).max(dim=1).values
|
||||||
# find the max error across all groups
|
# find the max error across all groups
|
||||||
maxerr = err.max().item()
|
maxerr = err.max().item()
|
||||||
return int8val, scale, maxerr
|
return int8val, scale, maxerr
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# legacy
|
# legacy
|
||||||
|
|||||||
Reference in New Issue
Block a user