110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import tensorflow as tf
|
|
|
|
from onnx.backend.base import BackendRep, namedtupledict
|
|
|
|
|
|
class TensorflowRep(BackendRep):
|
|
|
|
def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
|
|
super(TensorflowRep, self).__init__()
|
|
self._graph = graph
|
|
self._inputs = inputs or []
|
|
self._outputs = outputs or []
|
|
self._tensor_dict = tensor_dict or {}
|
|
|
|
@property
|
|
def graph(self):
|
|
return self._graph
|
|
|
|
@graph.setter
|
|
def graph(self, graph):
|
|
self._graph = graph
|
|
|
|
@property
|
|
def inputs(self):
|
|
return self._inputs
|
|
|
|
@inputs.setter
|
|
def inputs(self, inputs):
|
|
self._inputs = inputs
|
|
|
|
@property
|
|
def outputs(self):
|
|
return self._outputs
|
|
|
|
@outputs.setter
|
|
def outputs(self, outputs):
|
|
self._outputs = outputs
|
|
|
|
@property
|
|
def tensor_dict(self):
|
|
return self._tensor_dict
|
|
|
|
@tensor_dict.setter
|
|
def tensor_dict(self, tensor_dict):
|
|
self._tensor_dict = tensor_dict
|
|
|
|
def run(self, inputs, **kwargs):
|
|
""" Run TensorflowRep.
|
|
|
|
:param inputs: Given inputs.
|
|
:param kwargs: Other args.
|
|
:return: Outputs.
|
|
"""
|
|
super(TensorflowRep, self).run(inputs, **kwargs)
|
|
|
|
# TODO: handle name scope if necessary
|
|
with self.graph.as_default():
|
|
with tf.Session() as sess:
|
|
if isinstance(inputs, dict):
|
|
feed_dict = inputs
|
|
elif isinstance(inputs, list) or isinstance(inputs, tuple):
|
|
if len(self.inputs) != len(inputs):
|
|
raise RuntimeError('Expected {} values for uninitialized '
|
|
'graph inputs ({}), but got {}.'.format(
|
|
len(self.inputs), ', '.join(self.inputs),
|
|
len(inputs)))
|
|
feed_dict = dict(zip(self.inputs, inputs))
|
|
else:
|
|
# single input
|
|
feed_dict = dict([(self.inputs[0], inputs)])
|
|
|
|
feed_dict = {
|
|
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
|
|
}
|
|
|
|
sess.run(tf.global_variables_initializer())
|
|
outputs = [self.tensor_dict[output] for output in self.outputs]
|
|
|
|
output_values = sess.run(outputs, feed_dict=feed_dict)
|
|
return namedtupledict('Outputs', self.outputs)(*output_values)
|
|
|
|
def export_graph(self, path):
|
|
"""Export backend representation to a Tensorflow proto file.
|
|
|
|
This function obtains the graph proto corresponding to the ONNX
|
|
model associated with the backend representation and serializes
|
|
to a protobuf file.
|
|
|
|
:param path: The path to the output TF protobuf file.
|
|
|
|
:returns: none.
|
|
"""
|
|
graph_proto = self.graph.as_graph_def()
|
|
# rename the output nodes
|
|
meaningful_names = {}
|
|
for output_name in self.outputs:
|
|
meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
|
|
for node in graph_proto.node:
|
|
if node.name in meaningful_names.keys():
|
|
node.name = meaningful_names[node.name]
|
|
|
|
file = open(path, "wb")
|
|
file.write(graph_proto.SerializeToString())
|
|
file.close()
|