重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
@@ -0,0 +1,188 @@
import tensorflow as tf
from onnx.helper import make_opsetid
import onnx_tf
from onnx_tf.common import data_type
class ScanMixin(object):
@classmethod
def scan(cls, node, input_dict, strict):
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
body = node.attrs["body"]
# in version 8, node.inputs[0] is the sequence_lens
node_inputs = node.inputs if cls.SINCE_VERSION != 8 else \
node.inputs[1:]
# M
num_scan_inputs = int(node.attrs["num_scan_inputs"])
# N = num_inputs - M
num_state_vars = len(node_inputs) - num_scan_inputs
# K = num_outputs - N
num_scan_outputs = len(node.outputs) - num_state_vars
"""
Function to run subgraph used with tf.scan
"""
def run_subgraph(a, b):
input_values = {}
# set the input values for the subgraph
# set the values for the state variables
for i in range(num_state_vars):
input_values[body.input[i].name] = a[i]
# set the values for the scan inputs
for i in range(num_scan_inputs):
input_values[body.input[i + num_state_vars].name] = b[i]
# get the tensor operations for the onnx graph
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
subgraph=body,
input_values=input_values,
tensor_dict=input_dict,
opset=current_opset,
strict=strict)
# return sequence of tensors for every subgraph output
outputs = [subgraph_tensor_dict[output.name] for output in body.output]
return outputs
scan_input_axes = node.attrs.get("scan_input_axes", [0] * num_scan_inputs)
scan_input_directions = node.attrs.get(
"directions" if cls.SINCE_VERSION == 8 else "scan_input_directions",
[0] * num_scan_inputs)
scan_output_axes = node.attrs.get("scan_output_axes",
[0] * num_scan_outputs)
scan_output_directions = node.attrs.get("scan_output_directions",
[0] * num_scan_outputs)
# if version 8 read the sequnce_lens from the first input
if cls.SINCE_VERSION == 8:
sequence_lens = input_dict[node.inputs[0]] \
if node.inputs[0] != '' else None
inputs = [input_dict[node_input] for node_input in node_inputs]
scan_inputs = inputs[num_state_vars:]
# loop over all the scan inputs and apply transpose depending
# on input axes provided and also reverse the scan inputs if
# reverse direction for scan is provided
for i in range(num_scan_inputs):
# if input axes are different than 0, use transpose to scan over
# the provided axes
if scan_input_axes[i] != 0:
transpose_perm = cls._calc_transpose_perm_input(tf.rank(scan_inputs[i]),
scan_input_axes[i])
scan_inputs[i] = tf.transpose(scan_inputs[i], transpose_perm)
# check for reverse direction scans
if scan_input_directions[i] == 1:
# version 8 has a batch dimension
axis = 0 if cls.SINCE_VERSION != 8 else 1
scan_inputs[i] = tf.reverse(scan_inputs[i], [axis])
state_vars_init = inputs[:num_state_vars]
scan_outputs_init = []
# generate sequence of zero tensors for all scan outputs
# with the correct shape and dtype
for scan_output in body.output[num_state_vars:]:
tensor_type = scan_output.type.tensor_type
shape = [
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in tensor_type.shape.dim
]
dtype = data_type.onnx2tf(tensor_type.elem_type)
scan_outputs_init.append(tf.zeros(shape, dtype=dtype))
# tf.scan initilizer is state_variables_init + scan_outputs_init
initializer = state_vars_init + scan_outputs_init
if cls.SINCE_VERSION == 8:
# version == 8
# function to process the batches. it is used with tf.map_fn
def run_batches(x):
# state vars initial values per batch
initial = x[0]
# scan inputs per batch
scan_inputs = x[1]
# sequence length for the batch
seq_len = x[2]
# slice the input to the current sequence len
scan_inputs = [scan_input[:seq_len, ...] for scan_input in scan_inputs]
# run scan on the current batch
out = tf.scan(run_subgraph,
scan_inputs,
initializer=initial + scan_outputs_init)
# pad to the original shape with zeros
paddings = [[0, tf.shape(x[1][0], out_type=seq_len.dtype)[0] - seq_len]]
for i in range(len(out)):
pads = tf.concat(
[paddings,
tf.zeros([(tf.rank(out[i]) - 1), 2], dtype=tf.int32)],
axis=0)
out[i] = tf.pad(out[i], pads)
return out
if sequence_lens is None:
# if sequence_lens is None, fill it with the shape of
# the input axis 1
sequence_lens = tf.fill([tf.shape(scan_inputs[0])[0]],
tf.shape(scan_inputs[0], out_type=tf.int32)[1])
output_types = [
data_type.onnx2tf(output.type.tensor_type.elem_type)
for output in body.output
]
# run scan for every batch
out = tf.map_fn(run_batches,
(state_vars_init, scan_inputs, sequence_lens),
dtype=output_types)
state_vars_outputs = []
# extract the final values of the state variables
for state_var in out[:num_state_vars]:
state_vars_outputs.append(
tf.map_fn(lambda x: x[0][x[1] - 1], (state_var, sequence_lens),
state_var.dtype))
else:
# version > 8
# run the scan
out = tf.scan(run_subgraph, scan_inputs, initializer=initializer)
# extract the final values of the state variables
state_vars_outputs = [
state_var[tf.shape(state_var)[0] - 1]
for state_var in out[:num_state_vars]
]
scan_outputs = out[num_state_vars:]
# post process the scan outputs depending on the directions and
# axes provided.
for i in range(num_scan_outputs):
# check for reverse direction scan outputs
if scan_output_directions[i] == 1:
scan_outputs[i] = tf.reverse(scan_outputs[i], [0])
if scan_output_axes[i] != 0:
transpose_perm = cls._calc_transpose_perm_output(
tf.rank(scan_outputs[i]), scan_output_axes[i])
scan_outputs[i] = tf.transpose(scan_outputs[i], transpose_perm)
return state_vars_outputs + scan_outputs
@classmethod
def _calc_transpose_perm_input(cls, rank, axis):
if axis < 0:
axis = rank + axis
return tf.concat([[axis], tf.range(axis), tf.range(axis + 1, rank)], 0)
@classmethod
def _calc_transpose_perm_output(cls, rank, axis):
if axis < 0:
axis = rank + axis
return tf.concat([tf.range(1, axis + 1), [0], tf.range(axis + 1, rank)], 0)