add pt2tf tool
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
import copy
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.common import exception
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from onnx_tf.handlers.handler import partial_support
|
||||
from onnx_tf.handlers.handler import ps_description
|
||||
from onnx_tf.handlers.handler import tf_func
|
||||
|
||||
|
||||
@onnx_op("OneHot")
|
||||
@tf_func(tf.one_hot)
|
||||
@partial_support(True)
|
||||
@ps_description("OneHot indices in uint16/uint32/uint64/int8/int16/"+
|
||||
"float16/float/double, or " +
|
||||
"OneHot depth in uint8/uint16/uint32/uint64/int8/" +
|
||||
"int16/int64/float16/float/double " +
|
||||
"are not supported in Tensorflow.")
|
||||
class OneHot(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def args_check(cls, node, **kwargs):
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
indices = tensor_dict[node.inputs[0]]
|
||||
depth = tensor_dict[node.inputs[1]]
|
||||
if indices.dtype not in [tf.uint8, tf.int32, tf.int64]:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
"OneHot indices must be in uint8 or int32 or int64 " +
|
||||
"but it is currently in " + str(indices.dtype) + " which",
|
||||
"Tensorflow")
|
||||
if depth.dtype not in [tf.int32]:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
"OneHot depth must be in int32 but it is currently in " + str(
|
||||
depth.dtype) + " which", "Tensorflow")
|
||||
|
||||
@classmethod
|
||||
def _common(cls, node, **kwargs):
|
||||
attrs = copy.deepcopy(node.attrs)
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
indices = tensor_dict[node.inputs[0]]
|
||||
depth = tensor_dict[node.inputs[1]]
|
||||
off_value = tensor_dict[node.inputs[2]][0]
|
||||
on_value = tensor_dict[node.inputs[2]][1]
|
||||
attrs["dtype"] = on_value.dtype
|
||||
return [
|
||||
cls.make_tensor_from_onnx_node(
|
||||
node,
|
||||
inputs=[indices, depth, on_value, off_value],
|
||||
attrs=attrs,
|
||||
**kwargs)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def version_9(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
Reference in New Issue
Block a user