add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,218 @@
from functools import partial
import tensorflow as tf
from onnx_tf.common import get_unique_suffix
from onnx_tf.common import exception
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import partial_support
from onnx_tf.handlers.handler import ps_description
from .rnn_mixin import RNNMixin
@onnx_op("LSTM")
@partial_support(True)
@ps_description("LSTM not using sigmoid for `f`, or " +
"LSTM not using the same activation for `g` and `h` " +
"are not supported in Tensorflow.")
class LSTM(RNNMixin, BackendHandler):
@classmethod
def args_check(cls, node, **kwargs):
direction = node.attrs.get("direction", "forward")
num_directions = 2 if direction == "bidirectional" else 1
if node.attrs.get("input_forget", 0):
# TODO(fumihwh): warning
pass
if "activations" in node.attrs:
activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
if activations[0] != "sigmoid":
exception.OP_UNSUPPORTED_EXCEPT("LSTM without sigmoid for `f`",
"Tensorflow")
if activations[1] != activations[2]:
exception.OP_UNSUPPORTED_EXCEPT(
"LSTM without same activation for `g` and `h`", "Tensorflow")
if num_directions == 2:
if activations[3] != "sigmoid":
exception.OP_UNSUPPORTED_EXCEPT("LSTM without sigmoid for `f`",
"Tensorflow")
if activations[4] != activations[5]:
exception.OP_UNSUPPORTED_EXCEPT(
"LSTM without same activation for `g` and `h`", "Tensorflow")
@classmethod
def _custom_getter(cls,
getter,
name,
node=None,
tensor_dict=None,
is_bidirectional=None,
*args,
**kwargs):
names = name.split("/")
if is_bidirectional:
if "fw" in names:
index = 0
elif "bw" in names:
index = 1
else:
raise RuntimeError("Can not get {} for bidirectional. "
"Either fw and bw is not in name scope.".format(
names[-1]))
if names[-1] == "kernel":
# onnx W[iofc], R[iofc]
if is_bidirectional:
w = tf.split(tensor_dict[node.inputs[1]], 2)[index]
r = tf.split(tensor_dict[node.inputs[2]], 2)[index]
else:
w = tensor_dict[node.inputs[1]]
r = tensor_dict[node.inputs[2]]
w_i, w_o, w_f, w_c = tf.split(tf.squeeze(w), 4)
r_i, r_o, r_f, r_c = tf.split(tf.squeeze(r), 4)
new_w = tf.transpose(tf.concat([w_i, w_c, w_f, w_o], 0))
new_r = tf.transpose(tf.concat([r_i, r_c, r_f, r_o], 0))
kernel = tf.concat([new_w, new_r], 0)
return kernel
if names[-1] == "bias":
if len(node.inputs) >= 4:
# onnx Wb[iofc], Rb[iofc]
if is_bidirectional:
b = tf.split(tensor_dict[node.inputs[3]], 2)[index]
else:
b = tensor_dict[node.inputs[3]]
w_b, r_b = tf.split(tf.squeeze(b), 2)
w_b_i, w_b_o, w_b_f, w_b_c = tf.split(w_b, 4)
r_b_i, r_b_o, r_b_f, r_b_c = tf.split(r_b, 4)
w_b = tf.transpose(tf.concat([w_b_i, w_b_c, w_b_f, w_b_o], 0))
r_b = tf.transpose(tf.concat([r_b_i, r_b_c, r_b_f, r_b_o], 0))
return tf.add(w_b, r_b)
return getter(name, *args, **kwargs)
# Only use_peepholes is True,
# will try to get w_f_diag, w_i_diag, w_o_diag
# onnx P[iof]
if names[-1] in ["w_f_diag", "w_i_diag", "w_o_diag"]:
if is_bidirectional:
p = tf.split(tensor_dict[node.inputs[7]], 2)[index]
else:
p = tensor_dict[node.inputs[7]]
if names[-1] == "w_f_diag":
return tf.split(p, 3, axis=1)[2]
if names[-1] == "w_i_diag":
return tf.split(p, 3, axis=1)[0]
if names[-1] == "w_o_diag":
return tf.split(p, 3, axis=1)[1]
return getter(name, *args, **kwargs)
@classmethod
def _common(cls, node, **kwargs):
tensor_dict = kwargs["tensor_dict"]
x = tensor_dict[node.inputs[0]]
input_shape = x.get_shape().as_list()
input_size = len(node.inputs)
hidden_size = node.attrs["hidden_size"]
direction = node.attrs.get("direction", "forward")
num_directions = 2 if direction == "bidirectional" else 1
# removed from version 7, default is 0
output_sequence = node.attrs.get("output_sequence", 0)
# TODO(fumihwh): check if prev node is one of RNN
# process input if it comes from other previous cell
# which has shape [seq_length, num_directions, batch_size, hidden_size]
if len(input_shape) == 4 and input_shape[1] == 1:
x = tf.squeeze(x)
sequence_length = None
if input_size >= 5 and node.inputs[4] in tensor_dict:
sequence_length = tensor_dict[node.inputs[4]]
cell_kwargs = {}
if "clip" in node.attrs:
cell_kwargs["cell_clip"] = node.attrs["clip"]
tf_activations = [tf.nn.tanh] * num_directions
if "activations" in node.attrs:
activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
activation_alpha = node.attrs.get("activation_alpha", [None] * 6)
activation_beta = node.attrs.get("activation_beta", [None] * 6)
# tf only supports cutomizing hidden states activation function,
# which correspond to activation functions specified at position 1
# and 4 in onnx's activations attribute.
activation_idxs = [1, 4] if num_directions == 2 else [1]
tf_activations = [
cls.rnn_get_activation(activations[i], activation_alpha[i],
activation_beta[i]) for i in activation_idxs
]
# TODO(fumihwh): check if reverse and bidirectional works
with tf.variable_scope("LSTM_" + get_unique_suffix(),
custom_getter=partial(
cls._custom_getter,
node=node,
tensor_dict=tensor_dict,
is_bidirectional=num_directions == 2)):
cell_kwargs[
"use_peepholes"] = input_size == 8 and node.inputs[7] in tensor_dict
cell_kwargs["forget_bias"] = 0.
cell_kwargs["num_units"] = hidden_size
initial_state = None
initial_state_bw = None
if input_size >= 6:
initial_h = tensor_dict.get(node.inputs[5], None)
initial_c = tensor_dict.get(
node.inputs[6],
None) if input_size >= 7 else tf.zeros_like(initial_h)
if initial_h is not None and initial_c is not None:
initial_state = (tf.nn.rnn_cell.LSTMStateTuple(
initial_c[0], initial_h[0]),)
if num_directions == 2:
initial_state_bw = (tf.nn.rnn_cell.LSTMStateTuple(
initial_c[1], initial_h[1]),)
rnn_kwargs = {}
if num_directions == 1:
rnn_kwargs["initial_state"] = initial_state
elif num_directions == 2:
rnn_kwargs["initial_state_fw"] = initial_state
rnn_kwargs["initial_state_bw"] = initial_state_bw
rnn_kwargs["sequence_length"] = sequence_length
rnn_kwargs["time_major"] = True
rnn_kwargs["dtype"] = tf.float32
outputs, states = cls.rnn(x, tf.nn.rnn_cell.LSTMCell, cell_kwargs,
rnn_kwargs, tf_activations, direction)
if num_directions == 1:
state = states[0]
c = tf.expand_dims(state[0], 0)
h = tf.expand_dims(state[1], 0)
output = tf.expand_dims(outputs, 1)
else:
state_fw = states[0][0]
state_bw = states[1][0]
output_fw = outputs[0]
output_bw = outputs[1]
c_fw = tf.expand_dims(state_fw[0], 0)
c_bw = tf.expand_dims(state_bw[0], 0)
c = tf.concat((c_fw, c_bw), axis=0)
h_fw = tf.expand_dims(state_fw[1], 0)
h_bw = tf.expand_dims(state_bw[1], 0)
h = tf.concat((h_fw, h_bw), axis=0)
output_fw = tf.expand_dims(output_fw, 1)
output_bw = tf.expand_dims(output_bw, 1)
output = tf.concat((output_fw, output_bw), axis=1)
return [output, h, c] if output_sequence == 0 else [h, c]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_7(cls, node, **kwargs):
return cls._common(node, **kwargs)