Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/backend_rep.py
T
2020-10-14 08:55:07 +08:00

110 lines
3.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import tensorflow as tf
from onnx.backend.base import BackendRep, namedtupledict
class TensorflowRep(BackendRep):
def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
super(TensorflowRep, self).__init__()
self._graph = graph
self._inputs = inputs or []
self._outputs = outputs or []
self._tensor_dict = tensor_dict or {}
@property
def graph(self):
return self._graph
@graph.setter
def graph(self, graph):
self._graph = graph
@property
def inputs(self):
return self._inputs
@inputs.setter
def inputs(self, inputs):
self._inputs = inputs
@property
def outputs(self):
return self._outputs
@outputs.setter
def outputs(self, outputs):
self._outputs = outputs
@property
def tensor_dict(self):
return self._tensor_dict
@tensor_dict.setter
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict
def run(self, inputs, **kwargs):
""" Run TensorflowRep.
:param inputs: Given inputs.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowRep, self).run(inputs, **kwargs)
# TODO: handle name scope if necessary
with self.graph.as_default():
with tf.Session() as sess:
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])
feed_dict = {
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
}
sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]
output_values = sess.run(outputs, feed_dict=feed_dict)
return namedtupledict('Outputs', self.outputs)(*output_values)
def export_graph(self, path):
"""Export backend representation to a Tensorflow proto file.
This function obtains the graph proto corresponding to the ONNX
model associated with the backend representation and serializes
to a protobuf file.
:param path: The path to the output TF protobuf file.
:returns: none.
"""
graph_proto = self.graph.as_graph_def()
# rename the output nodes
meaningful_names = {}
for output_name in self.outputs:
meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
for node in graph_proto.node:
if node.name in meaningful_names.keys():
node.name = meaningful_names[node.name]
file = open(path, "wb")
file.write(graph_proto.SerializeToString())
file.close()