add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,63 @@
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from .conv_mixin import ConvMixin
@onnx_op("ConvInteger")
class ConvInteger(ConvMixin, BackendHandler):
@classmethod
def _apply_zero_point(cls, base, zero_point):
base = tf.cast(base, tf.float32)
zero_point = tf.cast(zero_point, tf.float32)
return base - zero_point
@classmethod
def version_10(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
w = tensor_dict[node.inputs[1]]
def process_conv(new_x, new_w):
# Remove zero-points from inputs
if len(node.inputs) == 4:
node.inputs.remove(node.inputs[3])
if len(node.inputs) == 3:
node.inputs.remove(node.inputs[2])
new_dict = { node.inputs[0]:new_x, node.inputs[1]:new_w }
# Use common conv handling
conv_node = cls.conv(node, new_dict)
return conv_node
# Apply x_zero_point first
x = cls._apply_zero_point(
x, tensor_dict[node.inputs[2]]) if len(node.inputs) > 2 else tf.cast(
x, tf.float32)
# Apply w_zero_point next
if len(node.inputs) == 4:
w_zero_point = tensor_dict[node.inputs[3]]
if w_zero_point.shape.rank == 0:
# Simply apply w_zero_point for scalar
w = cls._apply_zero_point(w, w_zero_point)
elif w_zero_point.shape.rank == 1:
# Need additional processing for 1d w_zero_point
tensor_list = []
process_shape = [1] + [w.shape[i] for i in range(1, len(w.shape))]
for i in range(w.shape.as_list()[0]):
# Apply w_zero_point for each element in 1d tensor
out_tensor = cls._apply_zero_point(w[i], w_zero_point[i])
tensor_list.append(tf.reshape(out_tensor, process_shape))
w = tf.concat(tensor_list, 0)
else:
raise ValueError("Unsupported w zero point: {}".format(w_zero_point))
else:
# Just cast without processing w
w = tf.cast(w, tf.float32)
return [tf.cast(process_conv(x, w)[0], tf.int32)]