重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
from onnx.backend.base import DeviceType
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
IS_PYTHON3 = sys.version_info > (3,)
|
||||
logger = logging.getLogger('onnx-tf')
|
||||
|
||||
# create console handler and formatter for logger
|
||||
console = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
|
||||
class Deprecated:
|
||||
"""Add deprecated message when function is called.
|
||||
|
||||
Usage:
|
||||
from onnx_tf.common import deprecated
|
||||
|
||||
@deprecated
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: func is deprecated. It will be removed in future release.
|
||||
|
||||
@deprecated("Message")
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
|
||||
@deprecated({"arg": "Message",
|
||||
"arg_1": deprecated.MSG_WILL_REMOVE,
|
||||
"arg_2": "",})
|
||||
def func(arg, arg_1, arg_2):
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
|
||||
UserWarning: arg_2 of func is deprecated.
|
||||
"""
|
||||
|
||||
MSG_WILL_REMOVE = " It will be removed in future release."
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.deprecated_decorator(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def messages():
|
||||
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
|
||||
|
||||
@staticmethod
|
||||
def deprecated_decorator(arg=None):
|
||||
# deprecate function with default message MSG_WILL_REMOVE
|
||||
# @deprecated
|
||||
if inspect.isfunction(arg):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn("{} is deprecated.{}".format(
|
||||
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
|
||||
return arg(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
|
||||
|
||||
def deco(func):
|
||||
# deprecate arg
|
||||
# @deprecated({...})
|
||||
if isinstance(deprecated_arg, dict):
|
||||
for name, message in deprecated_arg.items():
|
||||
if message in Deprecated.messages():
|
||||
message = "{} of {} is deprecated.{}".format(
|
||||
name, func.__module__ + "." + func.__name__, message or "")
|
||||
warnings.warn(message)
|
||||
# deprecate function with message
|
||||
# @deprecated("message")
|
||||
elif isinstance(deprecated_arg, str):
|
||||
message = deprecated_arg
|
||||
if message in Deprecated.messages():
|
||||
message = "{} is deprecated.{}".format(
|
||||
func.__module__ + "." + func.__name__, message)
|
||||
warnings.warn(message)
|
||||
return func
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
deprecated = Deprecated()
|
||||
|
||||
|
||||
# This function inserts an underscore before every upper
|
||||
# case letter and lowers that upper case letter except for
|
||||
# the first letter.
|
||||
def op_name_to_lower(name):
|
||||
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
|
||||
|
||||
|
||||
def get_unique_suffix():
|
||||
""" Get unique suffix by using first 8 chars from uuid.uuid4
|
||||
to make unique identity name.
|
||||
|
||||
:return: Unique suffix string.
|
||||
"""
|
||||
return str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
def get_perm_from_formats(from_, to_):
|
||||
""" Get perm from data formats.
|
||||
For example:
|
||||
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
|
||||
|
||||
:param from_: From data format string.
|
||||
:param to_: To data format string.
|
||||
:return: Perm. Int list.
|
||||
"""
|
||||
return list(map(lambda x: from_.find(x), to_))
|
||||
|
||||
|
||||
# TODO: allow more flexible placement
|
||||
def get_device_option(device):
|
||||
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
|
||||
return m[device.type]
|
||||
|
||||
|
||||
def get_data_format(x_rank):
|
||||
""" Get data format by input rank.
|
||||
Channel first if support CUDA.
|
||||
|
||||
:param x_rank: Input rank.
|
||||
:return: Data format.
|
||||
"""
|
||||
sp_dim_names = ["D", "H", "W"]
|
||||
sp_dim_lst = []
|
||||
for i in range(x_rank - 2):
|
||||
sp_dim_lst.append(sp_dim_names[-i - 1])
|
||||
|
||||
sp_dim_string = "".join(reversed(sp_dim_lst))
|
||||
storage_format = "NC" + sp_dim_string
|
||||
|
||||
if supports_device("CUDA"):
|
||||
compute_format = "NC" + sp_dim_string
|
||||
else:
|
||||
compute_format = "N" + sp_dim_string + "C"
|
||||
return storage_format, compute_format
|
||||
|
||||
|
||||
def supports_device(device):
|
||||
""" Check if support target device.
|
||||
|
||||
:param device: CUDA or CPU.
|
||||
:return: If supports.
|
||||
"""
|
||||
if device == "CUDA":
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
|
||||
]) > 0
|
||||
elif device == "CPU":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
|
||||
deprecated.MSG_WILL_REMOVE,
|
||||
"Use TensorflowGraph.get_outputs_names instead."))
|
||||
def get_output_node_names(graph_def):
|
||||
"""Get output node names from GraphDef.
|
||||
Args:
|
||||
graph_def: GraphDef object.
|
||||
Returns:
|
||||
List of output node names.
|
||||
"""
|
||||
nodes, input_names = dict(), set()
|
||||
for node in graph_def.node:
|
||||
nodes[node.name] = node
|
||||
input_names.update(set(node.input))
|
||||
return list(set(nodes) - input_names)
|
||||
|
||||
|
||||
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
|
||||
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
|
||||
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
|
||||
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from onnx_tf.common import IS_PYTHON3
|
||||
|
||||
|
||||
def convert_tf(attr):
|
||||
return __convert_tf_attr_value(attr)
|
||||
|
||||
|
||||
def convert_onnx(attr):
|
||||
return __convert_onnx_attribute_proto(attr)
|
||||
|
||||
|
||||
def __convert_tf_attr_value(attr):
|
||||
""" convert Tensorflow AttrValue object to Python object
|
||||
"""
|
||||
if attr.HasField('list'):
|
||||
return __convert_tf_list_value(attr.list)
|
||||
if attr.HasField('s'):
|
||||
return attr.s
|
||||
elif attr.HasField('i'):
|
||||
return attr.i
|
||||
elif attr.HasField('f'):
|
||||
return attr.f
|
||||
elif attr.HasField('b'):
|
||||
return attr.b
|
||||
elif attr.HasField('type'):
|
||||
return attr.type
|
||||
elif attr.HasField('shape'):
|
||||
return attr.type
|
||||
elif attr.HasField('tensor'):
|
||||
return attr.tensor
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
|
||||
|
||||
|
||||
def __convert_tf_list_value(list_value):
|
||||
""" convert Tensorflow ListValue object to Python object
|
||||
"""
|
||||
if list_value.s:
|
||||
return list_value.s
|
||||
elif list_value.i:
|
||||
return list_value.i
|
||||
elif list_value.f:
|
||||
return list_value.f
|
||||
elif list_value.b:
|
||||
return list_value.b
|
||||
elif list_value.tensor:
|
||||
return list_value.tensor
|
||||
elif list_value.type:
|
||||
return list_value.type
|
||||
elif list_value.shape:
|
||||
return list_value.shape
|
||||
elif list_value.func:
|
||||
return list_value.func
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
|
||||
|
||||
|
||||
def __convert_onnx_attribute_proto(attr_proto):
|
||||
"""
|
||||
Convert an ONNX AttributeProto into an appropriate Python object
|
||||
for the type.
|
||||
NB: Tensor attribute gets returned as the straight proto.
|
||||
"""
|
||||
if attr_proto.HasField('f'):
|
||||
return attr_proto.f
|
||||
elif attr_proto.HasField('i'):
|
||||
return attr_proto.i
|
||||
elif attr_proto.HasField('s'):
|
||||
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
|
||||
elif attr_proto.HasField('t'):
|
||||
return attr_proto.t # this is a proto!
|
||||
elif attr_proto.HasField('g'):
|
||||
return attr_proto.g
|
||||
elif attr_proto.floats:
|
||||
return list(attr_proto.floats)
|
||||
elif attr_proto.ints:
|
||||
return list(attr_proto.ints)
|
||||
elif attr_proto.strings:
|
||||
str_list = list(attr_proto.strings)
|
||||
if IS_PYTHON3:
|
||||
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
|
||||
return str_list
|
||||
elif attr_proto.HasField('sparse_tensor'):
|
||||
return attr_proto.sparse_tensor
|
||||
else:
|
||||
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
|
||||
@@ -0,0 +1,37 @@
|
||||
from tensorflow.python.framework.tensor_util import MakeNdarray
|
||||
|
||||
from onnx_tf.common import data_type
|
||||
|
||||
# Keyed by old attribute names.
|
||||
__tf_attr_translator = {
|
||||
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
|
||||
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
|
||||
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"value": lambda x: MakeNdarray(x.tensor),
|
||||
"seed2": lambda x: float(x.i),
|
||||
"seed": lambda x: float(x.i),
|
||||
"keep_dims": lambda x: int(x.b),
|
||||
"squeeze_dims": lambda x: list(x.list.i),
|
||||
}
|
||||
|
||||
__onnx_attr_translator = {
|
||||
"axis": lambda x: int(x),
|
||||
"axes": lambda x: [int(a) for a in x],
|
||||
"dtype": lambda x: data_type.onnx2tf(x),
|
||||
"keepdims": lambda x: bool(x),
|
||||
"to": lambda x: data_type.onnx2tf(x),
|
||||
}
|
||||
|
||||
|
||||
def translate_tf(key, val):
|
||||
return __tf_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def translate_onnx(key, val):
|
||||
return __onnx_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def get_tf_shape_as_list(tf_shape_dim):
|
||||
return list(map(lambda x: x.size, list(tf_shape_dim)))
|
||||
@@ -0,0 +1,71 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
from onnx import mapping
|
||||
from onnx import TensorProto
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def tf2onnx(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
tf_dype = tf.as_dtype(dtype)
|
||||
elif isinstance(dtype, tf.DType):
|
||||
tf_dype = dtype
|
||||
elif isinstance(dtype, list):
|
||||
return [tf2onnx(t) for t in dtype]
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or tf.DType.")
|
||||
|
||||
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
|
||||
# to leverage existing type conversion infrastructure;
|
||||
# However, we need to intercept the string type early because
|
||||
# lowering tf.string type to numpy dtype results in loss of
|
||||
# information. <class 'object'> is returned instead of the
|
||||
# numpy string type desired.
|
||||
if tf_dype is tf.string:
|
||||
return TensorProto.STRING
|
||||
|
||||
onnx_dtype = None
|
||||
try:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
|
||||
tf_dype.as_numpy_dtype)]
|
||||
finally:
|
||||
if onnx_dtype is None:
|
||||
common.logger.warning(
|
||||
"Can't convert tf dtype {} to ONNX dtype. Return 0 (TensorProto.UNDEFINED)."
|
||||
.format(tf_dype))
|
||||
onnx_dtype = TensorProto.UNDEFINED
|
||||
return onnx_dtype
|
||||
|
||||
|
||||
def onnx2tf(dtype):
|
||||
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
|
||||
|
||||
|
||||
def onnx2field(dtype):
|
||||
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
|
||||
|
||||
|
||||
def _onnx_dtype(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
onnx_dype = dtype
|
||||
elif isinstance(dtype, str):
|
||||
onnx_dype = TensorProto.DataType.Value(dtype)
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or str.")
|
||||
return onnx_dype
|
||||
|
||||
|
||||
# TODO (tjingrant) unify _onnx_dtype into any_dtype_to_onnx_dtype
|
||||
def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
|
||||
dtype_mask = [1 if val else 0 for val in [np_dtype, tf_dtype, onnx_dtype]]
|
||||
num_type_set = sum(dtype_mask)
|
||||
assert num_type_set == 1, "One and only one type must be set. However, {} set.".format(
|
||||
sum(num_type_set))
|
||||
|
||||
if np_dtype:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
|
||||
if tf_dtype:
|
||||
onnx_dtype = tf2onnx(tf_dtype)
|
||||
|
||||
return onnx_dtype
|
||||
@@ -0,0 +1,73 @@
|
||||
import inspect
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class CustomException(object):
|
||||
|
||||
def __init__(self):
|
||||
self._func = RuntimeError
|
||||
self._message = ""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if inspect.isclass(self._func) and issubclass(self._func, Exception):
|
||||
raise self._func(self.get_message(*args, **kwargs))
|
||||
elif callable(self._func):
|
||||
self._func(self.get_message(*args, **kwargs))
|
||||
|
||||
def get_message(self, *args, **kwargs):
|
||||
return self._message
|
||||
|
||||
|
||||
class OpUnimplementedException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnimplementedException, self).__init__()
|
||||
self._func = NotImplementedError
|
||||
self._message = "{} is not implemented."
|
||||
|
||||
def __call__(self, op, version=None, domain=None):
|
||||
if IGNORE_UNIMPLEMENTED:
|
||||
self._func = common.logger.warning
|
||||
super(OpUnimplementedException, self).__call__(op, version, domain)
|
||||
|
||||
def get_message(self, op, version=None, domain=None):
|
||||
insert_message = op
|
||||
if version is not None:
|
||||
insert_message += " version {}".format(version)
|
||||
if domain is not None:
|
||||
insert_message += " in domain `{}`".format(domain)
|
||||
return self._message.format(insert_message)
|
||||
|
||||
|
||||
class OpUnsupportedException(object):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnsupportedException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} is not supported in {}."
|
||||
|
||||
def __call__(self, op, framework):
|
||||
raise self._func(self.get_message(op, framework))
|
||||
|
||||
def get_message(self, op, framework):
|
||||
return self._message.format(op, framework)
|
||||
|
||||
|
||||
class ConstNotFoundException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(ConstNotFoundException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} of {} is not found in graph consts."
|
||||
|
||||
def __call__(self, name, op):
|
||||
super(ConstNotFoundException, self).__call__(name, op)
|
||||
|
||||
def get_message(self, name, op):
|
||||
return self._message.format(name, op)
|
||||
|
||||
|
||||
IGNORE_UNIMPLEMENTED = False
|
||||
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
|
||||
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
|
||||
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
|
||||
@@ -0,0 +1,76 @@
|
||||
from onnx import defs
|
||||
|
||||
import onnx_tf.common as common
|
||||
from onnx_tf.handlers.backend import * # noqa
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
|
||||
import onnx_tf.common as common
|
||||
|
||||
def get_all_backend_handlers(opset_dict):
|
||||
""" Get a dict of all backend handler classes.
|
||||
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
|
||||
|
||||
:param opset_dict: A dict of opset. e.g. {'domain': version, ...}
|
||||
:return: Dict.
|
||||
"""
|
||||
handlers = {}
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
domain = handler.DOMAIN
|
||||
version = opset_dict[domain] if domain in opset_dict else 1
|
||||
handler.VERSION = version
|
||||
|
||||
since_version = 1
|
||||
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
|
||||
try:
|
||||
since_version = defs.get_schema(
|
||||
handler.ONNX_OP,
|
||||
domain=handler.DOMAIN,
|
||||
max_inclusive_version=version).since_version
|
||||
except RuntimeError:
|
||||
common.logger.debug("Fail to get since_version of {} in domain `{}` "
|
||||
"with max_inclusive_version={}. Set to 1.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN, version))
|
||||
else:
|
||||
common.logger.debug("Unknown op {} in domain `{}`.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
|
||||
handler.SINCE_VERSION = since_version
|
||||
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
|
||||
return handlers
|
||||
|
||||
|
||||
def get_backend_coverage():
|
||||
""" Get backend coverage for document.
|
||||
|
||||
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
|
||||
"""
|
||||
|
||||
onnx_coverage = {}
|
||||
experimental_op = set()
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
versions = handler.get_versions()
|
||||
domain = handler.DOMAIN
|
||||
if getattr(handler, "EXPERIMENTAL", False):
|
||||
experimental_op.add(handler.ONNX_OP)
|
||||
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
|
||||
return onnx_coverage, experimental_op
|
||||
|
||||
|
||||
def _update_coverage(coverage, domain, key, versions):
|
||||
domain_coverage = coverage.setdefault(domain, {})
|
||||
vers = domain_coverage.get(key, [])
|
||||
vers.extend(versions)
|
||||
domain_coverage[key] = sorted(list(set(vers)))
|
||||
|
||||
|
||||
def get_backend_partial_support_detail():
|
||||
ps_dict = {}
|
||||
opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())])
|
||||
handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN]
|
||||
for op_name in handlers:
|
||||
if handlers[op_name].PARTIAL_SUPPORT:
|
||||
ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION
|
||||
return ps_dict
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def get_onnx_version():
|
||||
return tuple(map(int, onnx.version.version.split(".")))
|
||||
|
||||
|
||||
# Returns whether onnx version is prior to major.minor.patch
|
||||
def legacy_onnx_pre_ver(major=0, minor=0, patch=0):
|
||||
return get_onnx_version() < (major, minor, patch)
|
||||
|
||||
|
||||
# Returns whether the opset version accompanying the
|
||||
# onnx installation is prior to version passed.
|
||||
def legacy_opset_pre_ver(version):
|
||||
return onnx.defs.onnx_opset_version() < version
|
||||
@@ -0,0 +1,263 @@
|
||||
from __future__ import division
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
pad_ops = namedtuple("pad_ops",
|
||||
["max_op", "ceil_op", "floor_op", "cast_int_op"])
|
||||
|
||||
pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
|
||||
lambda arr: arr.astype(np.int64))
|
||||
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
|
||||
lambda tensor: tf.cast(tensor, tf.int64))
|
||||
|
||||
|
||||
def calc_pads_same(in_spatial_shape, kernel_shape, strides,
|
||||
dilations, padding, padding_ops=pad_numpy_ops,
|
||||
pads_order=1):
|
||||
"""
|
||||
Calculates the SAME paddings that need to be added to the input
|
||||
|
||||
Args:
|
||||
in_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: padding to calculate: SAME_UPPER or
|
||||
SAME_LOWER
|
||||
padding_ops: namedtuple with ops to be used during
|
||||
calculations. there are two sets of ops
|
||||
defined pad_numpy_ops and pad_tf_ops with
|
||||
numpy and tensorflow ops
|
||||
pads_order: order of returned pads. possible options are:
|
||||
1 - b1, b2, ..., bn, e1, e2, ..., en
|
||||
2 - b1, e1, b2, e2, ..., bn, en
|
||||
where n = len(kernel_shape) * 2,
|
||||
b1, b2, ..., bn define pads at the begging of
|
||||
axis
|
||||
e1, e2, ..., en define pads at the end of
|
||||
axis
|
||||
Return:
|
||||
pads: array with calculated pads. the order of the
|
||||
values is determined by `pads_order`
|
||||
|
||||
"""
|
||||
spatial_size = len(kernel_shape)
|
||||
pads = [0] * (spatial_size * 2)
|
||||
for i in range(spatial_size):
|
||||
in_size = in_spatial_shape[i]
|
||||
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1
|
||||
|
||||
out_size = padding_ops.ceil_op(in_size / strides[i])
|
||||
out_size = padding_ops.cast_int_op(out_size)
|
||||
pad_along_axis = \
|
||||
padding_ops.max_op((out_size - 1) * strides[i] +
|
||||
filter_size - in_size, 0)
|
||||
if padding.lower() == "same_lower":
|
||||
pad_op = padding_ops.ceil_op
|
||||
else:
|
||||
pad_op = padding_ops.floor_op
|
||||
pad_begin = pad_op(pad_along_axis / 2)
|
||||
|
||||
pad_begin = padding_ops.cast_int_op(pad_begin)
|
||||
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)
|
||||
|
||||
pad_end = pad_along_axis - pad_begin
|
||||
|
||||
pads[i * pads_order] = pad_begin
|
||||
pads[i * pads_order +
|
||||
(spatial_size if pads_order == 1 else 1)] = pad_end
|
||||
|
||||
return pads
|
||||
|
||||
|
||||
def calc_output_shape(input_spatial_shape, kernel_shape, strides, dilations,
|
||||
padding, ceil_mode=False):
|
||||
"""
|
||||
Calculate output shape
|
||||
|
||||
Args:
|
||||
input_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: can be explicit paddings, "SAME_UPPER" or
|
||||
"SAME_LOWER"
|
||||
Return:
|
||||
output_shape: calculated output shape
|
||||
"""
|
||||
spatial_size = len(input_spatial_shape)
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(input_spatial_shape, kernel_shape,
|
||||
strides, dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
output_shape = []
|
||||
for dim in range(spatial_size):
|
||||
output_shape.append(_pooling_output_shape(input_spatial_shape[dim],
|
||||
kernel_shape[dim], strides[dim], dilations[dim],
|
||||
padding[dim] + padding[dim + spatial_size],
|
||||
ceil_mode))
|
||||
|
||||
return output_shape
|
||||
|
||||
|
||||
def _pooling_output_shape(input_size, ksize, stride, dilation, pad, ceil_mode):
|
||||
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
|
||||
((stride-1) if ceil_mode else 0)) // stride + 1
|
||||
if (pad):
|
||||
if ((output_size - 1) * stride >= input_size + pad):
|
||||
output_size -= 1
|
||||
return output_size
|
||||
|
||||
|
||||
def py_pool(input, kernel_shape, strides=None, dilations=None,
|
||||
padding=None, ceil_mode=False, pooling_type="MAX",
|
||||
include_indices=True, p=2):
|
||||
"""
|
||||
Implementation of Max and Average pool operations in Python
|
||||
Args:
|
||||
input: input N-D data array in NC* format
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis of filter
|
||||
padding: padding for the beginning and ending along each
|
||||
spatial axis. `padding` format should be as follow
|
||||
[x1_begin, x2_begin...x1_end, x2_end,...]
|
||||
ceil_mode: whether to use ceil or floor (default) to compute
|
||||
the output shape.
|
||||
pooling_type: specifies pooling type. Values can be "MAX", "AVG" or
|
||||
"LP"
|
||||
include_indices: should indices be included in the output
|
||||
p: specifies the p parameter for LpPooling
|
||||
Return:
|
||||
pooled: output data from max pooling across the input
|
||||
ind: indices of the selected max values from the input
|
||||
"""
|
||||
|
||||
if type(pooling_type) is not str:
|
||||
pooling_type = pooling_type.decode("UTF-8")
|
||||
|
||||
input_shape = np.shape(input)
|
||||
inp_sp_shape = input_shape[2:]
|
||||
input_dtype = input.dtype
|
||||
if np.issubdtype(input_dtype, np.integer):
|
||||
input_dtype_min = np.iinfo(input_dtype).min
|
||||
else:
|
||||
input_dtype_min = np.finfo(input_dtype).min
|
||||
|
||||
if pooling_type == "LP":
|
||||
rootN = (1.0 / p)
|
||||
|
||||
def _loop_over_output(batch, channel):
|
||||
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
|
||||
for counters in itertools.product(*dims):
|
||||
input_ranges = []
|
||||
for dim in range(spatial_size):
|
||||
dim_start = \
|
||||
counters[dim] * strides[dim] - pads[dim * 2]
|
||||
dim_end = \
|
||||
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
|
||||
+ 1, inp_sp_shape[dim])
|
||||
while dim_start < 0:
|
||||
dim_start += dilations[dim]
|
||||
|
||||
cur_range = [i for i in range(dim_start,
|
||||
dim_end, dilations[dim])]
|
||||
input_ranges.append(cur_range)
|
||||
if pooling_type in ["AVG", "LP"]:
|
||||
val_sum = 0
|
||||
val_count = 0
|
||||
else:
|
||||
maxval = input_dtype_min
|
||||
maxind = -1
|
||||
for input_ind in itertools.product(*input_ranges):
|
||||
ind = (batch, channel) + input_ind
|
||||
val = input[ind]
|
||||
if pooling_type == "AVG":
|
||||
val_sum += val
|
||||
val_count += 1
|
||||
elif pooling_type == "LP":
|
||||
val_sum += abs(val ** p)
|
||||
else:
|
||||
if val > maxval:
|
||||
maxval = val
|
||||
ind = 0
|
||||
for i in range(spatial_size):
|
||||
coef = 1
|
||||
for j in range(i+1, spatial_size):
|
||||
coef *= inp_sp_shape[j]
|
||||
ind += input_ind[i] * coef
|
||||
maxind = ind
|
||||
ind = (batch, channel) + counters
|
||||
if pooling_type == "AVG":
|
||||
out_pool[ind] = val_sum / val_count
|
||||
elif pooling_type == "LP":
|
||||
out_pool[ind] = val_sum ** rootN
|
||||
else:
|
||||
out_pool[ind] = maxval
|
||||
out_ind[ind] = maxind
|
||||
|
||||
spatial_size = len(kernel_shape)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
channels_num = input_shape[1]
|
||||
|
||||
if strides is None:
|
||||
strides = kernel_shape
|
||||
|
||||
if dilations is None:
|
||||
dilations = [1] * spatial_size
|
||||
|
||||
if padding is None:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
if type(padding) is bytes:
|
||||
padding = padding.decode()
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if type(padding) is not str:
|
||||
padding = padding.decode("UTF-8")
|
||||
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
|
||||
dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
pads = []
|
||||
pad_along_axis = []
|
||||
output_sp_shape = []
|
||||
|
||||
for dim in range(spatial_size):
|
||||
pads.append(padding[dim])
|
||||
pads.append(padding[dim + spatial_size])
|
||||
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])
|
||||
|
||||
input_size = input_shape[dim + 2]
|
||||
output_size = \
|
||||
_pooling_output_shape(input_size, kernel_shape[dim],
|
||||
strides[dim], dilations[dim],
|
||||
pad_along_axis[dim], ceil_mode)
|
||||
output_sp_shape.append(output_size)
|
||||
|
||||
out_pool = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, input_dtype)
|
||||
out_ind = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, np.int64)
|
||||
|
||||
for batch in range(batch_size):
|
||||
for channel in range(channels_num):
|
||||
_loop_over_output(batch, channel)
|
||||
|
||||
if not include_indices:
|
||||
return out_pool
|
||||
else:
|
||||
return out_pool, out_ind
|
||||
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tf_shape(tensor):
|
||||
"""
|
||||
Helper function returning the shape of a Tensor.
|
||||
The function will check for fully defined shape and will return
|
||||
numpy array or if the shape is not fully defined will use tf.shape()
|
||||
to return the shape as a Tensor.
|
||||
"""
|
||||
if tensor.shape.is_fully_defined():
|
||||
return np.array(tensor.shape.as_list(), dtype=np.int64)
|
||||
else:
|
||||
return tf.shape(tensor, out_type=tf.int64)
|
||||
|
||||
|
||||
def tf_product(a, b):
|
||||
"""
|
||||
Calculates the cartesian product of two column vectors a and b
|
||||
|
||||
Example:
|
||||
|
||||
a = [[1]
|
||||
[2]
|
||||
[3]]
|
||||
|
||||
b = [[0]
|
||||
[1]]
|
||||
|
||||
result = [[1 0]
|
||||
[1 1]
|
||||
[2 0]
|
||||
[2 1]
|
||||
[3 0]
|
||||
[3 1]]
|
||||
"""
|
||||
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
|
||||
tile_a = tf.expand_dims(tile_a, 2)
|
||||
tile_a = tf.reshape(tile_a, [-1, 1])
|
||||
|
||||
b = tf.tile(b, [tf.shape(a)[0], 1])
|
||||
b = tf.concat([tile_a, b], axis=1)
|
||||
|
||||
return b
|
||||
Reference in New Issue
Block a user