重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,110 @@
|
||||
import tensorflow as tf
|
||||
|
||||
import onnx_tf
|
||||
from onnx.helper import make_opsetid
|
||||
from onnx_tf.common import data_type
|
||||
from onnx_tf.common import exception
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
from onnx_tf.handlers.handler import onnx_op
|
||||
|
||||
|
||||
@onnx_op("Loop")
|
||||
class Loop(BackendHandler):
|
||||
|
||||
@classmethod
|
||||
def _common(cls, node, **kwargs):
|
||||
body = node.attrs["body"]
|
||||
tensor_dict = kwargs["tensor_dict"]
|
||||
M = tensor_dict[node.inputs[0]] if node.inputs[0] != "" else None
|
||||
cond = tf.cast(tensor_dict[node.inputs[1]],
|
||||
tf.bool) if node.inputs[1] != "" else None
|
||||
v_initial = [tensor_dict[graph_input] for graph_input in node.inputs[2:]]
|
||||
v_shapes = [v.get_shape() for v in v_initial]
|
||||
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
|
||||
# outputs of the body will be in this format:
|
||||
# (condition, loop carried dependencies..., scan_outputs...)
|
||||
scan_outputs_start_index = 1 + len(v_initial)
|
||||
scan_outputs = [
|
||||
tf.TensorArray(dtype=data_type.onnx2tf(
|
||||
body.output[i].type.tensor_type.elem_type),
|
||||
size=0,
|
||||
dynamic_size=True)
|
||||
for i in range(scan_outputs_start_index, len(body.output))
|
||||
]
|
||||
scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs]
|
||||
|
||||
def run_subgraph(cond, v, scan_outputs):
|
||||
input_values = {}
|
||||
input_values[body.input[0].name] = M
|
||||
input_values[body.input[1].name] = cond
|
||||
for i in range(2, len(body.input)):
|
||||
input_values[body.input[i].name] = v[i - 2]
|
||||
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
|
||||
subgraph=body,
|
||||
input_values=input_values,
|
||||
tensor_dict=tensor_dict,
|
||||
opset=current_opset)
|
||||
outputs = [subgraph_tensor_dict[output.name] for output in body.output]
|
||||
for i in range(scan_outputs_start_index, len(outputs)):
|
||||
s_index = i - scan_outputs_start_index
|
||||
insert_index = scan_outputs[s_index].size()
|
||||
scan_outputs[s_index] = scan_outputs[s_index].write(
|
||||
insert_index, outputs[i])
|
||||
return [outputs[0], outputs[1:scan_outputs_start_index], scan_outputs]
|
||||
|
||||
# for loop
|
||||
if M is not None and cond is None:
|
||||
M = tf.cast(M, tf.int32)
|
||||
condition = lambda cond, v, scan_outputs: True
|
||||
_, v_final, scan_outputs = tf.while_loop(
|
||||
cond=condition,
|
||||
body=run_subgraph,
|
||||
loop_vars=["", v_initial, scan_outputs],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
||||
],
|
||||
maximum_iterations=M)
|
||||
# while and do-while loop
|
||||
elif M is None and cond is not None:
|
||||
condition = lambda cond, v, scan_outputs: tf.reduce_all(
|
||||
tf.equal(cond, True))
|
||||
cond, v_final, scan_outputs = tf.while_loop(
|
||||
cond=condition,
|
||||
body=run_subgraph,
|
||||
loop_vars=[cond, v_initial, scan_outputs],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
||||
])
|
||||
# combine for loop and while loop together
|
||||
elif M is not None and cond is not None:
|
||||
M = tf.cast(M, tf.int32)
|
||||
condition = lambda cond, v, scan_outputs: tf.reduce_all(
|
||||
tf.equal(cond, True))
|
||||
cond, v_final, scan_outputs = tf.while_loop(
|
||||
cond=condition,
|
||||
body=run_subgraph,
|
||||
loop_vars=[cond, v_initial, scan_outputs],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
||||
],
|
||||
maximum_iterations=M)
|
||||
else: # M is None and cond is None
|
||||
exception.OP_UNSUPPORTED_EXCEPT(
|
||||
"Both M and cond in Loop are not set at the same time",
|
||||
"Tensorflow.(PS. if you want to create a do-while loop " +
|
||||
"then please set cond to True or 1)")
|
||||
|
||||
scan_outputs_tensors = [o.stack() for o in scan_outputs]
|
||||
if scan_outputs_start_index == len(body.output):
|
||||
# there is no scan_output in the body graph
|
||||
return [v_final]
|
||||
else:
|
||||
return [v_final, scan_outputs_tensors]
|
||||
|
||||
@classmethod
|
||||
def version_1(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def version_11(cls, node, **kwargs):
|
||||
return cls._common(node, **kwargs)
|
||||
Reference in New Issue
Block a user