Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/handlers/backend/quantize_linear.py
T
2020-10-14 08:55:07 +08:00

30 lines
781 B
Python

import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
@onnx_op("QuantizeLinear")
class QuantizeLinear(BackendHandler):
@classmethod
def version_10(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
y_scale = tensor_dict[node.inputs[1]]
x = tf.cast(x, tf.float32)
y = tf.divide(x, y_scale)
y = tf.round(y)
if len(node.inputs) == 3:
y_zero_point = tensor_dict[node.inputs[2]]
y_dtype = y_zero_point.dtype
y_zero_point = tf.cast(y_zero_point, tf.float32)
y = tf.add(y, y_zero_point)
else: # y_zero_point default dtype = uint8
y_dtype = tf.uint8
y = tf.saturate_cast(y, y_dtype)
return [y]