111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
import tensorflow as tf
|
|
|
|
import onnx_tf
|
|
from onnx.helper import make_opsetid
|
|
from onnx_tf.common import data_type
|
|
from onnx_tf.common import exception
|
|
from onnx_tf.handlers.backend_handler import BackendHandler
|
|
from onnx_tf.handlers.handler import onnx_op
|
|
|
|
|
|
@onnx_op("Loop")
|
|
class Loop(BackendHandler):
|
|
|
|
@classmethod
|
|
def _common(cls, node, **kwargs):
|
|
body = node.attrs["body"]
|
|
tensor_dict = kwargs["tensor_dict"]
|
|
M = tensor_dict[node.inputs[0]] if node.inputs[0] != "" else None
|
|
cond = tf.cast(tensor_dict[node.inputs[1]],
|
|
tf.bool) if node.inputs[1] != "" else None
|
|
v_initial = [tensor_dict[graph_input] for graph_input in node.inputs[2:]]
|
|
v_shapes = [v.get_shape() for v in v_initial]
|
|
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
|
|
# outputs of the body will be in this format:
|
|
# (condition, loop carried dependencies..., scan_outputs...)
|
|
scan_outputs_start_index = 1 + len(v_initial)
|
|
scan_outputs = [
|
|
tf.TensorArray(dtype=data_type.onnx2tf(
|
|
body.output[i].type.tensor_type.elem_type),
|
|
size=0,
|
|
dynamic_size=True)
|
|
for i in range(scan_outputs_start_index, len(body.output))
|
|
]
|
|
scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs]
|
|
|
|
def run_subgraph(cond, v, scan_outputs):
|
|
input_values = {}
|
|
input_values[body.input[0].name] = M
|
|
input_values[body.input[1].name] = cond
|
|
for i in range(2, len(body.input)):
|
|
input_values[body.input[i].name] = v[i - 2]
|
|
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
|
|
subgraph=body,
|
|
input_values=input_values,
|
|
tensor_dict=tensor_dict,
|
|
opset=current_opset)
|
|
outputs = [subgraph_tensor_dict[output.name] for output in body.output]
|
|
for i in range(scan_outputs_start_index, len(outputs)):
|
|
s_index = i - scan_outputs_start_index
|
|
insert_index = scan_outputs[s_index].size()
|
|
scan_outputs[s_index] = scan_outputs[s_index].write(
|
|
insert_index, outputs[i])
|
|
return [outputs[0], outputs[1:scan_outputs_start_index], scan_outputs]
|
|
|
|
# for loop
|
|
if M is not None and cond is None:
|
|
M = tf.cast(M, tf.int32)
|
|
condition = lambda cond, v, scan_outputs: True
|
|
_, v_final, scan_outputs = tf.while_loop(
|
|
cond=condition,
|
|
body=run_subgraph,
|
|
loop_vars=["", v_initial, scan_outputs],
|
|
shape_invariants=[
|
|
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
|
],
|
|
maximum_iterations=M)
|
|
# while and do-while loop
|
|
elif M is None and cond is not None:
|
|
condition = lambda cond, v, scan_outputs: tf.reduce_all(
|
|
tf.equal(cond, True))
|
|
cond, v_final, scan_outputs = tf.while_loop(
|
|
cond=condition,
|
|
body=run_subgraph,
|
|
loop_vars=[cond, v_initial, scan_outputs],
|
|
shape_invariants=[
|
|
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
|
])
|
|
# combine for loop and while loop together
|
|
elif M is not None and cond is not None:
|
|
M = tf.cast(M, tf.int32)
|
|
condition = lambda cond, v, scan_outputs: tf.reduce_all(
|
|
tf.equal(cond, True))
|
|
cond, v_final, scan_outputs = tf.while_loop(
|
|
cond=condition,
|
|
body=run_subgraph,
|
|
loop_vars=[cond, v_initial, scan_outputs],
|
|
shape_invariants=[
|
|
tf.TensorShape(None), v_shapes, scan_outputs_shapes
|
|
],
|
|
maximum_iterations=M)
|
|
else: # M is None and cond is None
|
|
exception.OP_UNSUPPORTED_EXCEPT(
|
|
"Both M and cond in Loop are not set at the same time",
|
|
"Tensorflow.(PS. if you want to create a do-while loop " +
|
|
"then please set cond to True or 1)")
|
|
|
|
scan_outputs_tensors = [o.stack() for o in scan_outputs]
|
|
if scan_outputs_start_index == len(body.output):
|
|
# there is no scan_output in the body graph
|
|
return [v_final]
|
|
else:
|
|
return [v_final, scan_outputs_tensors]
|
|
|
|
@classmethod
|
|
def version_1(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|
|
|
|
@classmethod
|
|
def version_11(cls, node, **kwargs):
|
|
return cls._common(node, **kwargs)
|