重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.common import exception
|
||||
from onnx_tf.common import get_data_format
|
||||
from onnx_tf.common import get_perm_from_formats
|
||||
|
||||
from onnx_tf.common import logger
|
||||
from .dilated_pooling import DilatedPooling
|
||||
from onnx_tf.common.pooling_helper import py_pool
|
||||
from onnx_tf.common.pooling_helper import calc_pads_same
|
||||
from onnx_tf.common.pooling_helper import calc_output_shape
|
||||
|
||||
class PoolMixin(object):
|
||||
|
||||
@classmethod
|
||||
def pool(cls, node, input_dict, pooling_type, strict=True):
|
||||
x = input_dict[node.inputs[0]]
|
||||
orig_x = x
|
||||
|
||||
kernel_shape = node.attrs["kernel_shape"]
|
||||
|
||||
spatial_size = len(kernel_shape)
|
||||
x_rank = spatial_size + 2
|
||||
|
||||
kernel_shape = node.attrs["kernel_shape"]
|
||||
strides = node.attrs.get("strides", [1] * spatial_size)
|
||||
dilations = node.attrs.get("dilations", [1] * spatial_size)
|
||||
ceil_mode = bool(node.attrs.get("ceil_mode", 0))
|
||||
pads = node.attrs.get("auto_pad", "NOTSET")
|
||||
p = node.attrs.get("p", 2)
|
||||
|
||||
if pads == "NOTSET":
|
||||
pads = node.attrs.get("pads", [0] * spatial_size * 2)
|
||||
# In case shape is fully defined, check if pads match
|
||||
# SAME padding in Tensorflow
|
||||
if x.shape.is_fully_defined() and pads != [0] * spatial_size * 2:
|
||||
in_shape = x.get_shape().as_list()
|
||||
same_paddings = calc_pads_same(in_shape[1:x_rank-1], kernel_shape,
|
||||
strides, dilations, "SAME_UPPER")
|
||||
if pads == same_paddings:
|
||||
pads = "SAME_UPPER"
|
||||
|
||||
count_include_pad = bool(node.attrs.get("count_include_pad", 0))
|
||||
if pooling_type == "AVG":
|
||||
pooling_name = "AveragePool"
|
||||
elif pooling_type == "MAX":
|
||||
pooling_name = "MaxPool"
|
||||
elif pooling_type == "MAX_WITH_ARGMAX":
|
||||
pooling_name = "MaxPoolWithArgmax"
|
||||
elif pooling_type == "LP":
|
||||
pooling_name = "LpPool"
|
||||
|
||||
if spatial_size > 3:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
|
||||
if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
pooling_name + " with {}D input".format(x_rank), "Tensorflow")
|
||||
if node.attrs.get("storage_order", 0) != 0:
|
||||
exception.OP_UNSUPPORTED_EXCEPT(pooling_name + " with column major",
|
||||
"Tensorflow")
|
||||
|
||||
storage_format, _ = get_data_format(x_rank)
|
||||
|
||||
need_trans = storage_format.startswith("NC")
|
||||
if need_trans:
|
||||
compute_format = "N" + storage_format[2:] + "C"
|
||||
x = tf.transpose(
|
||||
x, perm=get_perm_from_formats(storage_format, compute_format))
|
||||
|
||||
dp = DilatedPooling(
|
||||
input=x,
|
||||
kernel_shape=kernel_shape,
|
||||
strides=strides,
|
||||
dilations=dilations,
|
||||
padding=pads,
|
||||
ceil_mode=ceil_mode,
|
||||
pooling_type=pooling_type,
|
||||
count_include_pad=count_include_pad,
|
||||
p=p)
|
||||
if not dp.is_supported():
|
||||
if strict:
|
||||
logger.warning(
|
||||
"Using the pooling op in compatibility mode. "
|
||||
"This means your graph cannot be serialized.", UserWarning)
|
||||
|
||||
result = tf.py_func(py_pool, [
|
||||
orig_x, kernel_shape, strides, dilations, pads, ceil_mode,
|
||||
pooling_type, False
|
||||
], orig_x.dtype)
|
||||
|
||||
if orig_x.shape.is_fully_defined():
|
||||
shape = orig_x.get_shape().as_list()
|
||||
output_shape = shape[0:2] + calc_output_shape(shape[2:x_rank],
|
||||
kernel_shape, strides, dilations, pads, ceil_mode)
|
||||
else:
|
||||
output_shape = [None] * x_rank
|
||||
result.set_shape(output_shape)
|
||||
return [result]
|
||||
else:
|
||||
exception.OP_UNSUPPORTED_EXCEPT("strict == 0 and " + pooling_name +
|
||||
" arguments not compatible",
|
||||
"Tensorflow")
|
||||
|
||||
def dilated_pool():
|
||||
return (dp.dilated_pool(), None)
|
||||
|
||||
# select correct op depending on the pooling type
|
||||
pooling_op = dilated_pool if pooling_type in ["MAX", "AVG", "LP"] else \
|
||||
dp.dilated_maxpool_with_argmax
|
||||
|
||||
# select the correct transpose ops depending on the input storage format
|
||||
perm = get_perm_from_formats(compute_format, storage_format)
|
||||
|
||||
def postprocess(pooled, argmax):
|
||||
return (tf.transpose(pooled, perm=perm) if need_trans else pooled,
|
||||
tf.transpose(argmax, perm=perm)
|
||||
if need_trans and argmax is not None else argmax)
|
||||
|
||||
pooled, argmax = pooling_op()
|
||||
pooled, argmax = postprocess(pooled, argmax)
|
||||
|
||||
result = [pooled] if argmax is None else [pooled, argmax]
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user