Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/handlers/backend/onehot.py
T
2020-10-14 08:55:07 +08:00

61 lines
2.0 KiB
Python

import copy
import tensorflow as tf
from onnx_tf.common import exception
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import partial_support
from onnx_tf.handlers.handler import ps_description
from onnx_tf.handlers.handler import tf_func
@onnx_op("OneHot")
@tf_func(tf.one_hot)
@partial_support(True)
@ps_description("OneHot indices in uint16/uint32/uint64/int8/int16/"+
"float16/float/double, or " +
"OneHot depth in uint8/uint16/uint32/uint64/int8/" +
"int16/int64/float16/float/double " +
"are not supported in Tensorflow.")
class OneHot(BackendHandler):
@classmethod
def args_check(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
indices = tensor_dict[node.inputs[0]]
depth = tensor_dict[node.inputs[1]]
if indices.dtype not in [tf.uint8, tf.int32, tf.int64]:
exception.OP_UNSUPPORTED_EXCEPT(
"OneHot indices must be in uint8 or int32 or int64 " +
"but it is currently in " + str(indices.dtype) + " which",
"Tensorflow")
if depth.dtype not in [tf.int32]:
exception.OP_UNSUPPORTED_EXCEPT(
"OneHot depth must be in int32 but it is currently in " + str(
depth.dtype) + " which", "Tensorflow")
@classmethod
def _common(cls, node, **kwargs):
attrs = copy.deepcopy(node.attrs)
tensor_dict = kwargs["tensor_dict"]
indices = tensor_dict[node.inputs[0]]
depth = tensor_dict[node.inputs[1]]
off_value = tensor_dict[node.inputs[2]][0]
on_value = tensor_dict[node.inputs[2]][1]
attrs["dtype"] = on_value.dtype
return [
cls.make_tensor_from_onnx_node(
node,
inputs=[indices, depth, on_value, off_value],
attrs=attrs,
**kwargs)
]
@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)