add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,196 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import inspect
import re
import sys
import uuid
import warnings
import logging
from onnx.backend.base import DeviceType
from tensorflow.python.client import device_lib
IS_PYTHON3 = sys.version_info > (3,)
logger = logging.getLogger('onnx-tf')
# create console handler and formatter for logger
console = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
logger.addHandler(console)
class Deprecated:
"""Add deprecated message when function is called.
Usage:
from onnx_tf.common import deprecated
@deprecated
def func():
pass
UserWarning: func is deprecated. It will be removed in future release.
@deprecated("Message")
def func():
pass
UserWarning: Message
@deprecated({"arg": "Message",
"arg_1": deprecated.MSG_WILL_REMOVE,
"arg_2": "",})
def func(arg, arg_1, arg_2):
pass
UserWarning: Message
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
UserWarning: arg_2 of func is deprecated.
"""
MSG_WILL_REMOVE = " It will be removed in future release."
def __call__(self, *args, **kwargs):
return self.deprecated_decorator(*args, **kwargs)
@staticmethod
def messages():
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
@staticmethod
def deprecated_decorator(arg=None):
# deprecate function with default message MSG_WILL_REMOVE
# @deprecated
if inspect.isfunction(arg):
def wrapper(*args, **kwargs):
warnings.warn("{} is deprecated.{}".format(
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
return arg(*args, **kwargs)
return wrapper
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
def deco(func):
# deprecate arg
# @deprecated({...})
if isinstance(deprecated_arg, dict):
for name, message in deprecated_arg.items():
if message in Deprecated.messages():
message = "{} of {} is deprecated.{}".format(
name, func.__module__ + "." + func.__name__, message or "")
warnings.warn(message)
# deprecate function with message
# @deprecated("message")
elif isinstance(deprecated_arg, str):
message = deprecated_arg
if message in Deprecated.messages():
message = "{} is deprecated.{}".format(
func.__module__ + "." + func.__name__, message)
warnings.warn(message)
return func
return deco
deprecated = Deprecated()
# This function inserts an underscore before every upper
# case letter and lowers that upper case letter except for
# the first letter.
def op_name_to_lower(name):
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
def get_unique_suffix():
""" Get unique suffix by using first 8 chars from uuid.uuid4
to make unique identity name.
:return: Unique suffix string.
"""
return str(uuid.uuid4())[:8]
def get_perm_from_formats(from_, to_):
""" Get perm from data formats.
For example:
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
:param from_: From data format string.
:param to_: To data format string.
:return: Perm. Int list.
"""
return list(map(lambda x: from_.find(x), to_))
# TODO: allow more flexible placement
def get_device_option(device):
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
return m[device.type]
def get_data_format(x_rank):
""" Get data format by input rank.
Channel first if support CUDA.
:param x_rank: Input rank.
:return: Data format.
"""
sp_dim_names = ["D", "H", "W"]
sp_dim_lst = []
for i in range(x_rank - 2):
sp_dim_lst.append(sp_dim_names[-i - 1])
sp_dim_string = "".join(reversed(sp_dim_lst))
storage_format = "NC" + sp_dim_string
if supports_device("CUDA"):
compute_format = "NC" + sp_dim_string
else:
compute_format = "N" + sp_dim_string + "C"
return storage_format, compute_format
def supports_device(device):
""" Check if support target device.
:param device: CUDA or CPU.
:return: If supports.
"""
if device == "CUDA":
local_device_protos = device_lib.list_local_devices()
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
]) > 0
elif device == "CPU":
return True
return False
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
deprecated.MSG_WILL_REMOVE,
"Use TensorflowGraph.get_outputs_names instead."))
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
nodes, input_names = dict(), set()
for node in graph_def.node:
nodes[node.name] = node
input_names.update(set(node.input))
return list(set(nodes) - input_names)
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
@@ -0,0 +1,86 @@
from onnx_tf.common import IS_PYTHON3
def convert_tf(attr):
return __convert_tf_attr_value(attr)
def convert_onnx(attr):
return __convert_onnx_attribute_proto(attr)
def __convert_tf_attr_value(attr):
""" convert Tensorflow AttrValue object to Python object
"""
if attr.HasField('list'):
return __convert_tf_list_value(attr.list)
if attr.HasField('s'):
return attr.s
elif attr.HasField('i'):
return attr.i
elif attr.HasField('f'):
return attr.f
elif attr.HasField('b'):
return attr.b
elif attr.HasField('type'):
return attr.type
elif attr.HasField('shape'):
return attr.type
elif attr.HasField('tensor'):
return attr.tensor
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
def __convert_tf_list_value(list_value):
""" convert Tensorflow ListValue object to Python object
"""
if list_value.s:
return list_value.s
elif list_value.i:
return list_value.i
elif list_value.f:
return list_value.f
elif list_value.b:
return list_value.b
elif list_value.tensor:
return list_value.tensor
elif list_value.type:
return list_value.type
elif list_value.shape:
return list_value.shape
elif list_value.func:
return list_value.func
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
def __convert_onnx_attribute_proto(attr_proto):
"""
Convert an ONNX AttributeProto into an appropriate Python object
for the type.
NB: Tensor attribute gets returned as the straight proto.
"""
if attr_proto.HasField('f'):
return attr_proto.f
elif attr_proto.HasField('i'):
return attr_proto.i
elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto!
elif attr_proto.HasField('g'):
return attr_proto.g
elif attr_proto.floats:
return list(attr_proto.floats)
elif attr_proto.ints:
return list(attr_proto.ints)
elif attr_proto.strings:
str_list = list(attr_proto.strings)
if IS_PYTHON3:
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
return str_list
elif attr_proto.HasField('sparse_tensor'):
return attr_proto.sparse_tensor
else:
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
@@ -0,0 +1,37 @@
from tensorflow.python.framework.tensor_util import MakeNdarray
from onnx_tf.common import data_type
# Keyed by old attribute names.
__tf_attr_translator = {
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"value": lambda x: MakeNdarray(x.tensor),
"seed2": lambda x: float(x.i),
"seed": lambda x: float(x.i),
"keep_dims": lambda x: int(x.b),
"squeeze_dims": lambda x: list(x.list.i),
}
__onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
}
def translate_tf(key, val):
return __tf_attr_translator.get(key, lambda x: x)(val)
def translate_onnx(key, val):
return __onnx_attr_translator.get(key, lambda x: x)(val)
def get_tf_shape_as_list(tf_shape_dim):
return list(map(lambda x: x.size, list(tf_shape_dim)))
@@ -0,0 +1,71 @@
from numbers import Number
import numpy as np
from onnx import mapping
from onnx import TensorProto
import tensorflow as tf
def tf2onnx(dtype):
if isinstance(dtype, Number):
tf_dype = tf.as_dtype(dtype)
elif isinstance(dtype, tf.DType):
tf_dype = dtype
elif isinstance(dtype, list):
return [tf2onnx(t) for t in dtype]
else:
raise RuntimeError("dtype should be number or tf.DType.")
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
# to leverage existing type conversion infrastructure;
# However, we need to intercept the string type early because
# lowering tf.string type to numpy dtype results in loss of
# information. <class 'object'> is returned instead of the
# numpy string type desired.
if tf_dype is tf.string:
return TensorProto.STRING
onnx_dtype = None
try:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
tf_dype.as_numpy_dtype)]
finally:
if onnx_dtype is None:
common.logger.warning(
"Can't convert tf dtype {} to ONNX dtype. Return 0 (TensorProto.UNDEFINED)."
.format(tf_dype))
onnx_dtype = TensorProto.UNDEFINED
return onnx_dtype
def onnx2tf(dtype):
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
def onnx2field(dtype):
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
def _onnx_dtype(dtype):
if isinstance(dtype, Number):
onnx_dype = dtype
elif isinstance(dtype, str):
onnx_dype = TensorProto.DataType.Value(dtype)
else:
raise RuntimeError("dtype should be number or str.")
return onnx_dype
# TODO (tjingrant) unify _onnx_dtype into any_dtype_to_onnx_dtype
def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
dtype_mask = [1 if val else 0 for val in [np_dtype, tf_dtype, onnx_dtype]]
num_type_set = sum(dtype_mask)
assert num_type_set == 1, "One and only one type must be set. However, {} set.".format(
sum(num_type_set))
if np_dtype:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
if tf_dtype:
onnx_dtype = tf2onnx(tf_dtype)
return onnx_dtype
@@ -0,0 +1,73 @@
import inspect
import onnx_tf.common as common
class CustomException(object):
def __init__(self):
self._func = RuntimeError
self._message = ""
def __call__(self, *args, **kwargs):
if inspect.isclass(self._func) and issubclass(self._func, Exception):
raise self._func(self.get_message(*args, **kwargs))
elif callable(self._func):
self._func(self.get_message(*args, **kwargs))
def get_message(self, *args, **kwargs):
return self._message
class OpUnimplementedException(CustomException):
def __init__(self):
super(OpUnimplementedException, self).__init__()
self._func = NotImplementedError
self._message = "{} is not implemented."
def __call__(self, op, version=None, domain=None):
if IGNORE_UNIMPLEMENTED:
self._func = common.logger.warning
super(OpUnimplementedException, self).__call__(op, version, domain)
def get_message(self, op, version=None, domain=None):
insert_message = op
if version is not None:
insert_message += " version {}".format(version)
if domain is not None:
insert_message += " in domain `{}`".format(domain)
return self._message.format(insert_message)
class OpUnsupportedException(object):
def __init__(self):
super(OpUnsupportedException, self).__init__()
self._func = RuntimeError
self._message = "{} is not supported in {}."
def __call__(self, op, framework):
raise self._func(self.get_message(op, framework))
def get_message(self, op, framework):
return self._message.format(op, framework)
class ConstNotFoundException(CustomException):
def __init__(self):
super(ConstNotFoundException, self).__init__()
self._func = RuntimeError
self._message = "{} of {} is not found in graph consts."
def __call__(self, name, op):
super(ConstNotFoundException, self).__call__(name, op)
def get_message(self, name, op):
return self._message.format(name, op)
IGNORE_UNIMPLEMENTED = False
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
@@ -0,0 +1,76 @@
from onnx import defs
import onnx_tf.common as common
from onnx_tf.handlers.backend import * # noqa
from onnx_tf.handlers.backend_handler import BackendHandler
import onnx_tf.common as common
def get_all_backend_handlers(opset_dict):
""" Get a dict of all backend handler classes.
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
:param opset_dict: A dict of opset. e.g. {'domain': version, ...}
:return: Dict.
"""
handlers = {}
for handler in BackendHandler.__subclasses__():
handler.check_cls()
domain = handler.DOMAIN
version = opset_dict[domain] if domain in opset_dict else 1
handler.VERSION = version
since_version = 1
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
try:
since_version = defs.get_schema(
handler.ONNX_OP,
domain=handler.DOMAIN,
max_inclusive_version=version).since_version
except RuntimeError:
common.logger.debug("Fail to get since_version of {} in domain `{}` "
"with max_inclusive_version={}. Set to 1.".format(
handler.ONNX_OP, handler.DOMAIN, version))
else:
common.logger.debug("Unknown op {} in domain `{}`.".format(
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
handler.SINCE_VERSION = since_version
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
return handlers
def get_backend_coverage():
""" Get backend coverage for document.
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
"""
onnx_coverage = {}
experimental_op = set()
for handler in BackendHandler.__subclasses__():
handler.check_cls()
versions = handler.get_versions()
domain = handler.DOMAIN
if getattr(handler, "EXPERIMENTAL", False):
experimental_op.add(handler.ONNX_OP)
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
return onnx_coverage, experimental_op
def _update_coverage(coverage, domain, key, versions):
domain_coverage = coverage.setdefault(domain, {})
vers = domain_coverage.get(key, [])
vers.extend(versions)
domain_coverage[key] = sorted(list(set(vers)))
def get_backend_partial_support_detail():
ps_dict = {}
opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())])
handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN]
for op_name in handlers:
if handlers[op_name].PARTIAL_SUPPORT:
ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION
return ps_dict
@@ -0,0 +1,21 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import onnx
def get_onnx_version():
return tuple(map(int, onnx.version.version.split(".")))
# Returns whether onnx version is prior to major.minor.patch
def legacy_onnx_pre_ver(major=0, minor=0, patch=0):
return get_onnx_version() < (major, minor, patch)
# Returns whether the opset version accompanying the
# onnx installation is prior to version passed.
def legacy_opset_pre_ver(version):
return onnx.defs.onnx_opset_version() < version
@@ -0,0 +1,263 @@
from __future__ import division
from collections import namedtuple
import numpy as np
import tensorflow as tf
import itertools
pad_ops = namedtuple("pad_ops",
["max_op", "ceil_op", "floor_op", "cast_int_op"])
pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
lambda arr: arr.astype(np.int64))
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
lambda tensor: tf.cast(tensor, tf.int64))
def calc_pads_same(in_spatial_shape, kernel_shape, strides,
dilations, padding, padding_ops=pad_numpy_ops,
pads_order=1):
"""
Calculates the SAME paddings that need to be added to the input
Args:
in_spatial_shape: input spatial shape
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis
padding: padding to calculate: SAME_UPPER or
SAME_LOWER
padding_ops: namedtuple with ops to be used during
calculations. there are two sets of ops
defined pad_numpy_ops and pad_tf_ops with
numpy and tensorflow ops
pads_order: order of returned pads. possible options are:
1 - b1, b2, ..., bn, e1, e2, ..., en
2 - b1, e1, b2, e2, ..., bn, en
where n = len(kernel_shape) * 2,
b1, b2, ..., bn define pads at the begging of
axis
e1, e2, ..., en define pads at the end of
axis
Return:
pads: array with calculated pads. the order of the
values is determined by `pads_order`
"""
spatial_size = len(kernel_shape)
pads = [0] * (spatial_size * 2)
for i in range(spatial_size):
in_size = in_spatial_shape[i]
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1
out_size = padding_ops.ceil_op(in_size / strides[i])
out_size = padding_ops.cast_int_op(out_size)
pad_along_axis = \
padding_ops.max_op((out_size - 1) * strides[i] +
filter_size - in_size, 0)
if padding.lower() == "same_lower":
pad_op = padding_ops.ceil_op
else:
pad_op = padding_ops.floor_op
pad_begin = pad_op(pad_along_axis / 2)
pad_begin = padding_ops.cast_int_op(pad_begin)
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)
pad_end = pad_along_axis - pad_begin
pads[i * pads_order] = pad_begin
pads[i * pads_order +
(spatial_size if pads_order == 1 else 1)] = pad_end
return pads
def calc_output_shape(input_spatial_shape, kernel_shape, strides, dilations,
padding, ceil_mode=False):
"""
Calculate output shape
Args:
input_spatial_shape: input spatial shape
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis
padding: can be explicit paddings, "SAME_UPPER" or
"SAME_LOWER"
Return:
output_shape: calculated output shape
"""
spatial_size = len(input_spatial_shape)
if type(padding) is not list and type(padding) is not np.ndarray:
if padding.lower().startswith("same"):
padding = calc_pads_same(input_spatial_shape, kernel_shape,
strides, dilations, padding)
else:
padding = [0] * spatial_size * 2
output_shape = []
for dim in range(spatial_size):
output_shape.append(_pooling_output_shape(input_spatial_shape[dim],
kernel_shape[dim], strides[dim], dilations[dim],
padding[dim] + padding[dim + spatial_size],
ceil_mode))
return output_shape
def _pooling_output_shape(input_size, ksize, stride, dilation, pad, ceil_mode):
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
((stride-1) if ceil_mode else 0)) // stride + 1
if (pad):
if ((output_size - 1) * stride >= input_size + pad):
output_size -= 1
return output_size
def py_pool(input, kernel_shape, strides=None, dilations=None,
padding=None, ceil_mode=False, pooling_type="MAX",
include_indices=True, p=2):
"""
Implementation of Max and Average pool operations in Python
Args:
input: input N-D data array in NC* format
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis of filter
padding: padding for the beginning and ending along each
spatial axis. `padding` format should be as follow
[x1_begin, x2_begin...x1_end, x2_end,...]
ceil_mode: whether to use ceil or floor (default) to compute
the output shape.
pooling_type: specifies pooling type. Values can be "MAX", "AVG" or
"LP"
include_indices: should indices be included in the output
p: specifies the p parameter for LpPooling
Return:
pooled: output data from max pooling across the input
ind: indices of the selected max values from the input
"""
if type(pooling_type) is not str:
pooling_type = pooling_type.decode("UTF-8")
input_shape = np.shape(input)
inp_sp_shape = input_shape[2:]
input_dtype = input.dtype
if np.issubdtype(input_dtype, np.integer):
input_dtype_min = np.iinfo(input_dtype).min
else:
input_dtype_min = np.finfo(input_dtype).min
if pooling_type == "LP":
rootN = (1.0 / p)
def _loop_over_output(batch, channel):
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
for counters in itertools.product(*dims):
input_ranges = []
for dim in range(spatial_size):
dim_start = \
counters[dim] * strides[dim] - pads[dim * 2]
dim_end = \
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
+ 1, inp_sp_shape[dim])
while dim_start < 0:
dim_start += dilations[dim]
cur_range = [i for i in range(dim_start,
dim_end, dilations[dim])]
input_ranges.append(cur_range)
if pooling_type in ["AVG", "LP"]:
val_sum = 0
val_count = 0
else:
maxval = input_dtype_min
maxind = -1
for input_ind in itertools.product(*input_ranges):
ind = (batch, channel) + input_ind
val = input[ind]
if pooling_type == "AVG":
val_sum += val
val_count += 1
elif pooling_type == "LP":
val_sum += abs(val ** p)
else:
if val > maxval:
maxval = val
ind = 0
for i in range(spatial_size):
coef = 1
for j in range(i+1, spatial_size):
coef *= inp_sp_shape[j]
ind += input_ind[i] * coef
maxind = ind
ind = (batch, channel) + counters
if pooling_type == "AVG":
out_pool[ind] = val_sum / val_count
elif pooling_type == "LP":
out_pool[ind] = val_sum ** rootN
else:
out_pool[ind] = maxval
out_ind[ind] = maxind
spatial_size = len(kernel_shape)
batch_size = input_shape[0]
channels_num = input_shape[1]
if strides is None:
strides = kernel_shape
if dilations is None:
dilations = [1] * spatial_size
if padding is None:
padding = [0] * spatial_size * 2
if type(padding) is bytes:
padding = padding.decode()
if type(padding) is not list and type(padding) is not np.ndarray:
if type(padding) is not str:
padding = padding.decode("UTF-8")
if padding.lower().startswith("same"):
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
dilations, padding)
else:
padding = [0] * spatial_size * 2
pads = []
pad_along_axis = []
output_sp_shape = []
for dim in range(spatial_size):
pads.append(padding[dim])
pads.append(padding[dim + spatial_size])
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])
input_size = input_shape[dim + 2]
output_size = \
_pooling_output_shape(input_size, kernel_shape[dim],
strides[dim], dilations[dim],
pad_along_axis[dim], ceil_mode)
output_sp_shape.append(output_size)
out_pool = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, input_dtype)
out_ind = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, np.int64)
for batch in range(batch_size):
for channel in range(channels_num):
_loop_over_output(batch, channel)
if not include_indices:
return out_pool
else:
return out_pool, out_ind
@@ -0,0 +1,45 @@
import tensorflow as tf
import numpy as np
def tf_shape(tensor):
"""
Helper function returning the shape of a Tensor.
The function will check for fully defined shape and will return
numpy array or if the shape is not fully defined will use tf.shape()
to return the shape as a Tensor.
"""
if tensor.shape.is_fully_defined():
return np.array(tensor.shape.as_list(), dtype=np.int64)
else:
return tf.shape(tensor, out_type=tf.int64)
def tf_product(a, b):
"""
Calculates the cartesian product of two column vectors a and b
Example:
a = [[1]
[2]
[3]]
b = [[0]
[1]]
result = [[1 0]
[1 1]
[2 0]
[2 1]
[3 0]
[3 1]]
"""
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
tile_a = tf.expand_dims(tile_a, 2)
tile_a = tf.reshape(tile_a, [-1, 1])
b = tf.tile(b, [tf.shape(a)[0], 1])
b = tf.concat([tile_a, b], axis=1)
return b