add pt2tf tool
This commit is contained in:
@@ -0,0 +1,74 @@
|
||||
import numpy as np
|
||||
|
||||
from onnx import numpy_helper
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from onnx_tf.handlers.handler import tf_func
|
||||
from onnx_tf.common import data_type
|
||||
|
||||
|
||||
@onnx_op("Constant")
|
||||
@tf_func(tf.constant)
|
||||
class Constant(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def _common(cls, node, **kwargs):
|
||||
attr_value = node.attrs["value"]
|
||||
dtype = data_type.onnx2tf(attr_value.data_type)
|
||||
value = numpy_helper.to_array(attr_value)
|
||||
return [
|
||||
cls.make_tensor_from_onnx_node(
|
||||
node, inputs=[value], attrs={
|
||||
"dtype": dtype
|
||||
})
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def version_1(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_9(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
# either value or sparse_value
|
||||
if "value" in node.attrs:
|
||||
return cls._common(node, **kwargs)
|
||||
else:
|
||||
sparse_value = node.attrs["sparse_value"]
|
||||
indices = numpy_helper.to_array(sparse_value.indices)
|
||||
values = numpy_helper.to_array(sparse_value.values)
|
||||
shape = np.array(sparse_value.dims)
|
||||
return [tf.SparseTensor(indices, values, shape)]
|
||||
|
||||
@classmethod
|
||||
def version_12(cls, node, **kwargs):
|
||||
if "value" in node.attrs or "sparse_value" in node.attrs:
|
||||
return cls.version_11(node, **kwargs)
|
||||
elif "value_float" in node.attrs:
|
||||
value = node.attrs["value_float"]
|
||||
dtype = tf.float32
|
||||
elif "value_floats" in node.attrs:
|
||||
value = node.attrs["value_floats"]
|
||||
dtype = tf.float32
|
||||
elif "value_int" in node.attrs:
|
||||
value = node.attrs["value_int"]
|
||||
dtype = tf.int64
|
||||
elif "value_ints" in node.attrs:
|
||||
value = node.attrs["value_ints"]
|
||||
dtype = tf.int64
|
||||
elif "value_string" in node.attrs:
|
||||
value = node.attrs["value_string"]
|
||||
dtype = tf.string
|
||||
elif "value_strings" in node.attrs:
|
||||
value = node.attrs["value_strings"]
|
||||
dtype = tf.string
|
||||
return [
|
||||
cls.make_tensor_from_onnx_node(node,
|
||||
inputs=[value],
|
||||
attrs={"dtype": dtype})
|
||||
]
|
||||
Reference in New Issue
Block a user