from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import inspect import re import sys import uuid import warnings import logging from onnx.backend.base import DeviceType from tensorflow.python.client import device_lib IS_PYTHON3 = sys.version_info > (3,) logger = logging.getLogger('onnx-tf') # create console handler and formatter for logger console = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') console.setFormatter(formatter) logger.addHandler(console) class Deprecated: """Add deprecated message when function is called. Usage: from onnx_tf.common import deprecated @deprecated def func(): pass UserWarning: func is deprecated. It will be removed in future release. @deprecated("Message") def func(): pass UserWarning: Message @deprecated({"arg": "Message", "arg_1": deprecated.MSG_WILL_REMOVE, "arg_2": "",}) def func(arg, arg_1, arg_2): pass UserWarning: Message UserWarning: arg_1 of func is deprecated. It will be removed in future release. UserWarning: arg_2 of func is deprecated. """ MSG_WILL_REMOVE = " It will be removed in future release." def __call__(self, *args, **kwargs): return self.deprecated_decorator(*args, **kwargs) @staticmethod def messages(): return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")} @staticmethod def deprecated_decorator(arg=None): # deprecate function with default message MSG_WILL_REMOVE # @deprecated if inspect.isfunction(arg): def wrapper(*args, **kwargs): warnings.warn("{} is deprecated.{}".format( arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE)) return arg(*args, **kwargs) return wrapper deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE def deco(func): # deprecate arg # @deprecated({...}) if isinstance(deprecated_arg, dict): for name, message in deprecated_arg.items(): if message in Deprecated.messages(): message = "{} of {} is deprecated.{}".format( name, func.__module__ + "." + func.__name__, message or "") warnings.warn(message) # deprecate function with message # @deprecated("message") elif isinstance(deprecated_arg, str): message = deprecated_arg if message in Deprecated.messages(): message = "{} is deprecated.{}".format( func.__module__ + "." + func.__name__, message) warnings.warn(message) return func return deco deprecated = Deprecated() # This function inserts an underscore before every upper # case letter and lowers that upper case letter except for # the first letter. def op_name_to_lower(name): return re.sub('(? 0 elif device == "CPU": return True return False @deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format( deprecated.MSG_WILL_REMOVE, "Use TensorflowGraph.get_outputs_names instead.")) def get_output_node_names(graph_def): """Get output node names from GraphDef. Args: graph_def: GraphDef object. Returns: List of output node names. """ nodes, input_names = dict(), set() for node in graph_def.node: nodes[node.name] = node input_names.update(set(node.input)) return list(set(nodes) - input_names) CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32" CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32" CONST_ONE_INT32 = "_onnx_tf_internal_one_int32" CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"