From 4212bd6d4343ac8a13efaced5609af268e7f4730 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 19 Aug 2023 18:34:49 +0000 Subject: [PATCH] oops fix double indent on quantize def --- export.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/export.py b/export.py index 4710649..db874b0 100644 --- a/export.py +++ b/export.py @@ -37,30 +37,30 @@ def serialize_int8(file, tensor): file.write(b) def quantize_q80(w, group_size): - """ - takes a tensor and returns the Q8_0 quantized version - i.e. symmetric quantization into int8, range [-127,127] - """ - assert w.numel() % group_size == 0 - ori_shape = w.shape - w = w.float() # convert to float32 - w = w.reshape(-1, group_size) - # find the max in each group - wmax = torch.abs(w).max(dim=1).values - # calculate the scaling factor such that float = quant * scale - scale = wmax / 127.0 - # scale into range [-127, 127] - quant = w / scale[:,None] - # round to nearest integer - int8val = torch.round(quant).to(torch.int8) - # dequantize by rescaling - fp32val = (int8val.float() * scale[:,None]).view(-1) - fp32valr = fp32val.reshape(-1, group_size) - # calculate the max error in each group - err = torch.abs(fp32valr - w).max(dim=1).values - # find the max error across all groups - maxerr = err.max().item() - return int8val, scale, maxerr + """ + takes a tensor and returns the Q8_0 quantized version + i.e. symmetric quantization into int8, range [-127,127] + """ + assert w.numel() % group_size == 0 + ori_shape = w.shape + w = w.float() # convert to float32 + w = w.reshape(-1, group_size) + # find the max in each group + wmax = torch.abs(w).max(dim=1).values + # calculate the scaling factor such that float = quant * scale + scale = wmax / 127.0 + # scale into range [-127, 127] + quant = w / scale[:,None] + # round to nearest integer + int8val = torch.round(quant).to(torch.int8) + # dequantize by rescaling + fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32valr = fp32val.reshape(-1, group_size) + # calculate the max error in each group + err = torch.abs(fp32valr - w).max(dim=1).values + # find the max error across all groups + maxerr = err.max().item() + return int8val, scale, maxerr # ----------------------------------------------------------------------------- # legacy