Files
2020-10-14 08:55:07 +08:00

52 lines
1.7 KiB
Python

import tensorflow as tf
import onnx_tf
from onnx.helper import make_opsetid
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
@onnx_op("If")
@tf_func(tf.cond)
class If(BackendHandler):
@classmethod
def _common(cls, node, **kwargs):
cond = kwargs["tensor_dict"][node.inputs[0]]
then_branch = node.attrs["then_branch"]
else_branch = node.attrs["else_branch"]
current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)]
def true_fn():
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
subgraph=then_branch,
input_values={}, # all inputs of then_branch are in tensor_dict
tensor_dict=kwargs["tensor_dict"],
opset=current_opset)
return [subgraph_tensor_dict[o.name] for o in then_branch.output]
def false_fn():
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
subgraph=else_branch,
input_values={}, # all inputs of else_branch are in tensor_dict
tensor_dict=kwargs["tensor_dict"],
opset=current_opset)
return [subgraph_tensor_dict[o.name] for o in else_branch.output]
# Set strict=True to make sure singleton lists and tuples return from
# true_fn and false_fn will not be implicitly unpacked to single values
strict = True
return [
cls.make_tensor_from_onnx_node(node,
inputs=[cond, true_fn, false_fn, strict])
]
@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)
@classmethod
def version_11(cls, node, **kwargs):
return cls._common(node, **kwargs)