74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
import inspect
|
|
import onnx_tf.common as common
|
|
|
|
|
|
class CustomException(object):
|
|
|
|
def __init__(self):
|
|
self._func = RuntimeError
|
|
self._message = ""
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
if inspect.isclass(self._func) and issubclass(self._func, Exception):
|
|
raise self._func(self.get_message(*args, **kwargs))
|
|
elif callable(self._func):
|
|
self._func(self.get_message(*args, **kwargs))
|
|
|
|
def get_message(self, *args, **kwargs):
|
|
return self._message
|
|
|
|
|
|
class OpUnimplementedException(CustomException):
|
|
|
|
def __init__(self):
|
|
super(OpUnimplementedException, self).__init__()
|
|
self._func = NotImplementedError
|
|
self._message = "{} is not implemented."
|
|
|
|
def __call__(self, op, version=None, domain=None):
|
|
if IGNORE_UNIMPLEMENTED:
|
|
self._func = common.logger.warning
|
|
super(OpUnimplementedException, self).__call__(op, version, domain)
|
|
|
|
def get_message(self, op, version=None, domain=None):
|
|
insert_message = op
|
|
if version is not None:
|
|
insert_message += " version {}".format(version)
|
|
if domain is not None:
|
|
insert_message += " in domain `{}`".format(domain)
|
|
return self._message.format(insert_message)
|
|
|
|
|
|
class OpUnsupportedException(object):
|
|
|
|
def __init__(self):
|
|
super(OpUnsupportedException, self).__init__()
|
|
self._func = RuntimeError
|
|
self._message = "{} is not supported in {}."
|
|
|
|
def __call__(self, op, framework):
|
|
raise self._func(self.get_message(op, framework))
|
|
|
|
def get_message(self, op, framework):
|
|
return self._message.format(op, framework)
|
|
|
|
|
|
class ConstNotFoundException(CustomException):
|
|
|
|
def __init__(self):
|
|
super(ConstNotFoundException, self).__init__()
|
|
self._func = RuntimeError
|
|
self._message = "{} of {} is not found in graph consts."
|
|
|
|
def __call__(self, name, op):
|
|
super(ConstNotFoundException, self).__call__(name, op)
|
|
|
|
def get_message(self, name, op):
|
|
return self._message.format(name, op)
|
|
|
|
|
|
IGNORE_UNIMPLEMENTED = False
|
|
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
|
|
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
|
|
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
|