add pt2tf tool
This commit is contained in:
@@ -0,0 +1,188 @@
|
||||
import tensorflow as tf
|
||||
from onnx.helper import make_opsetid
|
||||
import onnx_tf
|
||||
from onnx_tf.common import data_type
|
||||
|
||||
|
||||
class ScanMixin(object):
|
||||
|
||||
@classmethod
|
||||
def scan(cls, node, input_dict, strict):
|
||||
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
|
||||
|
||||
body = node.attrs["body"]
|
||||
|
||||
# in version 8, node.inputs[0] is the sequence_lens
|
||||
node_inputs = node.inputs if cls.SINCE_VERSION != 8 else \
|
||||
node.inputs[1:]
|
||||
# M
|
||||
num_scan_inputs = int(node.attrs["num_scan_inputs"])
|
||||
# N = num_inputs - M
|
||||
num_state_vars = len(node_inputs) - num_scan_inputs
|
||||
# K = num_outputs - N
|
||||
num_scan_outputs = len(node.outputs) - num_state_vars
|
||||
|
||||
"""
|
||||
Function to run subgraph used with tf.scan
|
||||
"""
|
||||
|
||||
def run_subgraph(a, b):
|
||||
input_values = {}
|
||||
# set the input values for the subgraph
|
||||
# set the values for the state variables
|
||||
for i in range(num_state_vars):
|
||||
input_values[body.input[i].name] = a[i]
|
||||
# set the values for the scan inputs
|
||||
for i in range(num_scan_inputs):
|
||||
input_values[body.input[i + num_state_vars].name] = b[i]
|
||||
|
||||
# get the tensor operations for the onnx graph
|
||||
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
|
||||
subgraph=body,
|
||||
input_values=input_values,
|
||||
tensor_dict=input_dict,
|
||||
opset=current_opset,
|
||||
strict=strict)
|
||||
# return sequence of tensors for every subgraph output
|
||||
outputs = [subgraph_tensor_dict[output.name] for output in body.output]
|
||||
return outputs
|
||||
|
||||
scan_input_axes = node.attrs.get("scan_input_axes", [0] * num_scan_inputs)
|
||||
scan_input_directions = node.attrs.get(
|
||||
"directions" if cls.SINCE_VERSION == 8 else "scan_input_directions",
|
||||
[0] * num_scan_inputs)
|
||||
scan_output_axes = node.attrs.get("scan_output_axes",
|
||||
[0] * num_scan_outputs)
|
||||
scan_output_directions = node.attrs.get("scan_output_directions",
|
||||
[0] * num_scan_outputs)
|
||||
|
||||
# if version 8 read the sequnce_lens from the first input
|
||||
if cls.SINCE_VERSION == 8:
|
||||
sequence_lens = input_dict[node.inputs[0]] \
|
||||
if node.inputs[0] != '' else None
|
||||
|
||||
inputs = [input_dict[node_input] for node_input in node_inputs]
|
||||
|
||||
scan_inputs = inputs[num_state_vars:]
|
||||
# loop over all the scan inputs and apply transpose depending
|
||||
# on input axes provided and also reverse the scan inputs if
|
||||
# reverse direction for scan is provided
|
||||
for i in range(num_scan_inputs):
|
||||
# if input axes are different than 0, use transpose to scan over
|
||||
# the provided axes
|
||||
if scan_input_axes[i] != 0:
|
||||
transpose_perm = cls._calc_transpose_perm_input(tf.rank(scan_inputs[i]),
|
||||
scan_input_axes[i])
|
||||
scan_inputs[i] = tf.transpose(scan_inputs[i], transpose_perm)
|
||||
|
||||
# check for reverse direction scans
|
||||
if scan_input_directions[i] == 1:
|
||||
# version 8 has a batch dimension
|
||||
axis = 0 if cls.SINCE_VERSION != 8 else 1
|
||||
scan_inputs[i] = tf.reverse(scan_inputs[i], [axis])
|
||||
|
||||
state_vars_init = inputs[:num_state_vars]
|
||||
|
||||
scan_outputs_init = []
|
||||
# generate sequence of zero tensors for all scan outputs
|
||||
# with the correct shape and dtype
|
||||
for scan_output in body.output[num_state_vars:]:
|
||||
tensor_type = scan_output.type.tensor_type
|
||||
shape = [
|
||||
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
|
||||
for d in tensor_type.shape.dim
|
||||
]
|
||||
dtype = data_type.onnx2tf(tensor_type.elem_type)
|
||||
scan_outputs_init.append(tf.zeros(shape, dtype=dtype))
|
||||
|
||||
# tf.scan initilizer is state_variables_init + scan_outputs_init
|
||||
initializer = state_vars_init + scan_outputs_init
|
||||
|
||||
if cls.SINCE_VERSION == 8:
|
||||
# version == 8
|
||||
# function to process the batches. it is used with tf.map_fn
|
||||
def run_batches(x):
|
||||
# state vars initial values per batch
|
||||
initial = x[0]
|
||||
# scan inputs per batch
|
||||
scan_inputs = x[1]
|
||||
# sequence length for the batch
|
||||
seq_len = x[2]
|
||||
|
||||
# slice the input to the current sequence len
|
||||
scan_inputs = [scan_input[:seq_len, ...] for scan_input in scan_inputs]
|
||||
|
||||
# run scan on the current batch
|
||||
out = tf.scan(run_subgraph,
|
||||
scan_inputs,
|
||||
initializer=initial + scan_outputs_init)
|
||||
|
||||
# pad to the original shape with zeros
|
||||
paddings = [[0, tf.shape(x[1][0], out_type=seq_len.dtype)[0] - seq_len]]
|
||||
for i in range(len(out)):
|
||||
pads = tf.concat(
|
||||
[paddings,
|
||||
tf.zeros([(tf.rank(out[i]) - 1), 2], dtype=tf.int32)],
|
||||
axis=0)
|
||||
out[i] = tf.pad(out[i], pads)
|
||||
return out
|
||||
|
||||
if sequence_lens is None:
|
||||
# if sequence_lens is None, fill it with the shape of
|
||||
# the input axis 1
|
||||
sequence_lens = tf.fill([tf.shape(scan_inputs[0])[0]],
|
||||
tf.shape(scan_inputs[0], out_type=tf.int32)[1])
|
||||
|
||||
output_types = [
|
||||
data_type.onnx2tf(output.type.tensor_type.elem_type)
|
||||
for output in body.output
|
||||
]
|
||||
# run scan for every batch
|
||||
out = tf.map_fn(run_batches,
|
||||
(state_vars_init, scan_inputs, sequence_lens),
|
||||
dtype=output_types)
|
||||
|
||||
state_vars_outputs = []
|
||||
# extract the final values of the state variables
|
||||
for state_var in out[:num_state_vars]:
|
||||
state_vars_outputs.append(
|
||||
tf.map_fn(lambda x: x[0][x[1] - 1], (state_var, sequence_lens),
|
||||
state_var.dtype))
|
||||
else:
|
||||
# version > 8
|
||||
# run the scan
|
||||
out = tf.scan(run_subgraph, scan_inputs, initializer=initializer)
|
||||
|
||||
# extract the final values of the state variables
|
||||
state_vars_outputs = [
|
||||
state_var[tf.shape(state_var)[0] - 1]
|
||||
for state_var in out[:num_state_vars]
|
||||
]
|
||||
|
||||
scan_outputs = out[num_state_vars:]
|
||||
|
||||
# post process the scan outputs depending on the directions and
|
||||
# axes provided.
|
||||
for i in range(num_scan_outputs):
|
||||
# check for reverse direction scan outputs
|
||||
if scan_output_directions[i] == 1:
|
||||
scan_outputs[i] = tf.reverse(scan_outputs[i], [0])
|
||||
|
||||
if scan_output_axes[i] != 0:
|
||||
transpose_perm = cls._calc_transpose_perm_output(
|
||||
tf.rank(scan_outputs[i]), scan_output_axes[i])
|
||||
scan_outputs[i] = tf.transpose(scan_outputs[i], transpose_perm)
|
||||
|
||||
return state_vars_outputs + scan_outputs
|
||||
|
||||
@classmethod
|
||||
def _calc_transpose_perm_input(cls, rank, axis):
|
||||
if axis < 0:
|
||||
axis = rank + axis
|
||||
return tf.concat([[axis], tf.range(axis), tf.range(axis + 1, rank)], 0)
|
||||
|
||||
@classmethod
|
||||
def _calc_transpose_perm_output(cls, rank, axis):
|
||||
if axis < 0:
|
||||
axis = rank + axis
|
||||
return tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, rank)], 0)
|
||||
Reference in New Issue
Block a user