重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.common import exception
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from onnx_tf.handlers.handler import partial_support
|
||||
from onnx_tf.handlers.handler import ps_description
|
||||
|
||||
|
||||
@onnx_op("Clip")
|
||||
@partial_support(True)
|
||||
@ps_description("Clip input in uint64 is not supported in Tensorflow.")
|
||||
class Clip(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def args_check(cls, node, **kwargs):
|
||||
x = kwargs["tensor_dict"][node.inputs[0]]
|
||||
# uint64 cannot upcast to any tensorflow supported datatype
|
||||
# for tf.clip_by_value that didn't lose precision
|
||||
if x.dtype == tf.uint64:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
"Clip input, min and max in " + str(x.dtype) + " datatype",
|
||||
"Tensorflow")
|
||||
|
||||
@classmethod
|
||||
def _common(cls, node, **kwargs):
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
x = tensor_dict[node.inputs[0]]
|
||||
x_dtype = x.dtype
|
||||
|
||||
if cls.SINCE_VERSION < 11:
|
||||
# min/max were required and passed as attributes
|
||||
clip_value_min = node.attrs.get("min", tf.reduce_min(x))
|
||||
clip_value_max = node.attrs.get("max", tf.reduce_max(x))
|
||||
else:
|
||||
# min/max are optional and passed as inputs
|
||||
clip_value_min = tensor_dict[node.inputs[1]] if len(
|
||||
node.inputs) > 1 and node.inputs[1] != "" else x_dtype.min
|
||||
clip_value_max = tensor_dict[node.inputs[2]] if len(
|
||||
node.inputs) > 2 and node.inputs[2] != "" else x_dtype.max
|
||||
|
||||
# tf.clip_by_value doesn't support uint8, uint16, uint32, int8 and int16
|
||||
# dtype for x, therefore need to upcast it to tf.int32 or tf.int64
|
||||
if x_dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16]:
|
||||
cast_to = tf.int64 if x_dtype == tf.uint32 else tf.int32
|
||||
x = tf.cast(x, cast_to)
|
||||
clip_value_min = tf.cast(clip_value_min, cast_to)
|
||||
clip_value_max = tf.cast(clip_value_max, cast_to)
|
||||
y = tf.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
y = tf.cast(y, x_dtype)
|
||||
else:
|
||||
y = tf.clip_by_value(x, clip_value_min, clip_value_max)
|
||||
|
||||
return [y]
|
||||
|
||||
@classmethod
|
||||
def version_1(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_6(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_12(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
Reference in New Issue
Block a user