75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
import numpy as np
|
|
|
|
from onnx import numpy_helper
|
|
import tensorflow as tf
|
|
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from onnx_tf.handlers.handler import tf_func
|
|
from onnx_tf.common import data_type
|
|
|
|
|
|
@onnx_op("Constant")
|
|
@tf_func(tf.constant)
|
|
class Constant(BackendHandler):
|
|
|
|
@classmethod
|
|
def _common(cls, node, **kwargs):
|
|
attr_value = node.attrs["value"]
|
|
dtype = data_type.onnx2tf(attr_value.data_type)
|
|
value = numpy_helper.to_array(attr_value)
|
|
return [
|
|
cls.make_tensor_from_onnx_node(
|
|
node, inputs=[value], attrs={
|
|
"dtype": dtype
|
|
})
|
|
]
|
|
|
|
@classmethod
|
|
def version_1(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_9(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_11(cls, node, **kwargs):
|
|
# either value or sparse_value
|
|
if "value" in node.attrs:
|
|
return cls._common(node, **kwargs)
|
|
else:
|
|
sparse_value = node.attrs["sparse_value"]
|
|
indices = numpy_helper.to_array(sparse_value.indices)
|
|
values = numpy_helper.to_array(sparse_value.values)
|
|
shape = np.array(sparse_value.dims)
|
|
return [tf.SparseTensor(indices, values, shape)]
|
|
|
|
@classmethod
|
|
def version_12(cls, node, **kwargs):
|
|
if "value" in node.attrs or "sparse_value" in node.attrs:
|
|
return cls.version_11(node, **kwargs)
|
|
elif "value_float" in node.attrs:
|
|
value = node.attrs["value_float"]
|
|
dtype = tf.float32
|
|
elif "value_floats" in node.attrs:
|
|
value = node.attrs["value_floats"]
|
|
dtype = tf.float32
|
|
elif "value_int" in node.attrs:
|
|
value = node.attrs["value_int"]
|
|
dtype = tf.int64
|
|
elif "value_ints" in node.attrs:
|
|
value = node.attrs["value_ints"]
|
|
dtype = tf.int64
|
|
elif "value_string" in node.attrs:
|
|
value = node.attrs["value_string"]
|
|
dtype = tf.string
|
|
elif "value_strings" in node.attrs:
|
|
value = node.attrs["value_strings"]
|
|
dtype = tf.string
|
|
return [
|
|
cls.make_tensor_from_onnx_node(node,
|
|
inputs=[value],
|
|
attrs={"dtype": dtype})
|
|
]
|