Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/handlers/backend/pad.py
T
2020-10-14 08:55:07 +08:00

66 lines
2.0 KiB
Python

import numpy as np
import tensorflow as tf
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
@onnx_op("Pad")
@tf_func(tf.pad)
class Pad(BackendHandler):
@classmethod
def _common(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
num_dim = len(tensor_dict[node.inputs[0]].get_shape())
mode = node.attrs.pop("mode", "constant")
if cls.SINCE_VERSION < 11: # for opset 1 and opset 2
paddings = node.attrs.pop("pads", None)
# tf requires int32 paddings
paddings = tf.constant(
np.transpose(
np.array(paddings).reshape([2, num_dim]).astype(np.int32)))
constant_values = node.attrs.pop("value", 0.)
else: # for opset 11
paddings = tensor_dict[node.inputs[1]]
# tf requires int32 paddings
paddings = tf.cast(tf.transpose(tf.reshape(paddings, [2, num_dim])),
dtype=tf.int32)
constant_values = tensor_dict[node.inputs[2]] if len(
node.inputs) == 3 else 0
def _symmetric_pad(i, x):
paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings)
paddings_i = tf.reshape(paddings_i, [num_dim, 2])
x = tf.pad(x, paddings_i, 'SYMMETRIC')
return i + 1, x
if mode.lower() == "edge":
paddings = tf.reshape(paddings, [-1])
max_i = tf.reduce_max(paddings)
_, x = tf.while_loop(
lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x],
[tf.TensorShape([]), tf.TensorShape(None)])
return [x]
return [
cls.make_tensor_from_onnx_node(
node, inputs=[x, paddings, mode, None, constant_values], **kwargs)
]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_2(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)