重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
import inspect
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class CustomException(object):
|
||||
|
||||
def __init__(self):
|
||||
self._func = RuntimeError
|
||||
self._message = ""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if inspect.isclass(self._func) and issubclass(self._func, Exception):
|
||||
raise self._func(self.get_message(*args, **kwargs))
|
||||
elif callable(self._func):
|
||||
self._func(self.get_message(*args, **kwargs))
|
||||
|
||||
def get_message(self, *args, **kwargs):
|
||||
return self._message
|
||||
|
||||
|
||||
class OpUnimplementedException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnimplementedException, self).__init__()
|
||||
self._func = NotImplementedError
|
||||
self._message = "{} is not implemented."
|
||||
|
||||
def __call__(self, op, version=None, domain=None):
|
||||
if IGNORE_UNIMPLEMENTED:
|
||||
self._func = common.logger.warning
|
||||
super(OpUnimplementedException, self).__call__(op, version, domain)
|
||||
|
||||
def get_message(self, op, version=None, domain=None):
|
||||
insert_message = op
|
||||
if version is not None:
|
||||
insert_message += " version {}".format(version)
|
||||
if domain is not None:
|
||||
insert_message += " in domain `{}`".format(domain)
|
||||
return self._message.format(insert_message)
|
||||
|
||||
|
||||
class OpUnsupportedException(object):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnsupportedException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} is not supported in {}."
|
||||
|
||||
def __call__(self, op, framework):
|
||||
raise self._func(self.get_message(op, framework))
|
||||
|
||||
def get_message(self, op, framework):
|
||||
return self._message.format(op, framework)
|
||||
|
||||
|
||||
class ConstNotFoundException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(ConstNotFoundException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} of {} is not found in graph consts."
|
||||
|
||||
def __call__(self, name, op):
|
||||
super(ConstNotFoundException, self).__call__(name, op)
|
||||
|
||||
def get_message(self, name, op):
|
||||
return self._message.format(name, op)
|
||||
|
||||
|
||||
IGNORE_UNIMPLEMENTED = False
|
||||
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
|
||||
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
|
||||
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
|
||||
Reference in New Issue
Block a user