from onnx import defs import onnx_tf.common as common from onnx_tf.handlers.backend import * # noqa from onnx_tf.handlers.backend_handler import BackendHandler import onnx_tf.common as common def get_all_backend_handlers(opset_dict): """ Get a dict of all backend handler classes. e.g. {'domain': {'Abs': Abs handler class}, ...}, }. :param opset_dict: A dict of opset. e.g. {'domain': version, ...} :return: Dict. """ handlers = {} for handler in BackendHandler.__subclasses__(): handler.check_cls() domain = handler.DOMAIN version = opset_dict[domain] if domain in opset_dict else 1 handler.VERSION = version since_version = 1 if defs.has(handler.ONNX_OP, domain=handler.DOMAIN): try: since_version = defs.get_schema( handler.ONNX_OP, domain=handler.DOMAIN, max_inclusive_version=version).since_version except RuntimeError: common.logger.debug("Fail to get since_version of {} in domain `{}` " "with max_inclusive_version={}. Set to 1.".format( handler.ONNX_OP, handler.DOMAIN, version)) else: common.logger.debug("Unknown op {} in domain `{}`.".format( handler.ONNX_OP, handler.DOMAIN or "ai.onnx")) handler.SINCE_VERSION = since_version handlers.setdefault(domain, {})[handler.ONNX_OP] = handler return handlers def get_backend_coverage(): """ Get backend coverage for document. :return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...} """ onnx_coverage = {} experimental_op = set() for handler in BackendHandler.__subclasses__(): handler.check_cls() versions = handler.get_versions() domain = handler.DOMAIN if getattr(handler, "EXPERIMENTAL", False): experimental_op.add(handler.ONNX_OP) _update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions) return onnx_coverage, experimental_op def _update_coverage(coverage, domain, key, versions): domain_coverage = coverage.setdefault(domain, {}) vers = domain_coverage.get(key, []) vers.extend(versions) domain_coverage[key] = sorted(list(set(vers))) def get_backend_partial_support_detail(): ps_dict = {} opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())]) handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN] for op_name in handlers: if handlers[op_name].PARTIAL_SUPPORT: ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION return ps_dict