64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
import tensorflow as tf
|
|
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from .conv_mixin import ConvMixin
|
|
|
|
|
|
@onnx_op("ConvInteger")
|
|
class ConvInteger(ConvMixin, BackendHandler):
|
|
|
|
@classmethod
|
|
def _apply_zero_point(cls, base, zero_point):
|
|
base = tf.cast(base, tf.float32)
|
|
zero_point = tf.cast(zero_point, tf.float32)
|
|
return base - zero_point
|
|
|
|
@classmethod
|
|
def version_10(cls, node, **kwargs):
|
|
tensor_dict = kwargs["tensor_dict"]
|
|
x = tensor_dict[node.inputs[0]]
|
|
w = tensor_dict[node.inputs[1]]
|
|
|
|
def process_conv(new_x, new_w):
|
|
# Remove zero-points from inputs
|
|
if len(node.inputs) == 4:
|
|
node.inputs.remove(node.inputs[3])
|
|
if len(node.inputs) == 3:
|
|
node.inputs.remove(node.inputs[2])
|
|
|
|
new_dict = { node.inputs[0]:new_x, node.inputs[1]:new_w }
|
|
|
|
# Use common conv handling
|
|
conv_node = cls.conv(node, new_dict)
|
|
|
|
return conv_node
|
|
|
|
# Apply x_zero_point first
|
|
x = cls._apply_zero_point(
|
|
x, tensor_dict[node.inputs[2]]) if len(node.inputs) > 2 else tf.cast(
|
|
x, tf.float32)
|
|
|
|
# Apply w_zero_point next
|
|
if len(node.inputs) == 4:
|
|
w_zero_point = tensor_dict[node.inputs[3]]
|
|
if w_zero_point.shape.rank == 0:
|
|
# Simply apply w_zero_point for scalar
|
|
w = cls._apply_zero_point(w, w_zero_point)
|
|
elif w_zero_point.shape.rank == 1:
|
|
# Need additional processing for 1d w_zero_point
|
|
tensor_list = []
|
|
process_shape = [1] + [w.shape[i] for i in range(1, len(w.shape))]
|
|
for i in range(w.shape.as_list()[0]):
|
|
# Apply w_zero_point for each element in 1d tensor
|
|
out_tensor = cls._apply_zero_point(w[i], w_zero_point[i])
|
|
tensor_list.append(tf.reshape(out_tensor, process_shape))
|
|
w = tf.concat(tensor_list, 0)
|
|
else:
|
|
raise ValueError("Unsupported w zero point: {}".format(w_zero_point))
|
|
else:
|
|
# Just cast without processing w
|
|
w = tf.cast(w, tf.float32)
|
|
|
|
return [tf.cast(process_conv(x, w)[0], tf.int32)]
|