61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import copy
|
|
import tensorflow as tf
|
|
|
|
from onnx_tf.common import exception
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from onnx_tf.handlers.handler import partial_support
|
|
from onnx_tf.handlers.handler import ps_description
|
|
from onnx_tf.handlers.handler import tf_func
|
|
|
|
|
|
@onnx_op("OneHot")
|
|
@tf_func(tf.one_hot)
|
|
@partial_support(True)
|
|
@ps_description("OneHot indices in uint16/uint32/uint64/int8/int16/"+
|
|
"float16/float/double, or " +
|
|
"OneHot depth in uint8/uint16/uint32/uint64/int8/" +
|
|
"int16/int64/float16/float/double " +
|
|
"are not supported in Tensorflow.")
|
|
class OneHot(BackendHandler):
|
|
|
|
@classmethod
|
|
def args_check(cls, node, **kwargs):
|
|
tensor_dict = kwargs["tensor_dict"]
|
|
indices = tensor_dict[node.inputs[0]]
|
|
depth = tensor_dict[node.inputs[1]]
|
|
if indices.dtype not in [tf.uint8, tf.int32, tf.int64]:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"OneHot indices must be in uint8 or int32 or int64 " +
|
|
"but it is currently in " + str(indices.dtype) + " which",
|
|
"Tensorflow")
|
|
if depth.dtype not in [tf.int32]:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"OneHot depth must be in int32 but it is currently in " + str(
|
|
depth.dtype) + " which", "Tensorflow")
|
|
|
|
@classmethod
|
|
def _common(cls, node, **kwargs):
|
|
attrs = copy.deepcopy(node.attrs)
|
|
tensor_dict = kwargs["tensor_dict"]
|
|
indices = tensor_dict[node.inputs[0]]
|
|
depth = tensor_dict[node.inputs[1]]
|
|
off_value = tensor_dict[node.inputs[2]][0]
|
|
on_value = tensor_dict[node.inputs[2]][1]
|
|
attrs["dtype"] = on_value.dtype
|
|
return [
|
|
cls.make_tensor_from_onnx_node(
|
|
node,
|
|
inputs=[indices, depth, on_value, off_value],
|
|
attrs=attrs,
|
|
**kwargs)
|
|
]
|
|
|
|
@classmethod
|
|
def version_9(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_11(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|