66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
from onnx_tf.handlers.handler import tf_func
|
|
|
|
|
|
@onnx_op("Pad")
|
|
@tf_func(tf.pad)
|
|
class Pad(BackendHandler):
|
|
|
|
@classmethod
|
|
def _common(cls, node, **kwargs):
|
|
tensor_dict = kwargs["tensor_dict"]
|
|
x = tensor_dict[node.inputs[0]]
|
|
num_dim = len(tensor_dict[node.inputs[0]].get_shape())
|
|
mode = node.attrs.pop("mode", "constant")
|
|
|
|
if cls.SINCE_VERSION < 11: # for opset 1 and opset 2
|
|
paddings = node.attrs.pop("pads", None)
|
|
# tf requires int32 paddings
|
|
paddings = tf.constant(
|
|
np.transpose(
|
|
np.array(paddings).reshape([2, num_dim]).astype(np.int32)))
|
|
constant_values = node.attrs.pop("value", 0.)
|
|
|
|
else: # for opset 11
|
|
paddings = tensor_dict[node.inputs[1]]
|
|
# tf requires int32 paddings
|
|
paddings = tf.cast(tf.transpose(tf.reshape(paddings, [2, num_dim])),
|
|
dtype=tf.int32)
|
|
constant_values = tensor_dict[node.inputs[2]] if len(
|
|
node.inputs) == 3 else 0
|
|
|
|
def _symmetric_pad(i, x):
|
|
paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings)
|
|
paddings_i = tf.reshape(paddings_i, [num_dim, 2])
|
|
x = tf.pad(x, paddings_i, 'SYMMETRIC')
|
|
return i + 1, x
|
|
|
|
if mode.lower() == "edge":
|
|
paddings = tf.reshape(paddings, [-1])
|
|
max_i = tf.reduce_max(paddings)
|
|
_, x = tf.while_loop(
|
|
lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x],
|
|
[tf.TensorShape([]), tf.TensorShape(None)])
|
|
return [x]
|
|
|
|
return [
|
|
cls.make_tensor_from_onnx_node(
|
|
node, inputs=[x, paddings, mode, None, constant_values], **kwargs)
|
|
]
|
|
|
|
@classmethod
|
|
def version_1(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_2(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_11(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|