重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from .conv_mixin import ConvMixin
|
||||
|
||||
|
||||
@onnx_op("ConvInteger")
|
||||
class ConvInteger(ConvMixin, BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def _apply_zero_point(cls, base, zero_point):
|
||||
base = tf.cast(base, tf.float32)
|
||||
zero_point = tf.cast(zero_point, tf.float32)
|
||||
return base - zero_point
|
||||
|
||||
@classmethod
|
||||
def version_10(cls, node, **kwargs):
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
x = tensor_dict[node.inputs[0]]
|
||||
w = tensor_dict[node.inputs[1]]
|
||||
|
||||
def process_conv(new_x, new_w):
|
||||
# Remove zero-points from inputs
|
||||
if len(node.inputs) == 4:
|
||||
node.inputs.remove(node.inputs[3])
|
||||
if len(node.inputs) == 3:
|
||||
node.inputs.remove(node.inputs[2])
|
||||
|
||||
new_dict = { node.inputs[0]:new_x, node.inputs[1]:new_w }
|
||||
|
||||
# Use common conv handling
|
||||
conv_node = cls.conv(node, new_dict)
|
||||
|
||||
return conv_node
|
||||
|
||||
# Apply x_zero_point first
|
||||
x = cls._apply_zero_point(
|
||||
x, tensor_dict[node.inputs[2]]) if len(node.inputs) > 2 else tf.cast(
|
||||
x, tf.float32)
|
||||
|
||||
# Apply w_zero_point next
|
||||
if len(node.inputs) == 4:
|
||||
w_zero_point = tensor_dict[node.inputs[3]]
|
||||
if w_zero_point.shape.rank == 0:
|
||||
# Simply apply w_zero_point for scalar
|
||||
w = cls._apply_zero_point(w, w_zero_point)
|
||||
elif w_zero_point.shape.rank == 1:
|
||||
# Need additional processing for 1d w_zero_point
|
||||
tensor_list = []
|
||||
process_shape = [1] + [w.shape[i] for i in range(1, len(w.shape))]
|
||||
for i in range(w.shape.as_list()[0]):
|
||||
# Apply w_zero_point for each element in 1d tensor
|
||||
out_tensor = cls._apply_zero_point(w[i], w_zero_point[i])
|
||||
tensor_list.append(tf.reshape(out_tensor, process_shape))
|
||||
w = tf.concat(tensor_list, 0)
|
||||
else:
|
||||
raise ValueError("Unsupported w zero point: {}".format(w_zero_point))
|
||||
else:
|
||||
# Just cast without processing w
|
||||
w = tf.cast(w, tf.float32)
|
||||
|
||||
return [tf.cast(process_conv(x, w)[0], tf.int32)]
|
||||
Reference in New Issue
Block a user