add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,29 @@
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
@onnx_op("QuantizeLinear")
class QuantizeLinear(BackendHandler):
@classmethod
def version_10(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
y_scale = tensor_dict[node.inputs[1]]
x = tf.cast(x, tf.float32)
y = tf.divide(x, y_scale)
y = tf.round(y)
if len(node.inputs) == 3:
y_zero_point = tensor_dict[node.inputs[2]]
y_dtype = y_zero_point.dtype
y_zero_point = tf.cast(y_zero_point, tf.float32)
y = tf.add(y, y_zero_point)
else: # y_zero_point default dtype = uint8
y_dtype = tf.uint8
y = tf.saturate_cast(y, y_dtype)
return [y]