353 lines
14 KiB
Python
353 lines
14 KiB
Python
"""Backend for running ONNX on Tensorflow
|
|
|
|
To run this, you will need to have Tensorflow installed as well.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
try:
|
|
from itertools import izip as zip
|
|
except ImportError: # will be 3.x series
|
|
pass
|
|
|
|
from onnx import defs
|
|
from onnx import numpy_helper
|
|
from onnx.backend.base import Backend
|
|
from onnx.backend.base import Device
|
|
from onnx.backend.base import namedtupledict
|
|
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
|
|
from onnx.helper import make_opsetid
|
|
import tensorflow as tf
|
|
|
|
from onnx_tf.backend_rep import TensorflowRep
|
|
from onnx_tf.common import data_type
|
|
from onnx_tf.common import get_device_option
|
|
from onnx_tf.common import get_unique_suffix
|
|
from onnx_tf.common import supports_device as common_supports_device
|
|
from onnx_tf.common.handler_helper import get_all_backend_handlers
|
|
from onnx_tf.pb_wrapper import OnnxNode
|
|
import onnx_tf.common as common
|
|
|
|
|
|
class TensorflowBackend(Backend):
|
|
""" Tensorflow Backend for ONNX
|
|
"""
|
|
|
|
@classmethod
|
|
def prepare(cls,
|
|
model,
|
|
device='CPU',
|
|
strict=True,
|
|
logging_level='INFO',
|
|
**kwargs):
|
|
"""Prepare an ONNX model for Tensorflow Backend.
|
|
|
|
This function converts an ONNX model to an internel representation
|
|
of the computational graph called TensorflowRep and returns
|
|
the converted representation.
|
|
|
|
:param model: The ONNX model to be converted.
|
|
:param device: The device to execute this model on.
|
|
:param strict: Whether to enforce semantic equivalence between the original model
|
|
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
|
Changing to False is strongly discouraged.
|
|
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
|
|
:param logging_level: The logging level, default is INFO. Change it to DEBUG
|
|
to see more conversion details or to WARNING to see less
|
|
|
|
:returns: A TensorflowRep class object representing the ONNX model
|
|
"""
|
|
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
|
|
common.logger.setLevel(logging_level)
|
|
common.logger.handlers[0].setLevel(logging_level)
|
|
|
|
return cls.onnx_model_to_tensorflow_rep(model, strict)
|
|
|
|
@classmethod
|
|
def onnx_model_to_tensorflow_rep(cls, model, strict):
|
|
""" Convert ONNX model to TensorflowRep.
|
|
|
|
:param model: ONNX ModelProto object.
|
|
:param strict: whether to enforce semantic equivalence between the original model
|
|
and the converted tensorflow model.
|
|
:return: TensorflowRep object.
|
|
"""
|
|
|
|
# Models with IR_VERSION less than 3 does not have opset_import set.
|
|
# We default to minimum opset, this behavior is consistent with
|
|
# onnx checker.
|
|
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
|
|
if model.ir_version < 3:
|
|
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
|
|
else:
|
|
opset_import = model.opset_import
|
|
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
|
|
|
|
@classmethod
|
|
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
|
|
""" Convert ONNX graph to TensorflowRep.
|
|
|
|
:param graph_def: ONNX GraphProto object.
|
|
:param opset: ONNX OperatorSetIdProto list.
|
|
:param strict: whether to enforce semantic equivalence between the original model
|
|
and the converted tensorflow model.
|
|
:return: TensorflowRep object.
|
|
"""
|
|
handlers = cls._get_handlers(opset)
|
|
|
|
tf_rep_graph = tf.Graph()
|
|
with tf_rep_graph.as_default():
|
|
# initializer: TensorProtos representing the values to initialize
|
|
# a given tensor.
|
|
# initialized: A list of names of the initialized tensors.
|
|
if graph_def.initializer:
|
|
input_dict_items = cls._onnx_initializer_to_input_dict_items(
|
|
graph_def.initializer)
|
|
initialized = {init.name for init in graph_def.initializer}
|
|
else:
|
|
input_dict_items = []
|
|
initialized = set()
|
|
|
|
# creating placeholders for currently unknown inputs
|
|
for value_info in graph_def.input:
|
|
if value_info.name in initialized:
|
|
continue
|
|
shape = list(
|
|
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
|
|
for d in value_info.type.tensor_type.shape.dim)
|
|
value_info_name = value_info.name.replace(
|
|
":", "_tf_") + "_" + get_unique_suffix(
|
|
) if ":" in value_info.name else value_info.name
|
|
|
|
x = tf.placeholder(data_type.onnx2tf(
|
|
value_info.type.tensor_type.elem_type),
|
|
name=value_info_name,
|
|
shape=shape)
|
|
input_dict_items.append((value_info.name, x))
|
|
|
|
# tensor dict: this dictionary is a map from variable names
|
|
# to the latest produced TF tensors of the given name.
|
|
# This dictionary will get updated as we build the graph to
|
|
# record the names of newly produced tensors.
|
|
tensor_dict = dict(input_dict_items)
|
|
# Since tensor dict may be updated, we need to keep a copy
|
|
# of the original input dict where we track the earliest
|
|
# defined tensors so we can have access to the placeholders
|
|
# to feed in input tensors when we run the graph.
|
|
input_dict = dict(input_dict_items)
|
|
|
|
for node in graph_def.node:
|
|
onnx_node = OnnxNode(node)
|
|
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
|
tensor_dict,
|
|
handlers,
|
|
opset=opset,
|
|
strict=strict)
|
|
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
|
tensor_dict.update(curr_node_output_map)
|
|
|
|
tf_rep = TensorflowRep()
|
|
tf_rep.graph = tf_rep_graph
|
|
tf_rep.inputs = [
|
|
value_info.name
|
|
for value_info in graph_def.input
|
|
if value_info.name not in initialized
|
|
]
|
|
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
|
|
tf_rep.tensor_dict = tensor_dict
|
|
return tf_rep
|
|
|
|
@classmethod
|
|
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
|
|
""" Run ONNX node.
|
|
|
|
:param node: ONNX NodeProto object.
|
|
:param inputs: Inputs.
|
|
:param device: Device run on.
|
|
:param outputs_info: None.
|
|
:param kwargs: Other args.
|
|
:return: Outputs.
|
|
"""
|
|
super(TensorflowBackend, cls).run_node(node, inputs, device)
|
|
node_graph = tf.Graph()
|
|
with node_graph.as_default():
|
|
node = OnnxNode(node)
|
|
device_option = get_device_option(Device(device))
|
|
input_tensors = []
|
|
for i in inputs:
|
|
input_tensors.append(tf.constant(i))
|
|
|
|
if isinstance(inputs, dict):
|
|
feed_dict_raw = inputs
|
|
else:
|
|
assert len(node.inputs) == len(inputs)
|
|
feed_dict_raw = dict(zip(node.inputs, inputs))
|
|
|
|
# TODO: is constant the best way for feeding inputs?
|
|
input_dict = dict([
|
|
(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
|
|
])
|
|
ops = cls._onnx_node_to_tensorflow_op(node, input_dict)
|
|
|
|
with tf.Session() as sess:
|
|
with tf.device(device_option):
|
|
sess.run(tf.global_variables_initializer())
|
|
output_vals = sess.run(ops)
|
|
|
|
return namedtupledict('Outputs', node.outputs)(*output_vals)
|
|
|
|
@classmethod
|
|
def _onnx_initializer_to_input_dict_items(cls, initializer):
|
|
""" Convert ONNX graph initializer to input dict items.
|
|
|
|
:param initializer: ONNX graph initializer, list of TensorProto.
|
|
:return: List of input dict items.
|
|
"""
|
|
|
|
def tensor2list(onnx_tensor):
|
|
# Use the onnx.numpy_helper because the data may be raw
|
|
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
|
|
|
|
def validate_initializer_name(name):
|
|
# Prepend a unique suffix if leading charater is "_"
|
|
name = get_unique_suffix() + name if name[0] is "_" else name
|
|
|
|
# Replace ":" with "_tf_" and append a unique suffix for
|
|
# traceability
|
|
return name.replace(
|
|
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
|
|
|
|
return [(init.name,
|
|
tf.constant(tensor2list(init),
|
|
shape=init.dims,
|
|
dtype=data_type.onnx2tf(init.data_type),
|
|
name=validate_initializer_name(init.name)))
|
|
for init in initializer]
|
|
|
|
@classmethod
|
|
def _onnx_node_to_tensorflow_op(cls,
|
|
node,
|
|
tensor_dict,
|
|
handlers=None,
|
|
opset=None,
|
|
strict=True):
|
|
"""
|
|
Convert onnx node to tensorflow op.
|
|
|
|
Args:
|
|
node: Onnx node object.
|
|
tensor_dict: Tensor dict of graph.
|
|
opset: Opset version of the operator set. Default 0 means using latest version.
|
|
strict: whether to enforce semantic equivalence between the original model
|
|
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
|
Changing to False is strongly discouraged.
|
|
Returns:
|
|
Tensorflow op
|
|
"""
|
|
handlers = handlers or cls._get_handlers(opset)
|
|
if handlers:
|
|
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
|
|
if handler:
|
|
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
|
|
|
|
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
|
|
|
|
@classmethod
|
|
def _get_handlers(cls, opset):
|
|
""" Get all backend handlers with opset.
|
|
|
|
:param opset: ONNX OperatorSetIdProto list.
|
|
:return: All backend handlers.
|
|
"""
|
|
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
|
opset_dict = dict([(o.domain, o.version) for o in opset])
|
|
return get_all_backend_handlers(opset_dict)
|
|
|
|
@classmethod
|
|
def supports_device(cls, device):
|
|
return common_supports_device(device)
|
|
|
|
@classmethod
|
|
def onnx_graph_to_tensorflow_ops(cls,
|
|
subgraph,
|
|
input_values,
|
|
tensor_dict,
|
|
opset=None,
|
|
strict=True):
|
|
"""
|
|
Converts ONNX graph to Tensorflow operations
|
|
Args:
|
|
subgraph: the ONNX graph to be converted
|
|
input_values: dictionary with values/tensors to initialize
|
|
the subgraph inputs. if the subgraph.input
|
|
are send in as parameters then it is required,
|
|
otherwise this can be empty dictionary.
|
|
tensor_dict: the dictionary that contain values for all the
|
|
node.inputs in the subgraph that are not defined
|
|
in the subgraph or input_values.
|
|
opset: opset version of the operator set.
|
|
strict: whether to enforce semantic equivalence between the
|
|
original model and the converted tensorflow model,
|
|
defaults to True (yes, enforce semantic equivalence).
|
|
Returns:
|
|
array of Tensorflow Tensors
|
|
"""
|
|
# get the subgraph.input from input_values
|
|
subgraph_tensor_dict = input_values.copy()
|
|
# get the rest of the subgraph input from tensor_dict
|
|
for i in subgraph.input:
|
|
if i.name not in subgraph_tensor_dict.keys():
|
|
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
|
|
# get the required initializer constant node(s) for the subgraph
|
|
# Need to get the initializer constant nodes from tensor_dict here
|
|
# because input from initializer will not be send in as inputs
|
|
# to the subgraph and those nodes are not in the subgraph
|
|
nodes_outputs = []
|
|
for node in subgraph.node:
|
|
for o_name in node.output:
|
|
nodes_outputs.append(o_name)
|
|
for node in subgraph.node:
|
|
for i_name in node.input:
|
|
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
|
|
):
|
|
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
|
|
onnx_node = OnnxNode(node)
|
|
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
|
subgraph_tensor_dict,
|
|
opset=opset,
|
|
strict=strict)
|
|
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
|
subgraph_tensor_dict.update(curr_node_output_map)
|
|
return subgraph_tensor_dict
|
|
|
|
@classmethod
|
|
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
|
|
"""
|
|
Converts ONNX graph to TensorflowRep
|
|
Args:
|
|
graph_def: the ONNX graph to be converted
|
|
strict: whether to enforce semantic equivalence between the
|
|
original model and the converted tensorflow model,
|
|
defaults to True (yes, enforce semantic equivalence).
|
|
Returns:
|
|
TensorflowRep object.
|
|
"""
|
|
# get the opset of the installed ONNX
|
|
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
|
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
|
|
|
|
|
|
prepare = TensorflowBackend.prepare
|
|
|
|
run_node = TensorflowBackend.run_node
|
|
|
|
run_model = TensorflowBackend.run_model
|
|
|
|
supports_device = TensorflowBackend.supports_device
|
|
|
|
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
|
|
|
|
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep
|