重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
@@ -0,0 +1,114 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import inspect
from onnx import defs
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
import onnx_tf.common as common
class Handler(object):
""" This class is base handler class.
Base backend and frontend base handler class inherit this class.
All operator handler MUST put decorator @onnx_op to register corresponding op.
"""
ONNX_OP = None
DOMAIN = defs.ONNX_DOMAIN
VERSION = 0
SINCE_VERSION = 0
PARTIAL_SUPPORT = False
PS_DESCRIPTION = ''
@classmethod
def check_cls(cls):
if not cls.ONNX_OP:
common.logger.warning(
"{} doesn't have ONNX_OP. "
"Please use Handler.onnx_op decorator to register ONNX_OP.".format(
cls.__name__))
@classmethod
def args_check(cls, node, **kwargs):
""" Check args. e.g. if shape info is in graph.
Raise exception if failed.
:param node: NodeProto for backend.
:param kwargs: Other args.
"""
pass
@classmethod
def handle(cls, node, **kwargs):
""" Main method in handler. It will find corresponding versioned handle method,
whose name format is `version_%d`. So prefix `version_` is reserved in onnx-tensorflow.
DON'T use it for other purpose.
:param node: NodeProto for backend.
:param kwargs: Other args.
:return: TensorflowNode for backend.
"""
ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None)
if ver_handle:
cls.args_check(node, **kwargs)
return ver_handle(node, **kwargs)
raise BackendIsNotSupposedToImplementIt("{} version {} is not implemented.".format(node.op_type, cls.SINCE_VERSION))
@classmethod
def get_versions(cls):
""" Get all support versions.
:return: Version list.
"""
versions = []
for k, v in inspect.getmembers(cls, inspect.ismethod):
if k.startswith("version_"):
versions.append(int(k.replace("version_", "")))
return versions
@staticmethod
def onnx_op(op):
return Handler.property_register("ONNX_OP", op)
@staticmethod
def tf_func(func):
return Handler.property_register("TF_FUNC", func)
@staticmethod
def domain(d):
return Handler.property_register("DOMAIN", d)
@staticmethod
def partial_support(ps):
return Handler.property_register("PARTIAL_SUPPORT", ps)
@staticmethod
def ps_description(psd):
return Handler.property_register("PS_DESCRIPTION", psd)
@staticmethod
def property_register(name, value):
def deco(cls):
if inspect.isfunction(value) and not common.IS_PYTHON3:
setattr(cls, name, staticmethod(value))
else:
setattr(cls, name, value)
return cls
return deco
domain = Handler.domain
onnx_op = Handler.onnx_op
tf_func = Handler.tf_func
partial_support = Handler.partial_support
ps_description = Handler.ps_description
property_register = Handler.property_register