import tensorflow as tf import onnx_tf from onnx.helper import make_opsetid from onnx_tf.common import data_type from onnx_tf.common import exception from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op @onnx_op("Loop") class Loop(BackendHandler): @classmethod def _common(cls, node, **kwargs): body = node.attrs["body"] tensor_dict = kwargs["tensor_dict"] M = tensor_dict[node.inputs[0]] if node.inputs[0] != "" else None cond = tf.cast(tensor_dict[node.inputs[1]], tf.bool) if node.inputs[1] != "" else None v_initial = [tensor_dict[graph_input] for graph_input in node.inputs[2:]] v_shapes = [v.get_shape() for v in v_initial] current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)] # outputs of the body will be in this format: # (condition, loop carried dependencies..., scan_outputs...) scan_outputs_start_index = 1 + len(v_initial) scan_outputs = [ tf.TensorArray(dtype=data_type.onnx2tf( body.output[i].type.tensor_type.elem_type), size=0, dynamic_size=True) for i in range(scan_outputs_start_index, len(body.output)) ] scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs] def run_subgraph(cond, v, scan_outputs): input_values = {} input_values[body.input[0].name] = M input_values[body.input[1].name] = cond for i in range(2, len(body.input)): input_values[body.input[i].name] = v[i - 2] subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops( subgraph=body, input_values=input_values, tensor_dict=tensor_dict, opset=current_opset) outputs = [subgraph_tensor_dict[output.name] for output in body.output] for i in range(scan_outputs_start_index, len(outputs)): s_index = i - scan_outputs_start_index insert_index = scan_outputs[s_index].size() scan_outputs[s_index] = scan_outputs[s_index].write( insert_index, outputs[i]) return [outputs[0], outputs[1:scan_outputs_start_index], scan_outputs] # for loop if M is not None and cond is None: M = tf.cast(M, tf.int32) condition = lambda cond, v, scan_outputs: True _, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=["", v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) # while and do-while loop elif M is None and cond is not None: condition = lambda cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) cond, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[cond, v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ]) # combine for loop and while loop together elif M is not None and cond is not None: M = tf.cast(M, tf.int32) condition = lambda cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) cond, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[cond, v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) else: # M is None and cond is None exception.OP_UNSUPPORTED_EXCEPT( "Both M and cond in Loop are not set at the same time", "Tensorflow.(PS. if you want to create a do-while loop " + "then please set cond to True or 1)") scan_outputs_tensors = [o.stack() for o in scan_outputs] if scan_outputs_start_index == len(body.output): # there is no scan_output in the body graph return [v_final] else: return [v_final, scan_outputs_tensors] @classmethod def version_1(cls, node, **kwargs): return cls._common(node, **kwargs) @classmethod def version_11(cls, node, **kwargs): return cls._common(node, **kwargs)