重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.common import get_data_format
|
||||
from onnx_tf.common import get_perm_from_formats
|
||||
|
||||
|
||||
class UnpoolMixin(object):
|
||||
|
||||
@classmethod
|
||||
def max_unpool(cls, node, input_dict):
|
||||
"""
|
||||
MaxUnpooling operation
|
||||
"""
|
||||
x = input_dict[node.inputs[0]]
|
||||
ind = input_dict[node.inputs[1]]
|
||||
if len(node.inputs) > 2:
|
||||
output_shape = input_dict.get(node.inputs[2], None)
|
||||
else:
|
||||
output_shape = None
|
||||
|
||||
input_shape = x.get_shape()
|
||||
x_rank = len(x.get_shape())
|
||||
spatial_size = x_rank - 2
|
||||
storage_format, _ = get_data_format(x_rank)
|
||||
|
||||
kernel_shape = node.attrs["kernel_shape"]
|
||||
# if strides are not provided default is 1 along each spatial axis
|
||||
strides = node.attrs.get("strides", [1] * spatial_size)
|
||||
pads = node.attrs.get("pads", None)
|
||||
|
||||
default_shape = cls._get_default_shape(input_shape, kernel_shape,
|
||||
strides)
|
||||
|
||||
need_trans = storage_format != "NHWC"
|
||||
if need_trans:
|
||||
x = tf.transpose(x, perm=get_perm_from_formats(storage_format,
|
||||
"NHWC"))
|
||||
ind = tf.transpose(ind, perm=get_perm_from_formats(storage_format,
|
||||
"NHWC"))
|
||||
|
||||
# default_shape to NHWC storage format
|
||||
default_shape = [int(input_shape[0])] + default_shape + \
|
||||
[int(input_shape[1])]
|
||||
|
||||
unpooled = cls._unpool(x, ind, default_shape)
|
||||
|
||||
if need_trans:
|
||||
unpooled = tf.transpose(
|
||||
unpooled, perm=get_perm_from_formats("NHWC", storage_format))
|
||||
|
||||
if output_shape is not None:
|
||||
pads = cls._get_pads_from_output_shape(unpooled, output_shape)
|
||||
if pads is not None:
|
||||
unpooled = cls._pad_output(unpooled, pads, 0)
|
||||
|
||||
return [unpooled]
|
||||
|
||||
@classmethod
|
||||
def _get_default_shape(cls, input_shape, kernel_shape, strides):
|
||||
"""
|
||||
Calculates default shape from kernel_shape and strides
|
||||
Args:
|
||||
input_shape: shape of the input to unpool op
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
output_shape: stride along each spatial axis
|
||||
Return:
|
||||
default_shape: calculated default_shape
|
||||
"""
|
||||
default_shape = []
|
||||
for d in range(len(kernel_shape)):
|
||||
default_shape.append((int(input_shape[d + 2]) - 1) *
|
||||
int(strides[d]) + int(kernel_shape[d]))
|
||||
return default_shape
|
||||
|
||||
@classmethod
|
||||
def _get_pads_from_output_shape(cls, unpool, output_shape):
|
||||
"""
|
||||
Calculates the paddings from specified output_shape
|
||||
Args:
|
||||
unpool: result from unpool operation
|
||||
output_shape: expected shape of the output
|
||||
Return:
|
||||
pads: calculated paddings in format
|
||||
[x1_begin, x2_begin,.., x1_end, x2_end]
|
||||
where xi_... represent pads added to begin
|
||||
or end of axis i
|
||||
"""
|
||||
unpool_shape = tf.cast(tf.shape(unpool), dtype=tf.int32)
|
||||
new_shape = tf.cast(output_shape, dtype=tf.int32)
|
||||
|
||||
pads_begin = []
|
||||
pads_end = []
|
||||
|
||||
for d in range(len(unpool.get_shape())):
|
||||
pad_total = new_shape[d] - unpool_shape[d]
|
||||
pad_begin = tf.cast(pad_total / 2, tf.int32)
|
||||
pad_end = pad_total - pad_begin
|
||||
pads_begin = pads_begin + [pad_begin]
|
||||
pads_end = pads_end + [pad_end]
|
||||
|
||||
pads = pads_begin + pads_end
|
||||
return pads
|
||||
|
||||
@classmethod
|
||||
def _pad_output(cls, unpool, pads, constant_values):
|
||||
"""
|
||||
Pad the output from unpool op
|
||||
Args:
|
||||
unpool: result from unpool op
|
||||
pads: paddings in format
|
||||
[x1_begin, x2_begin,..., x1_end, x2_end]
|
||||
constant_values: constant value to fill up the padded spaces
|
||||
Return:
|
||||
padded: padded tensor
|
||||
"""
|
||||
unpool_shape = unpool.get_shape()
|
||||
paddings = []
|
||||
for d in range(len(unpool_shape)):
|
||||
paddings = paddings + [[pads[d], pads[d + len(unpool_shape)]]]
|
||||
padded = tf.pad(unpool, paddings, 'CONSTANT',
|
||||
constant_values=constant_values)
|
||||
return padded
|
||||
|
||||
@classmethod
|
||||
def _unpool(cls, pool, ind, output_shape, scope='unpool'):
|
||||
"""
|
||||
Unpooling layer after max_pool_with_argmax.
|
||||
|
||||
Args:
|
||||
pool: max pooled output tensor
|
||||
ind: argmax indices
|
||||
output_shape: the shape of the output
|
||||
Return:
|
||||
unpool: unpooling tensor
|
||||
"""
|
||||
with tf.variable_scope(scope):
|
||||
input_shape = tf.shape(pool)
|
||||
|
||||
flat_input_size = tf.reduce_prod(input_shape)
|
||||
flat_output_shape = [output_shape[0], output_shape[1] *
|
||||
output_shape[2] * output_shape[3]]
|
||||
|
||||
pool_ = tf.reshape(pool, [flat_input_size])
|
||||
batch_range = tf.reshape(
|
||||
tf.range(tf.cast(output_shape[0], tf.int64),
|
||||
dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
|
||||
b = tf.ones_like(ind) * batch_range
|
||||
b1 = tf.reshape(b, [flat_input_size, 1])
|
||||
ind_ = tf.reshape(ind, [flat_input_size, 1])
|
||||
ind_ = tf.concat([b1, ind_], 1)
|
||||
|
||||
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape,
|
||||
tf.int64))
|
||||
ret = tf.reshape(ret, output_shape)
|
||||
return ret
|
||||
Reference in New Issue
Block a user