重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
@@ -0,0 +1,70 @@
import tensorflow as tf
from onnx_tf.common import exception
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import partial_support
from onnx_tf.handlers.handler import ps_description
@onnx_op("Clip")
@partial_support(True)
@ps_description("Clip input in uint64 is not supported in Tensorflow.")
class Clip(BackendHandler):
@classmethod
def args_check(cls, node, **kwargs):
x = kwargs["tensor_dict"][node.inputs[0]]
# uint64 cannot upcast to any tensorflow supported datatype
# for tf.clip_by_value that didn't lose precision
if x.dtype == tf.uint64:
exception.OP_UNSUPPORTED_EXCEPT(
"Clip input, min and max in " + str(x.dtype) + " datatype",
"Tensorflow")
@classmethod
def _common(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
x_dtype = x.dtype
if cls.SINCE_VERSION < 11:
# min/max were required and passed as attributes
clip_value_min = node.attrs.get("min", tf.reduce_min(x))
clip_value_max = node.attrs.get("max", tf.reduce_max(x))
else:
# min/max are optional and passed as inputs
clip_value_min = tensor_dict[node.inputs[1]] if len(
node.inputs) > 1 and node.inputs[1] != "" else x_dtype.min
clip_value_max = tensor_dict[node.inputs[2]] if len(
node.inputs) > 2 and node.inputs[2] != "" else x_dtype.max
# tf.clip_by_value doesn't support uint8, uint16, uint32, int8 and int16
# dtype for x, therefore need to upcast it to tf.int32 or tf.int64
if x_dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16]:
cast_to = tf.int64 if x_dtype == tf.uint32 else tf.int32
x = tf.cast(x, cast_to)
clip_value_min = tf.cast(clip_value_min, cast_to)
clip_value_max = tf.cast(clip_value_max, cast_to)
y = tf.clip_by_value(x, clip_value_min, clip_value_max)
y = tf.cast(y, x_dtype)
else:
y = tf.clip_by_value(x, clip_value_min, clip_value_max)
return [y]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_6(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_12(cls, node, **kwargs):
return cls._common(node, **kwargs)