add pt2tf tool
This commit is contained in:
@@ -0,0 +1,188 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.common import IS_PYTHON3
|
||||
from onnx_tf.common import get_data_format
|
||||
from onnx_tf.common import get_perm_from_formats
|
||||
from onnx_tf.common import supports_device
|
||||
from .handler import Handler
|
||||
|
||||
|
||||
class BackendHandler(Handler):
|
||||
""" This class is base backend handler class.
|
||||
All backend operator handler class MUST inherit this class.
|
||||
In backend, operator handler class's name should be pascal case of file name
|
||||
which should be snake case.
|
||||
Use ONNX operator name as class name.
|
||||
"""
|
||||
|
||||
TF_FUNC = None
|
||||
|
||||
@classmethod
|
||||
def get_attrs_processor_param(cls):
|
||||
""" Get param for attrs processor.
|
||||
|
||||
:return: Dict.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _process_attrs(cls, attrs):
|
||||
""" Private method for processing attrs.
|
||||
Param for this processor got from `get_attrs_processor_param`.
|
||||
Param is dict contains two key: `default` and `raname`.
|
||||
First add default value to attrs if key does not exist.
|
||||
Second rename key to new key.
|
||||
|
||||
For example:
|
||||
attrs = {"keep_dims": True}
|
||||
param = {"default": {"axis": 1},
|
||||
"rename": {"keep_dims": "keepdims"}}
|
||||
|
||||
processed_attrs = {"axis": "1", "keepdims": True}
|
||||
|
||||
:param attrs: Process target attrs.
|
||||
:return: Processed attrs.
|
||||
"""
|
||||
param = {"rename": {}, "default": {}}
|
||||
param.update(cls.get_attrs_processor_param())
|
||||
|
||||
for k, v in param["default"].items():
|
||||
attrs.setdefault(k, v)
|
||||
|
||||
for k, new_k in param["rename"].items():
|
||||
if k in attrs:
|
||||
attrs[new_k] = attrs.pop(k)
|
||||
|
||||
return attrs
|
||||
|
||||
@classmethod
|
||||
def make_tensor_from_onnx_node(cls,
|
||||
node,
|
||||
tf_func=None,
|
||||
inputs=None,
|
||||
attrs=None,
|
||||
name="",
|
||||
c_first_cuda_only=False,
|
||||
c_last_only=False,
|
||||
**kwargs):
|
||||
""" Helper method to make tensor.
|
||||
|
||||
:param node: OnnxNode object.
|
||||
:param tf_func: Callable Tf function. Default is cls.TF_FUNC.
|
||||
:param inputs: Inputs tensor. Default is got from node.inputs.
|
||||
:param attrs: Attributes. Default is node.attrs.
|
||||
:param name: Node name.
|
||||
:param c_first_cuda_only: If channel first is only supported by cuda.
|
||||
If true and not cuda, do pre and post transpose.
|
||||
:param c_last_only: If only channel last is support,
|
||||
do pre and post transpose.
|
||||
:param kwargs: Other args.
|
||||
:return: Tensor.
|
||||
"""
|
||||
tensor_dict = kwargs.get("tensor_dict", {})
|
||||
tf_func = tf_func or cls.TF_FUNC
|
||||
if tf_func is None:
|
||||
raise RuntimeError("No Tensorflow function is given.")
|
||||
if inputs is None:
|
||||
inputs = [tensor_dict.get(inp, None) for inp in node.inputs]
|
||||
if attrs is None:
|
||||
attrs = copy.deepcopy(node.attrs)
|
||||
name = name or node.name
|
||||
if name != "":
|
||||
attrs["name"] = name
|
||||
|
||||
if c_first_cuda_only and c_last_only:
|
||||
raise ValueError(
|
||||
"c_first_cuda_only and c_last_only can not both be True.")
|
||||
|
||||
if c_first_cuda_only:
|
||||
return cls.c_first_cuda_only(tf_func, inputs, attrs)
|
||||
elif c_last_only:
|
||||
return cls.c_last_only(tf_func, inputs, attrs)
|
||||
|
||||
return cls._run_tf_func(tf_func, inputs, attrs)
|
||||
|
||||
@classmethod
|
||||
def c_first_cuda_only(cls, tf_func, inputs, attrs):
|
||||
""" Handle operator that channel first is only supported by CUDA.
|
||||
When using CPU, two transposes should be added.
|
||||
|
||||
:param tf_func: Callable Tf function.
|
||||
:param inputs: Inputs tensor.
|
||||
:param attrs: Attributes.
|
||||
:return: Tensor.
|
||||
"""
|
||||
support_cuda = supports_device("CUDA")
|
||||
if not support_cuda:
|
||||
return cls._tuck_transpose(tf_func, inputs, attrs)
|
||||
return cls._run_tf_func(tf_func, inputs, attrs)
|
||||
|
||||
@classmethod
|
||||
def c_last_only(cls, tf_func, inputs, attrs):
|
||||
""" Handle operator that channel last only is supported.
|
||||
Add two transposes anyway.
|
||||
|
||||
:param tf_func: Callable Tf function.
|
||||
:param inputs: Inputs tensor.
|
||||
:param attrs: Attributes.
|
||||
:return: Tensor.
|
||||
"""
|
||||
storage_format, compute_format = get_data_format(len(inputs[0].get_shape()))
|
||||
compute_format = compute_format.replace("C", "") + "C"
|
||||
return cls._tuck_transpose(tf_func, inputs, attrs,
|
||||
(storage_format, compute_format))
|
||||
|
||||
@classmethod
|
||||
def _tuck_transpose(cls, tf_func, inputs, attrs, data_format=None):
|
||||
x = inputs[0]
|
||||
x_rank = len(x.get_shape())
|
||||
if not data_format:
|
||||
data_format = get_data_format(x_rank)
|
||||
pre_perm = get_perm_from_formats(data_format[0], data_format[1])
|
||||
post_perm = get_perm_from_formats(data_format[1], data_format[0])
|
||||
attrs["data_format"] = data_format[1]
|
||||
if pre_perm != list(range(x_rank)):
|
||||
x_t = tf.transpose(x, perm=pre_perm)
|
||||
y = cls._run_tf_func(tf_func, [x_t] + inputs[1:], attrs)
|
||||
y_t = tf.transpose(y, perm=post_perm)
|
||||
return y_t
|
||||
return cls._run_tf_func(tf_func, inputs, attrs)
|
||||
|
||||
@classmethod
|
||||
def _run_tf_func(cls, tf_func, inputs, attrs):
|
||||
""" Run Tensorflow function.
|
||||
Use only acceptable attributes of function from attrs.
|
||||
|
||||
:param tf_func: Tensorflow function.
|
||||
:param inputs: Inputs.
|
||||
:param attrs: Attributes.
|
||||
:return: Tensor.
|
||||
"""
|
||||
if IS_PYTHON3:
|
||||
params = list(inspect.signature(tf_func).parameters.keys())
|
||||
else:
|
||||
# use closure to get args for function using decorator
|
||||
if tf_func.__closure__ is not None:
|
||||
while "__wrapped__" in tf_func.func_dict:
|
||||
tf_func = tf_func.func_dict["__wrapped__"]
|
||||
params = inspect.getargspec(tf_func).args
|
||||
else:
|
||||
params = inspect.getargspec(tf_func).args
|
||||
|
||||
attrs = cls._process_attrs(attrs)
|
||||
attrs = {p: v for p, v in attrs.items() if p in params}
|
||||
kwargs = dict(zip(params, inputs))
|
||||
ambiguous_arguments = any(kwargs.get(p) is not None and v is not None
|
||||
for p, v in attrs.items())
|
||||
if ambiguous_arguments:
|
||||
raise TypeError('Ambiguous arguments for {}()'.format(tf_func.__name__))
|
||||
kwargs.update((p, v) for p, v in attrs.items() if v is not None)
|
||||
return tf_func(**kwargs)
|
||||
Reference in New Issue
Block a user