Files
2020-10-14 08:55:07 +08:00

156 lines
5.7 KiB
Python

import tensorflow as tf
from onnx_tf.common import get_data_format
from onnx_tf.common import get_perm_from_formats
class UnpoolMixin(object):
@classmethod
def max_unpool(cls, node, input_dict):
"""
MaxUnpooling operation
"""
x = input_dict[node.inputs[0]]
ind = input_dict[node.inputs[1]]
if len(node.inputs) > 2:
output_shape = input_dict.get(node.inputs[2], None)
else:
output_shape = None
input_shape = x.get_shape()
x_rank = len(x.get_shape())
spatial_size = x_rank - 2
storage_format, _ = get_data_format(x_rank)
kernel_shape = node.attrs["kernel_shape"]
# if strides are not provided default is 1 along each spatial axis
strides = node.attrs.get("strides", [1] * spatial_size)
pads = node.attrs.get("pads", None)
default_shape = cls._get_default_shape(input_shape, kernel_shape,
strides)
need_trans = storage_format != "NHWC"
if need_trans:
x = tf.transpose(x, perm=get_perm_from_formats(storage_format,
"NHWC"))
ind = tf.transpose(ind, perm=get_perm_from_formats(storage_format,
"NHWC"))
# default_shape to NHWC storage format
default_shape = [int(input_shape[0])] + default_shape + \
[int(input_shape[1])]
unpooled = cls._unpool(x, ind, default_shape)
if need_trans:
unpooled = tf.transpose(
unpooled, perm=get_perm_from_formats("NHWC", storage_format))
if output_shape is not None:
pads = cls._get_pads_from_output_shape(unpooled, output_shape)
if pads is not None:
unpooled = cls._pad_output(unpooled, pads, 0)
return [unpooled]
@classmethod
def _get_default_shape(cls, input_shape, kernel_shape, strides):
"""
Calculates default shape from kernel_shape and strides
Args:
input_shape: shape of the input to unpool op
kernel_shape: the size of the kernel along each axis
output_shape: stride along each spatial axis
Return:
default_shape: calculated default_shape
"""
default_shape = []
for d in range(len(kernel_shape)):
default_shape.append((int(input_shape[d + 2]) - 1) *
int(strides[d]) + int(kernel_shape[d]))
return default_shape
@classmethod
def _get_pads_from_output_shape(cls, unpool, output_shape):
"""
Calculates the paddings from specified output_shape
Args:
unpool: result from unpool operation
output_shape: expected shape of the output
Return:
pads: calculated paddings in format
[x1_begin, x2_begin,.., x1_end, x2_end]
where xi_... represent pads added to begin
or end of axis i
"""
unpool_shape = tf.cast(tf.shape(unpool), dtype=tf.int32)
new_shape = tf.cast(output_shape, dtype=tf.int32)
pads_begin = []
pads_end = []
for d in range(len(unpool.get_shape())):
pad_total = new_shape[d] - unpool_shape[d]
pad_begin = tf.cast(pad_total / 2, tf.int32)
pad_end = pad_total - pad_begin
pads_begin = pads_begin + [pad_begin]
pads_end = pads_end + [pad_end]
pads = pads_begin + pads_end
return pads
@classmethod
def _pad_output(cls, unpool, pads, constant_values):
"""
Pad the output from unpool op
Args:
unpool: result from unpool op
pads: paddings in format
[x1_begin, x2_begin,..., x1_end, x2_end]
constant_values: constant value to fill up the padded spaces
Return:
padded: padded tensor
"""
unpool_shape = unpool.get_shape()
paddings = []
for d in range(len(unpool_shape)):
paddings = paddings + [[pads[d], pads[d + len(unpool_shape)]]]
padded = tf.pad(unpool, paddings, 'CONSTANT',
constant_values=constant_values)
return padded
@classmethod
def _unpool(cls, pool, ind, output_shape, scope='unpool'):
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
output_shape: the shape of the output
Return:
unpool: unpooling tensor
"""
with tf.variable_scope(scope):
input_shape = tf.shape(pool)
flat_input_size = tf.reduce_prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] *
output_shape[2] * output_shape[3]]
pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(
tf.range(tf.cast(output_shape[0], tf.int64),
dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b1 = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b1, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape,
tf.int64))
ret = tf.reshape(ret, output_shape)
return ret