Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/common/__init__.py
T
2020-10-14 08:55:07 +08:00

197 lines
5.2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import inspect
import re
import sys
import uuid
import warnings
import logging
from onnx.backend.base import DeviceType
from tensorflow.python.client import device_lib
IS_PYTHON3 = sys.version_info > (3,)
logger = logging.getLogger('onnx-tf')
# create console handler and formatter for logger
console = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
logger.addHandler(console)
class Deprecated:
"""Add deprecated message when function is called.
Usage:
from onnx_tf.common import deprecated
@deprecated
def func():
pass
UserWarning: func is deprecated. It will be removed in future release.
@deprecated("Message")
def func():
pass
UserWarning: Message
@deprecated({"arg": "Message",
"arg_1": deprecated.MSG_WILL_REMOVE,
"arg_2": "",})
def func(arg, arg_1, arg_2):
pass
UserWarning: Message
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
UserWarning: arg_2 of func is deprecated.
"""
MSG_WILL_REMOVE = " It will be removed in future release."
def __call__(self, *args, **kwargs):
return self.deprecated_decorator(*args, **kwargs)
@staticmethod
def messages():
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
@staticmethod
def deprecated_decorator(arg=None):
# deprecate function with default message MSG_WILL_REMOVE
# @deprecated
if inspect.isfunction(arg):
def wrapper(*args, **kwargs):
warnings.warn("{} is deprecated.{}".format(
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
return arg(*args, **kwargs)
return wrapper
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
def deco(func):
# deprecate arg
# @deprecated({...})
if isinstance(deprecated_arg, dict):
for name, message in deprecated_arg.items():
if message in Deprecated.messages():
message = "{} of {} is deprecated.{}".format(
name, func.__module__ + "." + func.__name__, message or "")
warnings.warn(message)
# deprecate function with message
# @deprecated("message")
elif isinstance(deprecated_arg, str):
message = deprecated_arg
if message in Deprecated.messages():
message = "{} is deprecated.{}".format(
func.__module__ + "." + func.__name__, message)
warnings.warn(message)
return func
return deco
deprecated = Deprecated()
# This function inserts an underscore before every upper
# case letter and lowers that upper case letter except for
# the first letter.
def op_name_to_lower(name):
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
def get_unique_suffix():
""" Get unique suffix by using first 8 chars from uuid.uuid4
to make unique identity name.
:return: Unique suffix string.
"""
return str(uuid.uuid4())[:8]
def get_perm_from_formats(from_, to_):
""" Get perm from data formats.
For example:
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
:param from_: From data format string.
:param to_: To data format string.
:return: Perm. Int list.
"""
return list(map(lambda x: from_.find(x), to_))
# TODO: allow more flexible placement
def get_device_option(device):
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
return m[device.type]
def get_data_format(x_rank):
""" Get data format by input rank.
Channel first if support CUDA.
:param x_rank: Input rank.
:return: Data format.
"""
sp_dim_names = ["D", "H", "W"]
sp_dim_lst = []
for i in range(x_rank - 2):
sp_dim_lst.append(sp_dim_names[-i - 1])
sp_dim_string = "".join(reversed(sp_dim_lst))
storage_format = "NC" + sp_dim_string
if supports_device("CUDA"):
compute_format = "NC" + sp_dim_string
else:
compute_format = "N" + sp_dim_string + "C"
return storage_format, compute_format
def supports_device(device):
""" Check if support target device.
:param device: CUDA or CPU.
:return: If supports.
"""
if device == "CUDA":
local_device_protos = device_lib.list_local_devices()
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
]) > 0
elif device == "CPU":
return True
return False
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
deprecated.MSG_WILL_REMOVE,
"Use TensorflowGraph.get_outputs_names instead."))
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
nodes, input_names = dict(), set()
for node in graph_def.node:
nodes[node.name] = node
input_names.update(set(node.input))
return list(set(nodes) - input_names)
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"