import tensorflow as tf from onnx_tf.common import get_data_format from onnx_tf.common import get_perm_from_formats class UnpoolMixin(object): @classmethod def max_unpool(cls, node, input_dict): """ MaxUnpooling operation """ x = input_dict[node.inputs[0]] ind = input_dict[node.inputs[1]] if len(node.inputs) > 2: output_shape = input_dict.get(node.inputs[2], None) else: output_shape = None input_shape = x.get_shape() x_rank = len(x.get_shape()) spatial_size = x_rank - 2 storage_format, _ = get_data_format(x_rank) kernel_shape = node.attrs["kernel_shape"] # if strides are not provided default is 1 along each spatial axis strides = node.attrs.get("strides", [1] * spatial_size) pads = node.attrs.get("pads", None) default_shape = cls._get_default_shape(input_shape, kernel_shape, strides) need_trans = storage_format != "NHWC" if need_trans: x = tf.transpose(x, perm=get_perm_from_formats(storage_format, "NHWC")) ind = tf.transpose(ind, perm=get_perm_from_formats(storage_format, "NHWC")) # default_shape to NHWC storage format default_shape = [int(input_shape[0])] + default_shape + \ [int(input_shape[1])] unpooled = cls._unpool(x, ind, default_shape) if need_trans: unpooled = tf.transpose( unpooled, perm=get_perm_from_formats("NHWC", storage_format)) if output_shape is not None: pads = cls._get_pads_from_output_shape(unpooled, output_shape) if pads is not None: unpooled = cls._pad_output(unpooled, pads, 0) return [unpooled] @classmethod def _get_default_shape(cls, input_shape, kernel_shape, strides): """ Calculates default shape from kernel_shape and strides Args: input_shape: shape of the input to unpool op kernel_shape: the size of the kernel along each axis output_shape: stride along each spatial axis Return: default_shape: calculated default_shape """ default_shape = [] for d in range(len(kernel_shape)): default_shape.append((int(input_shape[d + 2]) - 1) * int(strides[d]) + int(kernel_shape[d])) return default_shape @classmethod def _get_pads_from_output_shape(cls, unpool, output_shape): """ Calculates the paddings from specified output_shape Args: unpool: result from unpool operation output_shape: expected shape of the output Return: pads: calculated paddings in format [x1_begin, x2_begin,.., x1_end, x2_end] where xi_... represent pads added to begin or end of axis i """ unpool_shape = tf.cast(tf.shape(unpool), dtype=tf.int32) new_shape = tf.cast(output_shape, dtype=tf.int32) pads_begin = [] pads_end = [] for d in range(len(unpool.get_shape())): pad_total = new_shape[d] - unpool_shape[d] pad_begin = tf.cast(pad_total / 2, tf.int32) pad_end = pad_total - pad_begin pads_begin = pads_begin + [pad_begin] pads_end = pads_end + [pad_end] pads = pads_begin + pads_end return pads @classmethod def _pad_output(cls, unpool, pads, constant_values): """ Pad the output from unpool op Args: unpool: result from unpool op pads: paddings in format [x1_begin, x2_begin,..., x1_end, x2_end] constant_values: constant value to fill up the padded spaces Return: padded: padded tensor """ unpool_shape = unpool.get_shape() paddings = [] for d in range(len(unpool_shape)): paddings = paddings + [[pads[d], pads[d + len(unpool_shape)]]] padded = tf.pad(unpool, paddings, 'CONSTANT', constant_values=constant_values) return padded @classmethod def _unpool(cls, pool, ind, output_shape, scope='unpool'): """ Unpooling layer after max_pool_with_argmax. Args: pool: max pooled output tensor ind: argmax indices output_shape: the shape of the output Return: unpool: unpooling tensor """ with tf.variable_scope(scope): input_shape = tf.shape(pool) flat_input_size = tf.reduce_prod(input_shape) flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] pool_ = tf.reshape(pool, [flat_input_size]) batch_range = tf.reshape( tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) b = tf.ones_like(ind) * batch_range b1 = tf.reshape(b, [flat_input_size, 1]) ind_ = tf.reshape(ind, [flat_input_size, 1]) ind_ = tf.concat([b1, ind_], 1) ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64)) ret = tf.reshape(ret, output_shape) return ret