add pt2tf tool
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
from onnx_tf.handlers.handler import tf_func
|
||||
|
||||
|
||||
@onnx_op("Pad")
|
||||
@tf_func(tf.pad)
|
||||
class Pad(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def _common(cls, node, **kwargs):
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
x = tensor_dict[node.inputs[0]]
|
||||
num_dim = len(tensor_dict[node.inputs[0]].get_shape())
|
||||
mode = node.attrs.pop("mode", "constant")
|
||||
|
||||
if cls.SINCE_VERSION < 11: # for opset 1 and opset 2
|
||||
paddings = node.attrs.pop("pads", None)
|
||||
# tf requires int32 paddings
|
||||
paddings = tf.constant(
|
||||
np.transpose(
|
||||
np.array(paddings).reshape([2, num_dim]).astype(np.int32)))
|
||||
constant_values = node.attrs.pop("value", 0.)
|
||||
|
||||
else: # for opset 11
|
||||
paddings = tensor_dict[node.inputs[1]]
|
||||
# tf requires int32 paddings
|
||||
paddings = tf.cast(tf.transpose(tf.reshape(paddings, [2, num_dim])),
|
||||
dtype=tf.int32)
|
||||
constant_values = tensor_dict[node.inputs[2]] if len(
|
||||
node.inputs) == 3 else 0
|
||||
|
||||
def _symmetric_pad(i, x):
|
||||
paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings)
|
||||
paddings_i = tf.reshape(paddings_i, [num_dim, 2])
|
||||
x = tf.pad(x, paddings_i, 'SYMMETRIC')
|
||||
return i + 1, x
|
||||
|
||||
if mode.lower() == "edge":
|
||||
paddings = tf.reshape(paddings, [-1])
|
||||
max_i = tf.reduce_max(paddings)
|
||||
_, x = tf.while_loop(
|
||||
lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x],
|
||||
[tf.TensorShape([]), tf.TensorShape(None)])
|
||||
return [x]
|
||||
|
||||
return [
|
||||
cls.make_tensor_from_onnx_node(
|
||||
node, inputs=[x, paddings, mode, None, constant_values], **kwargs)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def version_1(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_2(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
Reference in New Issue
Block a user