重命名 pt2tf 为 pt2pb

This commit is contained in:
zhutian
2020-10-14 08:55:07 +08:00
committed by Gitee
parent 324ab60a5d
commit 90ae190559
407 changed files with 0 additions and 0 deletions
+461
View File
@@ -0,0 +1,461 @@
import inspect
from itertools import chain
import numpy as np
from onnx import NodeProto
from onnx import TensorProto
from onnx import ValueInfoProto
from onnx import numpy_helper
from onnx.helper import make_graph
from onnx.helper import make_tensor
from onnx.helper import make_tensor_value_info
from onnx.helper import mapping
import tensorflow as tf
from tensorflow.core.framework.attr_value_pb2 import AttrValue
from tensorflow.core.framework.node_def_pb2 import NodeDef
from onnx_tf.common import attr_converter
from onnx_tf.common import attr_translator
from onnx_tf.common import CONST_MINUS_ONE_INT32
from onnx_tf.common import CONST_ONE_FP32
from onnx_tf.common import CONST_ONE_INT32
from onnx_tf.common import CONST_ZERO_INT32
from onnx_tf.common import IS_PYTHON3
from onnx_tf.common import logger
from onnx_tf.common.data_type import any_dtype_to_onnx_dtype
class TensorflowNode(object):
def __init__(self,
node=None,
name=None,
inputs=None,
outputs=None,
attr=None,
domain=None,
op_type=None):
# storing a reference to the original protobuf object
if node is None:
self.node = None
self.name = name or ""
self.inputs = inputs or []
self.attr = attr or {}
self.domain = domain or ""
self.op_type = op_type or ""
self.outputs = outputs or self.get_outputs_names()
elif isinstance(node, (OnnxNode, NodeProto)):
self._load_onnx_node(node)
elif isinstance(node, NodeDef):
self._load_tf_node(node)
def _load_onnx_node(self, node):
if isinstance(node, NodeProto):
node = OnnxNode(node)
self.name = node.name
self.inputs = node.inputs
self.outputs = node.outputs
self.attr = node.attrs
self.domain = node.domain
self.op_type = node.op_type
def _load_tf_node(self, node):
self.node = node
self.name = node.name
self.inputs = list(node.input)
self.attr = {}
for key, val in node.attr.items():
new_val = attr_translator.translate_tf(key, val)
if isinstance(new_val, AttrValue):
new_val = attr_converter.convert_tf(new_val)
self.attr[key] = new_val
splitted_op_name = node.op.split(".")
self.domain = "" if len(splitted_op_name) == 1 else ".".join(
splitted_op_name[:-1])
self.op_type = splitted_op_name[-1]
self.outputs = self.get_outputs_names()
def get_outputs_names(self, num=None):
""" Helper method to get outputs names.
e.g. tf.split: [Split, Split:1, Split:2]
:param num: Force to get `num` outputs names.
:return: List of outputs names.
"""
if num is None:
if "_output_shapes" in self.attr:
num = len(self.attr["_output_shapes"])
else:
num = 1
logger.warning("_output_shapes is not in node.attr. "
"The num of output is set to 1 for commonly. "
"It will cause problem with case of multiple outputs.")
return [
self.name + ":{}".format(i) if i > 0 else self.name for i in range(num)
]
class TensorflowGraph(object):
def __init__(self, graph_def, outputs=(), graph_name="graph"):
self._graph_name = graph_name
self._graph_def = self._process_graph_def(graph_def)
self._nodes = self._create_util_nodes() + [
TensorflowNode(node) for node in self.graph_def.node
]
self._nodes_dict = {n.name: n for n in self._nodes}
self._outputs = outputs or self.get_output_node_names(self.graph_def)
@staticmethod
def _create_util_nodes():
util_nodes = [(CONST_MINUS_ONE_INT32, np.array([-1]).astype(np.int32)),
(CONST_ZERO_INT32, np.array([0]).astype(np.int32)),
(CONST_ONE_INT32, np.array([1]).astype(np.int32))]
return [
TensorflowNode(
op_type="Const",
name=name,
attr={
"value": value,
"dtype": any_dtype_to_onnx_dtype(value.dtype),
"_output_shapes": [value.shape]
}) for name, value in util_nodes
]
def get_node_by_name(self, name):
node = self._nodes_dict.get(name, None)
if node is None:
raise ValueError(
"Node {} is not found in the graph provided".format(name))
return node
def _process_graph_def(self, graph_def):
if "_output_shapes" not in TensorflowNode(graph_def.node[0]).attr:
graph_def = self._add_infer_shapes(graph_def)
return graph_def
@staticmethod
def _add_infer_shapes(graph_def):
with tf.Graph().as_default():
with tf.Session(
config=tf.ConfigProto(
graph_options=tf.GraphOptions(infer_shapes=True))) as sess:
tf.import_graph_def(graph_def, name="")
return sess.graph_def
@staticmethod
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
input_names, output_names = set(), set()
for node in graph_def.node:
output_names.add(node.name)
input_names.update(set(node.input))
return list(output_names - input_names)
def update_nodes(self, nodes):
self._nodes = nodes
self._nodes_dict = {n.name: n for n in self._nodes}
@property
def graph_def(self):
return self._graph_def
@property
def graph_name(self):
return self._graph_name
@property
def nodes(self):
return self._nodes
@property
def nodes_dict(self):
return self._nodes_dict
@property
def outputs(self):
return self._outputs
# TODO: Move this into ONNX main library
class OnnxNode(object):
"""
Reimplementation of NodeProto from ONNX, but in a form
more convenient to work with from Python.
"""
def __init__(self, node):
self.name = str(node.name)
self.op_type = str(node.op_type)
self.domain = str(node.domain)
self.attrs = dict([(attr.name,
attr_translator.translate_onnx(
attr.name, attr_converter.convert_onnx(attr)))
for attr in node.attribute])
self.inputs = list(node.input)
self.outputs = list(node.output)
self.node_proto = node
class OnnxGraph(object):
""" A helper class for making ONNX graph.
This class holds all information ONNX graph needs.
"""
def __init__(self, name=None, graph_proto=None):
if graph_proto:
self._name = graph_proto.name
self._inputs_proto = list(graph_proto.input)
self._outputs_proto = list(graph_proto.output)
self._nodes_proto = list(graph_proto.node)
self._consts_proto = list(graph_proto.initializer)
self._value_info_proto = list(graph_proto.value_info)
self._consts = dict([(init.name, numpy_helper.to_array(init))
for init in graph_proto.initializer])
else:
self._name = name or ""
self._inputs_proto = []
self._outputs_proto = []
self._nodes_proto = []
self._consts = {}
self._consts_proto = []
self._value_info_proto = []
# Either way, data_type_cast_map is empty when initialized.
self._data_type_cast_map = {}
self._add_utility_constants()
def _add_utility_constants(self):
util_consts = {CONST_ONE_FP32: np.array([1.0]).astype(np.float32)}
# Add a few useful utility constants:
for name, value in util_consts.items():
self.add_const_explicit(name=name, value=value)
self.add_const_proto_explicit(
name=name, value=value, np_dtype=value.dtype)
self.add_input_proto_explicit(
name=name, shape=value.shape, np_dtype=value.dtype)
# This list holds the protobuf objects of type ValueInfoProto
# representing the input to the converted ONNX graph.
@property
def inputs_proto(self):
return self._inputs_proto
@inputs_proto.setter
def inputs_proto(self, inputs_proto):
self._inputs_proto = inputs_proto
@property
def all_node_inputs(self):
return list(chain.from_iterable(map(lambda p: p.input, self._nodes_proto)))
@property
def outputs(self):
return list(map(lambda p: p.name, self._outputs_proto))
@property
def outputs_proto(self):
return self._outputs_proto
# This list holds the protobuf objects of type NodeProto
# representing the ops in the converted ONNX graph.
@property
def nodes_proto(self):
return self._nodes_proto
@nodes_proto.setter
def nodes_proto(self, nodes_proto):
self._nodes_proto = nodes_proto
# This dictionary contains a map from the name of the constant
# op to the array of values it holds. This is useful because
# tensorflow is less eager to know about input values at
# graph construction time than ONNX. That is to say, some ONNX
# attributes are input tensors in TF. This dictionary extracts
# those values of constant tensors that are known at graph
# construction time.
@property
def consts(self):
return self._consts
@consts.setter
def consts(self, consts):
self._consts = consts
# Sometimes the constants are used as inputs to ops. This list
# holds initializers that creates global constant tensors available
# to be accessed by ops as inputs (as oppose to attributes which
# is supplied by the `consts` map above).
@property
def consts_proto(self):
return self._consts_proto
@consts_proto.setter
def consts_proto(self, consts_proto):
self._consts_proto = consts_proto
# A map holds nodes name and new data type. Will be used to
# process protos to match ONNX type constraints.
@property
def data_type_cast_map(self):
return self._data_type_cast_map
@data_type_cast_map.setter
def data_type_cast_map(self, data_type_cast_map):
self._data_type_cast_map = data_type_cast_map
# This list holds the protobuf objects of type ValueInfoProto
# representing the all nodes' outputs to the converted ONNX graph.
@property
def value_info_proto(self):
return self._value_info_proto
def add_input_proto_explicit(self,
name,
shape,
np_dtype=None,
tf_dtype=None,
onnx_dtype=None):
onnx_dtype = any_dtype_to_onnx_dtype(
np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
input_proto = make_tensor_value_info(name, onnx_dtype, shape)
self._inputs_proto.append(input_proto)
def add_input_proto(self, node):
name = node.name
onnx_dtype = node.attr["dtype"]
shape = node.attr["shape"] if node.op_type != "Const" else node.attr[
'value'].shape
self.add_input_proto_explicit(name, shape, onnx_dtype=onnx_dtype)
def add_output_proto(self, node):
output_onnx_type = node.attr.get("T", TensorProto.BOOL)
for i, output_shape in enumerate(node.attr["_output_shapes"]):
output_name = node.name + ":{}".format(i) if i > 0 else node.name
self._outputs_proto.append(
make_tensor_value_info(output_name, output_onnx_type, output_shape))
def add_node_proto(self, node_proto):
if not isinstance(node_proto, (list, tuple)):
node_proto = [node_proto]
self._nodes_proto.extend(node_proto)
def remove_node_proto(self, names):
if not isinstance(names, (list, tuple)):
names = [names]
self._nodes_proto = list(
filter(lambda x: x.name not in names, self._nodes_proto))
def add_const_explicit(self, name, value):
self._consts[name] = value
def add_const(self, node):
self.add_const_explicit(node.name, node.attr["value"])
def add_const_proto_explicit(self,
name,
value,
np_dtype=None,
tf_dtype=None,
onnx_dtype=None):
onnx_dtype = any_dtype_to_onnx_dtype(
np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
const_dim = len(value.shape)
if const_dim == 0:
raw_values = [value.tolist()]
values = [value]
else:
raw_values = value.flatten().tolist()
values = value
shape = np.array(values).shape
const_proto = make_tensor(
name=name, data_type=onnx_dtype, dims=shape, vals=raw_values)
self._consts_proto.append(const_proto)
def add_const_proto(self, node):
self.add_const_proto_explicit(
node.name, node.attr["value"], onnx_dtype=node.attr["dtype"])
def add_value_info_proto(self, node):
node_onnx_type = node.attr.get("T", TensorProto.BOOL)
for i, output_shape in enumerate(node.attr["_output_shapes"]):
node_name = node.name + ":{}".format(i) if i > 0 else node.name
value_info_proto = make_tensor_value_info(node_name, node_onnx_type,
output_shape)
self._value_info_proto.append(value_info_proto)
# Remove proto in inputs_proto and consts_proto
# if proto is not used as input or an output in ONNX
def _clean_graph(self):
in_out = self.all_node_inputs + self.outputs
self._inputs_proto = list(
filter(lambda x: x.name in in_out, self.inputs_proto))
self._consts_proto = list(
filter(lambda x: x.name in in_out, self.consts_proto))
def _fix_data_type(self):
self.inputs_proto = self._data_type_caster(self.inputs_proto,
self.data_type_cast_map)
self.consts_proto = self._data_type_caster(self.consts_proto,
self.data_type_cast_map)
@classmethod
def _data_type_caster(cls, protos, data_type_cast_map):
"""Cast to a new data type if node name is in data_type_cast_map.
Be used to process protos to match ONNX type constraints.
:param protos: Target protos.
TensorProto for inputs and ValueInfoProto for consts.
:param data_type_cast_map: A {node.name: new_data_type} dict.
:return: Processed protos.
"""
if not data_type_cast_map:
return protos
result = []
for proto in protos:
new_proto = proto
if proto.name in data_type_cast_map:
new_data_type = data_type_cast_map[proto.name]
if type(proto) == TensorProto and proto.data_type != new_data_type:
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]]
vals = getattr(proto, field)
new_proto = make_tensor(
name=proto.name,
data_type=new_data_type,
dims=proto.dims,
vals=vals)
elif type(
proto
) == ValueInfoProto and proto.type.tensor_type.elem_type != new_data_type:
new_proto.type.tensor_type.elem_type = new_data_type
result.append(new_proto)
return result
def make_graph_proto(self):
self._clean_graph()
self._fix_data_type()
if IS_PYTHON3:
params = list(inspect.signature(make_graph).parameters.keys())
else:
params = inspect.getargspec(make_graph).args
kwargs = {
"initializer": self.consts_proto,
"value_info": self.value_info_proto
}
return make_graph(self.nodes_proto, self._name, self.inputs_proto,
self.outputs_proto,
**dict([(k, kwargs[k]) for k in kwargs if k in params]))