Files
2020-10-14 08:55:07 +08:00

126 lines
4.4 KiB
Python

import tensorflow as tf
from onnx_tf.common import exception
from onnx_tf.common import get_data_format
from onnx_tf.common import get_perm_from_formats
from onnx_tf.common import logger
from .dilated_pooling import DilatedPooling
from onnx_tf.common.pooling_helper import py_pool
from onnx_tf.common.pooling_helper import calc_pads_same
from onnx_tf.common.pooling_helper import calc_output_shape
class PoolMixin(object):
@classmethod
def pool(cls, node, input_dict, pooling_type, strict=True):
x = input_dict[node.inputs[0]]
orig_x = x
kernel_shape = node.attrs["kernel_shape"]
spatial_size = len(kernel_shape)
x_rank = spatial_size + 2
kernel_shape = node.attrs["kernel_shape"]
strides = node.attrs.get("strides", [1] * spatial_size)
dilations = node.attrs.get("dilations", [1] * spatial_size)
ceil_mode = bool(node.attrs.get("ceil_mode", 0))
pads = node.attrs.get("auto_pad", "NOTSET")
p = node.attrs.get("p", 2)
if pads == "NOTSET":
pads = node.attrs.get("pads", [0] * spatial_size * 2)
# In case shape is fully defined, check if pads match
# SAME padding in Tensorflow
if x.shape.is_fully_defined() and pads != [0] * spatial_size * 2:
in_shape = x.get_shape().as_list()
same_paddings = calc_pads_same(in_shape[1:x_rank-1], kernel_shape,
strides, dilations, "SAME_UPPER")
if pads == same_paddings:
pads = "SAME_UPPER"
count_include_pad = bool(node.attrs.get("count_include_pad", 0))
if pooling_type == "AVG":
pooling_name = "AveragePool"
elif pooling_type == "MAX":
pooling_name = "MaxPool"
elif pooling_type == "MAX_WITH_ARGMAX":
pooling_name = "MaxPoolWithArgmax"
elif pooling_type == "LP":
pooling_name = "LpPool"
if spatial_size > 3:
exception.OP_UNSUPPORTED_EXCEPT(
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
exception.OP_UNSUPPORTED_EXCEPT(
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
if node.attrs.get("storage_order", 0) != 0:
exception.OP_UNSUPPORTED_EXCEPT(pooling_name + " with column major",
"Tensorflow")
storage_format, _ = get_data_format(x_rank)
need_trans = storage_format.startswith("NC")
if need_trans:
compute_format = "N" + storage_format[2:] + "C"
x = tf.transpose(
x, perm=get_perm_from_formats(storage_format, compute_format))
dp = DilatedPooling(
input=x,
kernel_shape=kernel_shape,
strides=strides,
dilations=dilations,
padding=pads,
ceil_mode=ceil_mode,
pooling_type=pooling_type,
count_include_pad=count_include_pad,
p=p)
if not dp.is_supported():
if strict:
logger.warning(
"Using the pooling op in compatibility mode. "
"This means your graph cannot be serialized.", UserWarning)
result = tf.py_func(py_pool, [
orig_x, kernel_shape, strides, dilations, pads, ceil_mode,
pooling_type, False
], orig_x.dtype)
if orig_x.shape.is_fully_defined():
shape = orig_x.get_shape().as_list()
output_shape = shape[0:2] + calc_output_shape(shape[2:x_rank],
kernel_shape, strides, dilations, pads, ceil_mode)
else:
output_shape = [None] * x_rank
result.set_shape(output_shape)
return [result]
else:
exception.OP_UNSUPPORTED_EXCEPT("strict == 0 and " + pooling_name +
" arguments not compatible",
"Tensorflow")
def dilated_pool():
return (dp.dilated_pool(), None)
# select correct op depending on the pooling type
pooling_op = dilated_pool if pooling_type in ["MAX", "AVG", "LP"] else \
dp.dilated_maxpool_with_argmax
# select the correct transpose ops depending on the input storage format
perm = get_perm_from_formats(compute_format, storage_format)
def postprocess(pooled, argmax):
return (tf.transpose(pooled, perm=perm) if need_trans else pooled,
tf.transpose(argmax, perm=perm)
if need_trans and argmax is not None else argmax)
pooled, argmax = pooling_op()
pooled, argmax = postprocess(pooled, argmax)
result = [pooled] if argmax is None else [pooled, argmax]
return result