189 lines
6.8 KiB
Python
189 lines
6.8 KiB
Python
import tensorflow as tf
|
|
from onnx.helper import make_opsetid
|
|
import onnx_tf
|
|
from onnx_tf.common import data_type
|
|
|
|
|
|
class ScanMixin(object):
|
|
|
|
@classmethod
|
|
def scan(cls, node, input_dict, strict):
|
|
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
|
|
|
|
body = node.attrs["body"]
|
|
|
|
# in version 8, node.inputs[0] is the sequence_lens
|
|
node_inputs = node.inputs if cls.SINCE_VERSION != 8 else \
|
|
node.inputs[1:]
|
|
# M
|
|
num_scan_inputs = int(node.attrs["num_scan_inputs"])
|
|
# N = num_inputs - M
|
|
num_state_vars = len(node_inputs) - num_scan_inputs
|
|
# K = num_outputs - N
|
|
num_scan_outputs = len(node.outputs) - num_state_vars
|
|
|
|
"""
|
|
Function to run subgraph used with tf.scan
|
|
"""
|
|
|
|
def run_subgraph(a, b):
|
|
input_values = {}
|
|
# set the input values for the subgraph
|
|
# set the values for the state variables
|
|
for i in range(num_state_vars):
|
|
input_values[body.input[i].name] = a[i]
|
|
# set the values for the scan inputs
|
|
for i in range(num_scan_inputs):
|
|
input_values[body.input[i + num_state_vars].name] = b[i]
|
|
|
|
# get the tensor operations for the onnx graph
|
|
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
|
|
subgraph=body,
|
|
input_values=input_values,
|
|
tensor_dict=input_dict,
|
|
opset=current_opset,
|
|
strict=strict)
|
|
# return sequence of tensors for every subgraph output
|
|
outputs = [subgraph_tensor_dict[output.name] for output in body.output]
|
|
return outputs
|
|
|
|
scan_input_axes = node.attrs.get("scan_input_axes", [0] * num_scan_inputs)
|
|
scan_input_directions = node.attrs.get(
|
|
"directions" if cls.SINCE_VERSION == 8 else "scan_input_directions",
|
|
[0] * num_scan_inputs)
|
|
scan_output_axes = node.attrs.get("scan_output_axes",
|
|
[0] * num_scan_outputs)
|
|
scan_output_directions = node.attrs.get("scan_output_directions",
|
|
[0] * num_scan_outputs)
|
|
|
|
# if version 8 read the sequnce_lens from the first input
|
|
if cls.SINCE_VERSION == 8:
|
|
sequence_lens = input_dict[node.inputs[0]] \
|
|
if node.inputs[0] != '' else None
|
|
|
|
inputs = [input_dict[node_input] for node_input in node_inputs]
|
|
|
|
scan_inputs = inputs[num_state_vars:]
|
|
# loop over all the scan inputs and apply transpose depending
|
|
# on input axes provided and also reverse the scan inputs if
|
|
# reverse direction for scan is provided
|
|
for i in range(num_scan_inputs):
|
|
# if input axes are different than 0, use transpose to scan over
|
|
# the provided axes
|
|
if scan_input_axes[i] != 0:
|
|
transpose_perm = cls._calc_transpose_perm_input(tf.rank(scan_inputs[i]),
|
|
scan_input_axes[i])
|
|
scan_inputs[i] = tf.transpose(scan_inputs[i], transpose_perm)
|
|
|
|
# check for reverse direction scans
|
|
if scan_input_directions[i] == 1:
|
|
# version 8 has a batch dimension
|
|
axis = 0 if cls.SINCE_VERSION != 8 else 1
|
|
scan_inputs[i] = tf.reverse(scan_inputs[i], [axis])
|
|
|
|
state_vars_init = inputs[:num_state_vars]
|
|
|
|
scan_outputs_init = []
|
|
# generate sequence of zero tensors for all scan outputs
|
|
# with the correct shape and dtype
|
|
for scan_output in body.output[num_state_vars:]:
|
|
tensor_type = scan_output.type.tensor_type
|
|
shape = [
|
|
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
|
|
for d in tensor_type.shape.dim
|
|
]
|
|
dtype = data_type.onnx2tf(tensor_type.elem_type)
|
|
scan_outputs_init.append(tf.zeros(shape, dtype=dtype))
|
|
|
|
# tf.scan initilizer is state_variables_init + scan_outputs_init
|
|
initializer = state_vars_init + scan_outputs_init
|
|
|
|
if cls.SINCE_VERSION == 8:
|
|
# version == 8
|
|
# function to process the batches. it is used with tf.map_fn
|
|
def run_batches(x):
|
|
# state vars initial values per batch
|
|
initial = x[0]
|
|
# scan inputs per batch
|
|
scan_inputs = x[1]
|
|
# sequence length for the batch
|
|
seq_len = x[2]
|
|
|
|
# slice the input to the current sequence len
|
|
scan_inputs = [scan_input[:seq_len, ...] for scan_input in scan_inputs]
|
|
|
|
# run scan on the current batch
|
|
out = tf.scan(run_subgraph,
|
|
scan_inputs,
|
|
initializer=initial + scan_outputs_init)
|
|
|
|
# pad to the original shape with zeros
|
|
paddings = [[0, tf.shape(x[1][0], out_type=seq_len.dtype)[0] - seq_len]]
|
|
for i in range(len(out)):
|
|
pads = tf.concat(
|
|
[paddings,
|
|
tf.zeros([(tf.rank(out[i]) - 1), 2], dtype=tf.int32)],
|
|
axis=0)
|
|
out[i] = tf.pad(out[i], pads)
|
|
return out
|
|
|
|
if sequence_lens is None:
|
|
# if sequence_lens is None, fill it with the shape of
|
|
# the input axis 1
|
|
sequence_lens = tf.fill([tf.shape(scan_inputs[0])[0]],
|
|
tf.shape(scan_inputs[0], out_type=tf.int32)[1])
|
|
|
|
output_types = [
|
|
data_type.onnx2tf(output.type.tensor_type.elem_type)
|
|
for output in body.output
|
|
]
|
|
# run scan for every batch
|
|
out = tf.map_fn(run_batches,
|
|
(state_vars_init, scan_inputs, sequence_lens),
|
|
dtype=output_types)
|
|
|
|
state_vars_outputs = []
|
|
# extract the final values of the state variables
|
|
for state_var in out[:num_state_vars]:
|
|
state_vars_outputs.append(
|
|
tf.map_fn(lambda x: x[0][x[1] - 1], (state_var, sequence_lens),
|
|
state_var.dtype))
|
|
else:
|
|
# version > 8
|
|
# run the scan
|
|
out = tf.scan(run_subgraph, scan_inputs, initializer=initializer)
|
|
|
|
# extract the final values of the state variables
|
|
state_vars_outputs = [
|
|
state_var[tf.shape(state_var)[0] - 1]
|
|
for state_var in out[:num_state_vars]
|
|
]
|
|
|
|
scan_outputs = out[num_state_vars:]
|
|
|
|
# post process the scan outputs depending on the directions and
|
|
# axes provided.
|
|
for i in range(num_scan_outputs):
|
|
# check for reverse direction scan outputs
|
|
if scan_output_directions[i] == 1:
|
|
scan_outputs[i] = tf.reverse(scan_outputs[i], [0])
|
|
|
|
if scan_output_axes[i] != 0:
|
|
transpose_perm = cls._calc_transpose_perm_output(
|
|
tf.rank(scan_outputs[i]), scan_output_axes[i])
|
|
scan_outputs[i] = tf.transpose(scan_outputs[i], transpose_perm)
|
|
|
|
return state_vars_outputs + scan_outputs
|
|
|
|
@classmethod
|
|
def _calc_transpose_perm_input(cls, rank, axis):
|
|
if axis < 0:
|
|
axis = rank + axis
|
|
return tf.concat([[axis], tf.range(axis), tf.range(axis + 1, rank)], 0)
|
|
|
|
@classmethod
|
|
def _calc_transpose_perm_output(cls, rank, axis):
|
|
if axis < 0:
|
|
axis = rank + axis
|
|
return tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, rank)], 0)
|