Files
2020-10-14 08:55:07 +08:00

75 lines
2.2 KiB
Python

import numpy as np
from onnx import numpy_helper
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
from onnx_tf.common import data_type
@onnx_op("Constant")
@tf_func(tf.constant)
class Constant(BackendHandler):
@classmethod
def _common(cls, node, **kwargs):
attr_value = node.attrs["value"]
dtype = data_type.onnx2tf(attr_value.data_type)
value = numpy_helper.to_array(attr_value)
return [
cls.make_tensor_from_onnx_node(
node, inputs=[value], attrs={
"dtype": dtype
})
]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_11(cls, node, **kwargs):
# either value or sparse_value
if "value" in node.attrs:
return cls._common(node, **kwargs)
else:
sparse_value = node.attrs["sparse_value"]
indices = numpy_helper.to_array(sparse_value.indices)
values = numpy_helper.to_array(sparse_value.values)
shape = np.array(sparse_value.dims)
return [tf.SparseTensor(indices, values, shape)]
@classmethod
def version_12(cls, node, **kwargs):
if "value" in node.attrs or "sparse_value" in node.attrs:
return cls.version_11(node, **kwargs)
elif "value_float" in node.attrs:
value = node.attrs["value_float"]
dtype = tf.float32
elif "value_floats" in node.attrs:
value = node.attrs["value_floats"]
dtype = tf.float32
elif "value_int" in node.attrs:
value = node.attrs["value_int"]
dtype = tf.int64
elif "value_ints" in node.attrs:
value = node.attrs["value_ints"]
dtype = tf.int64
elif "value_string" in node.attrs:
value = node.attrs["value_string"]
dtype = tf.string
elif "value_strings" in node.attrs:
value = node.attrs["value_strings"]
dtype = tf.string
return [
cls.make_tensor_from_onnx_node(node,
inputs=[value],
attrs={"dtype": dtype})
]