156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
import tensorflow as tf
|
|
|
|
from onnx_tf.common import get_data_format
|
|
from onnx_tf.common import get_perm_from_formats
|
|
|
|
|
|
class UnpoolMixin(object):
|
|
|
|
@classmethod
|
|
def max_unpool(cls, node, input_dict):
|
|
"""
|
|
MaxUnpooling operation
|
|
"""
|
|
x = input_dict[node.inputs[0]]
|
|
ind = input_dict[node.inputs[1]]
|
|
if len(node.inputs) > 2:
|
|
output_shape = input_dict.get(node.inputs[2], None)
|
|
else:
|
|
output_shape = None
|
|
|
|
input_shape = x.get_shape()
|
|
x_rank = len(x.get_shape())
|
|
spatial_size = x_rank - 2
|
|
storage_format, _ = get_data_format(x_rank)
|
|
|
|
kernel_shape = node.attrs["kernel_shape"]
|
|
# if strides are not provided default is 1 along each spatial axis
|
|
strides = node.attrs.get("strides", [1] * spatial_size)
|
|
pads = node.attrs.get("pads", None)
|
|
|
|
default_shape = cls._get_default_shape(input_shape, kernel_shape,
|
|
strides)
|
|
|
|
need_trans = storage_format != "NHWC"
|
|
if need_trans:
|
|
x = tf.transpose(x, perm=get_perm_from_formats(storage_format,
|
|
"NHWC"))
|
|
ind = tf.transpose(ind, perm=get_perm_from_formats(storage_format,
|
|
"NHWC"))
|
|
|
|
# default_shape to NHWC storage format
|
|
default_shape = [int(input_shape[0])] + default_shape + \
|
|
[int(input_shape[1])]
|
|
|
|
unpooled = cls._unpool(x, ind, default_shape)
|
|
|
|
if need_trans:
|
|
unpooled = tf.transpose(
|
|
unpooled, perm=get_perm_from_formats("NHWC", storage_format))
|
|
|
|
if output_shape is not None:
|
|
pads = cls._get_pads_from_output_shape(unpooled, output_shape)
|
|
if pads is not None:
|
|
unpooled = cls._pad_output(unpooled, pads, 0)
|
|
|
|
return [unpooled]
|
|
|
|
@classmethod
|
|
def _get_default_shape(cls, input_shape, kernel_shape, strides):
|
|
"""
|
|
Calculates default shape from kernel_shape and strides
|
|
Args:
|
|
input_shape: shape of the input to unpool op
|
|
kernel_shape: the size of the kernel along each axis
|
|
output_shape: stride along each spatial axis
|
|
Return:
|
|
default_shape: calculated default_shape
|
|
"""
|
|
default_shape = []
|
|
for d in range(len(kernel_shape)):
|
|
default_shape.append((int(input_shape[d + 2]) - 1) *
|
|
int(strides[d]) + int(kernel_shape[d]))
|
|
return default_shape
|
|
|
|
@classmethod
|
|
def _get_pads_from_output_shape(cls, unpool, output_shape):
|
|
"""
|
|
Calculates the paddings from specified output_shape
|
|
Args:
|
|
unpool: result from unpool operation
|
|
output_shape: expected shape of the output
|
|
Return:
|
|
pads: calculated paddings in format
|
|
[x1_begin, x2_begin,.., x1_end, x2_end]
|
|
where xi_... represent pads added to begin
|
|
or end of axis i
|
|
"""
|
|
unpool_shape = tf.cast(tf.shape(unpool), dtype=tf.int32)
|
|
new_shape = tf.cast(output_shape, dtype=tf.int32)
|
|
|
|
pads_begin = []
|
|
pads_end = []
|
|
|
|
for d in range(len(unpool.get_shape())):
|
|
pad_total = new_shape[d] - unpool_shape[d]
|
|
pad_begin = tf.cast(pad_total / 2, tf.int32)
|
|
pad_end = pad_total - pad_begin
|
|
pads_begin = pads_begin + [pad_begin]
|
|
pads_end = pads_end + [pad_end]
|
|
|
|
pads = pads_begin + pads_end
|
|
return pads
|
|
|
|
@classmethod
|
|
def _pad_output(cls, unpool, pads, constant_values):
|
|
"""
|
|
Pad the output from unpool op
|
|
Args:
|
|
unpool: result from unpool op
|
|
pads: paddings in format
|
|
[x1_begin, x2_begin,..., x1_end, x2_end]
|
|
constant_values: constant value to fill up the padded spaces
|
|
Return:
|
|
padded: padded tensor
|
|
"""
|
|
unpool_shape = unpool.get_shape()
|
|
paddings = []
|
|
for d in range(len(unpool_shape)):
|
|
paddings = paddings + [[pads[d], pads[d + len(unpool_shape)]]]
|
|
padded = tf.pad(unpool, paddings, 'CONSTANT',
|
|
constant_values=constant_values)
|
|
return padded
|
|
|
|
@classmethod
|
|
def _unpool(cls, pool, ind, output_shape, scope='unpool'):
|
|
"""
|
|
Unpooling layer after max_pool_with_argmax.
|
|
|
|
Args:
|
|
pool: max pooled output tensor
|
|
ind: argmax indices
|
|
output_shape: the shape of the output
|
|
Return:
|
|
unpool: unpooling tensor
|
|
"""
|
|
with tf.variable_scope(scope):
|
|
input_shape = tf.shape(pool)
|
|
|
|
flat_input_size = tf.reduce_prod(input_shape)
|
|
flat_output_shape = [output_shape[0], output_shape[1] *
|
|
output_shape[2] * output_shape[3]]
|
|
|
|
pool_ = tf.reshape(pool, [flat_input_size])
|
|
batch_range = tf.reshape(
|
|
tf.range(tf.cast(output_shape[0], tf.int64),
|
|
dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
|
|
b = tf.ones_like(ind) * batch_range
|
|
b1 = tf.reshape(b, [flat_input_size, 1])
|
|
ind_ = tf.reshape(ind, [flat_input_size, 1])
|
|
ind_ = tf.concat([b1, ind_], 1)
|
|
|
|
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape,
|
|
tf.int64))
|
|
ret = tf.reshape(ret, output_shape)
|
|
return ret
|