126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
import tensorflow as tf
|
|
|
|
from onnx_tf.common import exception
|
|
from onnx_tf.common import get_data_format
|
|
from onnx_tf.common import get_perm_from_formats
|
|
|
|
from onnx_tf.common import logger
|
|
from .dilated_pooling import DilatedPooling
|
|
from onnx_tf.common.pooling_helper import py_pool
|
|
from onnx_tf.common.pooling_helper import calc_pads_same
|
|
from onnx_tf.common.pooling_helper import calc_output_shape
|
|
|
|
class PoolMixin(object):
|
|
|
|
@classmethod
|
|
def pool(cls, node, input_dict, pooling_type, strict=True):
|
|
x = input_dict[node.inputs[0]]
|
|
orig_x = x
|
|
|
|
kernel_shape = node.attrs["kernel_shape"]
|
|
|
|
spatial_size = len(kernel_shape)
|
|
x_rank = spatial_size + 2
|
|
|
|
kernel_shape = node.attrs["kernel_shape"]
|
|
strides = node.attrs.get("strides", [1] * spatial_size)
|
|
dilations = node.attrs.get("dilations", [1] * spatial_size)
|
|
ceil_mode = bool(node.attrs.get("ceil_mode", 0))
|
|
pads = node.attrs.get("auto_pad", "NOTSET")
|
|
p = node.attrs.get("p", 2)
|
|
|
|
if pads == "NOTSET":
|
|
pads = node.attrs.get("pads", [0] * spatial_size * 2)
|
|
# In case shape is fully defined, check if pads match
|
|
# SAME padding in Tensorflow
|
|
if x.shape.is_fully_defined() and pads != [0] * spatial_size * 2:
|
|
in_shape = x.get_shape().as_list()
|
|
same_paddings = calc_pads_same(in_shape[1:x_rank-1], kernel_shape,
|
|
strides, dilations, "SAME_UPPER")
|
|
if pads == same_paddings:
|
|
pads = "SAME_UPPER"
|
|
|
|
count_include_pad = bool(node.attrs.get("count_include_pad", 0))
|
|
if pooling_type == "AVG":
|
|
pooling_name = "AveragePool"
|
|
elif pooling_type == "MAX":
|
|
pooling_name = "MaxPool"
|
|
elif pooling_type == "MAX_WITH_ARGMAX":
|
|
pooling_name = "MaxPoolWithArgmax"
|
|
elif pooling_type == "LP":
|
|
pooling_name = "LpPool"
|
|
|
|
if spatial_size > 3:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
|
|
if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
|
|
if node.attrs.get("storage_order", 0) != 0:
|
|
exception.OP_UNSUPPORTED_EXCEPT(pooling_name + " with column major",
|
|
"Tensorflow")
|
|
|
|
storage_format, _ = get_data_format(x_rank)
|
|
|
|
need_trans = storage_format.startswith("NC")
|
|
if need_trans:
|
|
compute_format = "N" + storage_format[2:] + "C"
|
|
x = tf.transpose(
|
|
x, perm=get_perm_from_formats(storage_format, compute_format))
|
|
|
|
dp = DilatedPooling(
|
|
input=x,
|
|
kernel_shape=kernel_shape,
|
|
strides=strides,
|
|
dilations=dilations,
|
|
padding=pads,
|
|
ceil_mode=ceil_mode,
|
|
pooling_type=pooling_type,
|
|
count_include_pad=count_include_pad,
|
|
p=p)
|
|
if not dp.is_supported():
|
|
if strict:
|
|
logger.warning(
|
|
"Using the pooling op in compatibility mode. "
|
|
"This means your graph cannot be serialized.", UserWarning)
|
|
|
|
result = tf.py_func(py_pool, [
|
|
orig_x, kernel_shape, strides, dilations, pads, ceil_mode,
|
|
pooling_type, False
|
|
], orig_x.dtype)
|
|
|
|
if orig_x.shape.is_fully_defined():
|
|
shape = orig_x.get_shape().as_list()
|
|
output_shape = shape[0:2] + calc_output_shape(shape[2:x_rank],
|
|
kernel_shape, strides, dilations, pads, ceil_mode)
|
|
else:
|
|
output_shape = [None] * x_rank
|
|
result.set_shape(output_shape)
|
|
return [result]
|
|
else:
|
|
exception.OP_UNSUPPORTED_EXCEPT("strict == 0 and " + pooling_name +
|
|
" arguments not compatible",
|
|
"Tensorflow")
|
|
|
|
def dilated_pool():
|
|
return (dp.dilated_pool(), None)
|
|
|
|
# select correct op depending on the pooling type
|
|
pooling_op = dilated_pool if pooling_type in ["MAX", "AVG", "LP"] else \
|
|
dp.dilated_maxpool_with_argmax
|
|
|
|
# select the correct transpose ops depending on the input storage format
|
|
perm = get_perm_from_formats(compute_format, storage_format)
|
|
|
|
def postprocess(pooled, argmax):
|
|
return (tf.transpose(pooled, perm=perm) if need_trans else pooled,
|
|
tf.transpose(argmax, perm=perm)
|
|
if need_trans and argmax is not None else argmax)
|
|
|
|
pooled, argmax = pooling_op()
|
|
pooled, argmax = postprocess(pooled, argmax)
|
|
|
|
result = [pooled] if argmax is None else [pooled, argmax]
|
|
|
|
return result
|