Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/backend.py
T
2020-10-14 08:55:07 +08:00

353 lines
14 KiB
Python

"""Backend for running ONNX on Tensorflow
To run this, you will need to have Tensorflow installed as well.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
try:
from itertools import izip as zip
except ImportError: # will be 3.x series
pass
from onnx import defs
from onnx import numpy_helper
from onnx.backend.base import Backend
from onnx.backend.base import Device
from onnx.backend.base import namedtupledict
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
from onnx.helper import make_opsetid
import tensorflow as tf
from onnx_tf.backend_rep import TensorflowRep
from onnx_tf.common import data_type
from onnx_tf.common import get_device_option
from onnx_tf.common import get_unique_suffix
from onnx_tf.common import supports_device as common_supports_device
from onnx_tf.common.handler_helper import get_all_backend_handlers
from onnx_tf.pb_wrapper import OnnxNode
import onnx_tf.common as common
class TensorflowBackend(Backend):
""" Tensorflow Backend for ONNX
"""
@classmethod
def prepare(cls,
model,
device='CPU',
strict=True,
logging_level='INFO',
**kwargs):
"""Prepare an ONNX model for Tensorflow Backend.
This function converts an ONNX model to an internel representation
of the computational graph called TensorflowRep and returns
the converted representation.
:param model: The ONNX model to be converted.
:param device: The device to execute this model on.
:param strict: Whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
:param logging_level: The logging level, default is INFO. Change it to DEBUG
to see more conversion details or to WARNING to see less
:returns: A TensorflowRep class object representing the ONNX model
"""
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
return cls.onnx_model_to_tensorflow_rep(model, strict)
@classmethod
def onnx_model_to_tensorflow_rep(cls, model, strict):
""" Convert ONNX model to TensorflowRep.
:param model: ONNX ModelProto object.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:return: TensorflowRep object.
"""
# Models with IR_VERSION less than 3 does not have opset_import set.
# We default to minimum opset, this behavior is consistent with
# onnx checker.
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
if model.ir_version < 3:
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
else:
opset_import = model.opset_import
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
@classmethod
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
""" Convert ONNX graph to TensorflowRep.
:param graph_def: ONNX GraphProto object.
:param opset: ONNX OperatorSetIdProto list.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:return: TensorflowRep object.
"""
handlers = cls._get_handlers(opset)
tf_rep_graph = tf.Graph()
with tf_rep_graph.as_default():
# initializer: TensorProtos representing the values to initialize
# a given tensor.
# initialized: A list of names of the initialized tensors.
if graph_def.initializer:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer)
initialized = {init.name for init in graph_def.initializer}
else:
input_dict_items = []
initialized = set()
# creating placeholders for currently unknown inputs
for value_info in graph_def.input:
if value_info.name in initialized:
continue
shape = list(
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in value_info.type.tensor_type.shape.dim)
value_info_name = value_info.name.replace(
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name
x = tf.placeholder(data_type.onnx2tf(
value_info.type.tensor_type.elem_type),
name=value_info_name,
shape=shape)
input_dict_items.append((value_info.name, x))
# tensor dict: this dictionary is a map from variable names
# to the latest produced TF tensors of the given name.
# This dictionary will get updated as we build the graph to
# record the names of newly produced tensors.
tensor_dict = dict(input_dict_items)
# Since tensor dict may be updated, we need to keep a copy
# of the original input dict where we track the earliest
# defined tensors so we can have access to the placeholders
# to feed in input tensors when we run the graph.
input_dict = dict(input_dict_items)
for node in graph_def.node:
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
handlers,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)
tf_rep = TensorflowRep()
tf_rep.graph = tf_rep_graph
tf_rep.inputs = [
value_info.name
for value_info in graph_def.input
if value_info.name not in initialized
]
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
tf_rep.tensor_dict = tensor_dict
return tf_rep
@classmethod
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
""" Run ONNX node.
:param node: ONNX NodeProto object.
:param inputs: Inputs.
:param device: Device run on.
:param outputs_info: None.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowBackend, cls).run_node(node, inputs, device)
node_graph = tf.Graph()
with node_graph.as_default():
node = OnnxNode(node)
device_option = get_device_option(Device(device))
input_tensors = []
for i in inputs:
input_tensors.append(tf.constant(i))
if isinstance(inputs, dict):
feed_dict_raw = inputs
else:
assert len(node.inputs) == len(inputs)
feed_dict_raw = dict(zip(node.inputs, inputs))
# TODO: is constant the best way for feeding inputs?
input_dict = dict([
(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])
ops = cls._onnx_node_to_tensorflow_op(node, input_dict)
with tf.Session() as sess:
with tf.device(device_option):
sess.run(tf.global_variables_initializer())
output_vals = sess.run(ops)
return namedtupledict('Outputs', node.outputs)(*output_vals)
@classmethod
def _onnx_initializer_to_input_dict_items(cls, initializer):
""" Convert ONNX graph initializer to input dict items.
:param initializer: ONNX graph initializer, list of TensorProto.
:return: List of input dict items.
"""
def tensor2list(onnx_tensor):
# Use the onnx.numpy_helper because the data may be raw
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
def validate_initializer_name(name):
# Prepend a unique suffix if leading charater is "_"
name = get_unique_suffix() + name if name[0] is "_" else name
# Replace ":" with "_tf_" and append a unique suffix for
# traceability
return name.replace(
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
return [(init.name,
tf.constant(tensor2list(init),
shape=init.dims,
dtype=data_type.onnx2tf(init.data_type),
name=validate_initializer_name(init.name)))
for init in initializer]
@classmethod
def _onnx_node_to_tensorflow_op(cls,
node,
tensor_dict,
handlers=None,
opset=None,
strict=True):
"""
Convert onnx node to tensorflow op.
Args:
node: Onnx node object.
tensor_dict: Tensor dict of graph.
opset: Opset version of the operator set. Default 0 means using latest version.
strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Returns:
Tensorflow op
"""
handlers = handlers or cls._get_handlers(opset)
if handlers:
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
@classmethod
def _get_handlers(cls, opset):
""" Get all backend handlers with opset.
:param opset: ONNX OperatorSetIdProto list.
:return: All backend handlers.
"""
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
opset_dict = dict([(o.domain, o.version) for o in opset])
return get_all_backend_handlers(opset_dict)
@classmethod
def supports_device(cls, device):
return common_supports_device(device)
@classmethod
def onnx_graph_to_tensorflow_ops(cls,
subgraph,
input_values,
tensor_dict,
opset=None,
strict=True):
"""
Converts ONNX graph to Tensorflow operations
Args:
subgraph: the ONNX graph to be converted
input_values: dictionary with values/tensors to initialize
the subgraph inputs. if the subgraph.input
are send in as parameters then it is required,
otherwise this can be empty dictionary.
tensor_dict: the dictionary that contain values for all the
node.inputs in the subgraph that are not defined
in the subgraph or input_values.
opset: opset version of the operator set.
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
array of Tensorflow Tensors
"""
# get the subgraph.input from input_values
subgraph_tensor_dict = input_values.copy()
# get the rest of the subgraph input from tensor_dict
for i in subgraph.input:
if i.name not in subgraph_tensor_dict.keys():
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
# get the required initializer constant node(s) for the subgraph
# Need to get the initializer constant nodes from tensor_dict here
# because input from initializer will not be send in as inputs
# to the subgraph and those nodes are not in the subgraph
nodes_outputs = []
for node in subgraph.node:
for o_name in node.output:
nodes_outputs.append(o_name)
for node in subgraph.node:
for i_name in node.input:
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
):
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
subgraph_tensor_dict,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
subgraph_tensor_dict.update(curr_node_output_map)
return subgraph_tensor_dict
@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
"""
Converts ONNX graph to TensorflowRep
Args:
graph_def: the ONNX graph to be converted
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
TensorflowRep object.
"""
# get the opset of the installed ONNX
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
prepare = TensorflowBackend.prepare
run_node = TensorflowBackend.run_node
run_model = TensorflowBackend.run_model
supports_device = TensorflowBackend.supports_device
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep