Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/pb_wrapper.py
T
2020-10-14 08:55:07 +08:00

462 lines
15 KiB
Python

import inspect
from itertools import chain
import numpy as np
from onnx import NodeProto
from onnx import TensorProto
from onnx import ValueInfoProto
from onnx import numpy_helper
from onnx.helper import make_graph
from onnx.helper import make_tensor
from onnx.helper import make_tensor_value_info
from onnx.helper import mapping
import tensorflow as tf
from tensorflow.core.framework.attr_value_pb2 import AttrValue
from tensorflow.core.framework.node_def_pb2 import NodeDef
from onnx_tf.common import attr_converter
from onnx_tf.common import attr_translator
from onnx_tf.common import CONST_MINUS_ONE_INT32
from onnx_tf.common import CONST_ONE_FP32
from onnx_tf.common import CONST_ONE_INT32
from onnx_tf.common import CONST_ZERO_INT32
from onnx_tf.common import IS_PYTHON3
from onnx_tf.common import logger
from onnx_tf.common.data_type import any_dtype_to_onnx_dtype
class TensorflowNode(object):
def __init__(self,
node=None,
name=None,
inputs=None,
outputs=None,
attr=None,
domain=None,
op_type=None):
# storing a reference to the original protobuf object
if node is None:
self.node = None
self.name = name or ""
self.inputs = inputs or []
self.attr = attr or {}
self.domain = domain or ""
self.op_type = op_type or ""
self.outputs = outputs or self.get_outputs_names()
elif isinstance(node, (OnnxNode, NodeProto)):
self._load_onnx_node(node)
elif isinstance(node, NodeDef):
self._load_tf_node(node)
def _load_onnx_node(self, node):
if isinstance(node, NodeProto):
node = OnnxNode(node)
self.name = node.name
self.inputs = node.inputs
self.outputs = node.outputs
self.attr = node.attrs
self.domain = node.domain
self.op_type = node.op_type
def _load_tf_node(self, node):
self.node = node
self.name = node.name
self.inputs = list(node.input)
self.attr = {}
for key, val in node.attr.items():
new_val = attr_translator.translate_tf(key, val)
if isinstance(new_val, AttrValue):
new_val = attr_converter.convert_tf(new_val)
self.attr[key] = new_val
splitted_op_name = node.op.split(".")
self.domain = "" if len(splitted_op_name) == 1 else ".".join(
splitted_op_name[:-1])
self.op_type = splitted_op_name[-1]
self.outputs = self.get_outputs_names()
def get_outputs_names(self, num=None):
""" Helper method to get outputs names.
e.g. tf.split: [Split, Split:1, Split:2]
:param num: Force to get `num` outputs names.
:return: List of outputs names.
"""
if num is None:
if "_output_shapes" in self.attr:
num = len(self.attr["_output_shapes"])
else:
num = 1
logger.warning("_output_shapes is not in node.attr. "
"The num of output is set to 1 for commonly. "
"It will cause problem with case of multiple outputs.")
return [
self.name + ":{}".format(i) if i > 0 else self.name for i in range(num)
]
class TensorflowGraph(object):
def __init__(self, graph_def, outputs=(), graph_name="graph"):
self._graph_name = graph_name
self._graph_def = self._process_graph_def(graph_def)
self._nodes = self._create_util_nodes() + [
TensorflowNode(node) for node in self.graph_def.node
]
self._nodes_dict = {n.name: n for n in self._nodes}
self._outputs = outputs or self.get_output_node_names(self.graph_def)
@staticmethod
def _create_util_nodes():
util_nodes = [(CONST_MINUS_ONE_INT32, np.array([-1]).astype(np.int32)),
(CONST_ZERO_INT32, np.array([0]).astype(np.int32)),
(CONST_ONE_INT32, np.array([1]).astype(np.int32))]
return [
TensorflowNode(
op_type="Const",
name=name,
attr={
"value": value,
"dtype": any_dtype_to_onnx_dtype(value.dtype),
"_output_shapes": [value.shape]
}) for name, value in util_nodes
]
def get_node_by_name(self, name):
node = self._nodes_dict.get(name, None)
if node is None:
raise ValueError(
"Node {} is not found in the graph provided".format(name))
return node
def _process_graph_def(self, graph_def):
if "_output_shapes" not in TensorflowNode(graph_def.node[0]).attr:
graph_def = self._add_infer_shapes(graph_def)
return graph_def
@staticmethod
def _add_infer_shapes(graph_def):
with tf.Graph().as_default():
with tf.Session(
config=tf.ConfigProto(
graph_options=tf.GraphOptions(infer_shapes=True))) as sess:
tf.import_graph_def(graph_def, name="")
return sess.graph_def
@staticmethod
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
input_names, output_names = set(), set()
for node in graph_def.node:
output_names.add(node.name)
input_names.update(set(node.input))
return list(output_names - input_names)
def update_nodes(self, nodes):
self._nodes = nodes
self._nodes_dict = {n.name: n for n in self._nodes}
@property
def graph_def(self):
return self._graph_def
@property
def graph_name(self):
return self._graph_name
@property
def nodes(self):
return self._nodes
@property
def nodes_dict(self):
return self._nodes_dict
@property
def outputs(self):
return self._outputs
# TODO: Move this into ONNX main library
class OnnxNode(object):
"""
Reimplementation of NodeProto from ONNX, but in a form
more convenient to work with from Python.
"""
def __init__(self, node):
self.name = str(node.name)
self.op_type = str(node.op_type)
self.domain = str(node.domain)
self.attrs = dict([(attr.name,
attr_translator.translate_onnx(
attr.name, attr_converter.convert_onnx(attr)))
for attr in node.attribute])
self.inputs = list(node.input)
self.outputs = list(node.output)
self.node_proto = node
class OnnxGraph(object):
""" A helper class for making ONNX graph.
This class holds all information ONNX graph needs.
"""
def __init__(self, name=None, graph_proto=None):
if graph_proto:
self._name = graph_proto.name
self._inputs_proto = list(graph_proto.input)
self._outputs_proto = list(graph_proto.output)
self._nodes_proto = list(graph_proto.node)
self._consts_proto = list(graph_proto.initializer)
self._value_info_proto = list(graph_proto.value_info)
self._consts = dict([(init.name, numpy_helper.to_array(init))
for init in graph_proto.initializer])
else:
self._name = name or ""
self._inputs_proto = []
self._outputs_proto = []
self._nodes_proto = []
self._consts = {}
self._consts_proto = []
self._value_info_proto = []
# Either way, data_type_cast_map is empty when initialized.
self._data_type_cast_map = {}
self._add_utility_constants()
def _add_utility_constants(self):
util_consts = {CONST_ONE_FP32: np.array([1.0]).astype(np.float32)}
# Add a few useful utility constants:
for name, value in util_consts.items():
self.add_const_explicit(name=name, value=value)
self.add_const_proto_explicit(
name=name, value=value, np_dtype=value.dtype)
self.add_input_proto_explicit(
name=name, shape=value.shape, np_dtype=value.dtype)
# This list holds the protobuf objects of type ValueInfoProto
# representing the input to the converted ONNX graph.
@property
def inputs_proto(self):
return self._inputs_proto
@inputs_proto.setter
def inputs_proto(self, inputs_proto):
self._inputs_proto = inputs_proto
@property
def all_node_inputs(self):
return list(chain.from_iterable(map(lambda p: p.input, self._nodes_proto)))
@property
def outputs(self):
return list(map(lambda p: p.name, self._outputs_proto))
@property
def outputs_proto(self):
return self._outputs_proto
# This list holds the protobuf objects of type NodeProto
# representing the ops in the converted ONNX graph.
@property
def nodes_proto(self):
return self._nodes_proto
@nodes_proto.setter
def nodes_proto(self, nodes_proto):
self._nodes_proto = nodes_proto
# This dictionary contains a map from the name of the constant
# op to the array of values it holds. This is useful because
# tensorflow is less eager to know about input values at
# graph construction time than ONNX. That is to say, some ONNX
# attributes are input tensors in TF. This dictionary extracts
# those values of constant tensors that are known at graph
# construction time.
@property
def consts(self):
return self._consts
@consts.setter
def consts(self, consts):
self._consts = consts
# Sometimes the constants are used as inputs to ops. This list
# holds initializers that creates global constant tensors available
# to be accessed by ops as inputs (as oppose to attributes which
# is supplied by the `consts` map above).
@property
def consts_proto(self):
return self._consts_proto
@consts_proto.setter
def consts_proto(self, consts_proto):
self._consts_proto = consts_proto
# A map holds nodes name and new data type. Will be used to
# process protos to match ONNX type constraints.
@property
def data_type_cast_map(self):
return self._data_type_cast_map
@data_type_cast_map.setter
def data_type_cast_map(self, data_type_cast_map):
self._data_type_cast_map = data_type_cast_map
# This list holds the protobuf objects of type ValueInfoProto
# representing the all nodes' outputs to the converted ONNX graph.
@property
def value_info_proto(self):
return self._value_info_proto
def add_input_proto_explicit(self,
name,
shape,
np_dtype=None,
tf_dtype=None,
onnx_dtype=None):
onnx_dtype = any_dtype_to_onnx_dtype(
np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
input_proto = make_tensor_value_info(name, onnx_dtype, shape)
self._inputs_proto.append(input_proto)
def add_input_proto(self, node):
name = node.name
onnx_dtype = node.attr["dtype"]
shape = node.attr["shape"] if node.op_type != "Const" else node.attr[
'value'].shape
self.add_input_proto_explicit(name, shape, onnx_dtype=onnx_dtype)
def add_output_proto(self, node):
output_onnx_type = node.attr.get("T", TensorProto.BOOL)
for i, output_shape in enumerate(node.attr["_output_shapes"]):
output_name = node.name + ":{}".format(i) if i > 0 else node.name
self._outputs_proto.append(
make_tensor_value_info(output_name, output_onnx_type, output_shape))
def add_node_proto(self, node_proto):
if not isinstance(node_proto, (list, tuple)):
node_proto = [node_proto]
self._nodes_proto.extend(node_proto)
def remove_node_proto(self, names):
if not isinstance(names, (list, tuple)):
names = [names]
self._nodes_proto = list(
filter(lambda x: x.name not in names, self._nodes_proto))
def add_const_explicit(self, name, value):
self._consts[name] = value
def add_const(self, node):
self.add_const_explicit(node.name, node.attr["value"])
def add_const_proto_explicit(self,
name,
value,
np_dtype=None,
tf_dtype=None,
onnx_dtype=None):
onnx_dtype = any_dtype_to_onnx_dtype(
np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
const_dim = len(value.shape)
if const_dim == 0:
raw_values = [value.tolist()]
values = [value]
else:
raw_values = value.flatten().tolist()
values = value
shape = np.array(values).shape
const_proto = make_tensor(
name=name, data_type=onnx_dtype, dims=shape, vals=raw_values)
self._consts_proto.append(const_proto)
def add_const_proto(self, node):
self.add_const_proto_explicit(
node.name, node.attr["value"], onnx_dtype=node.attr["dtype"])
def add_value_info_proto(self, node):
node_onnx_type = node.attr.get("T", TensorProto.BOOL)
for i, output_shape in enumerate(node.attr["_output_shapes"]):
node_name = node.name + ":{}".format(i) if i > 0 else node.name
value_info_proto = make_tensor_value_info(node_name, node_onnx_type,
output_shape)
self._value_info_proto.append(value_info_proto)
# Remove proto in inputs_proto and consts_proto
# if proto is not used as input or an output in ONNX
def _clean_graph(self):
in_out = self.all_node_inputs + self.outputs
self._inputs_proto = list(
filter(lambda x: x.name in in_out, self.inputs_proto))
self._consts_proto = list(
filter(lambda x: x.name in in_out, self.consts_proto))
def _fix_data_type(self):
self.inputs_proto = self._data_type_caster(self.inputs_proto,
self.data_type_cast_map)
self.consts_proto = self._data_type_caster(self.consts_proto,
self.data_type_cast_map)
@classmethod
def _data_type_caster(cls, protos, data_type_cast_map):
"""Cast to a new data type if node name is in data_type_cast_map.
Be used to process protos to match ONNX type constraints.
:param protos: Target protos.
TensorProto for inputs and ValueInfoProto for consts.
:param data_type_cast_map: A {node.name: new_data_type} dict.
:return: Processed protos.
"""
if not data_type_cast_map:
return protos
result = []
for proto in protos:
new_proto = proto
if proto.name in data_type_cast_map:
new_data_type = data_type_cast_map[proto.name]
if type(proto) == TensorProto and proto.data_type != new_data_type:
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]]
vals = getattr(proto, field)
new_proto = make_tensor(
name=proto.name,
data_type=new_data_type,
dims=proto.dims,
vals=vals)
elif type(
proto
) == ValueInfoProto and proto.type.tensor_type.elem_type != new_data_type:
new_proto.type.tensor_type.elem_type = new_data_type
result.append(new_proto)
return result
def make_graph_proto(self):
self._clean_graph()
self._fix_data_type()
if IS_PYTHON3:
params = list(inspect.signature(make_graph).parameters.keys())
else:
params = inspect.getargspec(make_graph).args
kwargs = {
"initializer": self.consts_proto,
"value_info": self.value_info_proto
}
return make_graph(self.nodes_proto, self._name, self.inputs_proto,
self.outputs_proto,
**dict([(k, kwargs[k]) for k in kwargs if k in params]))