import inspect from itertools import chain import numpy as np from onnx import NodeProto from onnx import TensorProto from onnx import ValueInfoProto from onnx import numpy_helper from onnx.helper import make_graph from onnx.helper import make_tensor from onnx.helper import make_tensor_value_info from onnx.helper import mapping import tensorflow as tf from tensorflow.core.framework.attr_value_pb2 import AttrValue from tensorflow.core.framework.node_def_pb2 import NodeDef from onnx_tf.common import attr_converter from onnx_tf.common import attr_translator from onnx_tf.common import CONST_MINUS_ONE_INT32 from onnx_tf.common import CONST_ONE_FP32 from onnx_tf.common import CONST_ONE_INT32 from onnx_tf.common import CONST_ZERO_INT32 from onnx_tf.common import IS_PYTHON3 from onnx_tf.common import logger from onnx_tf.common.data_type import any_dtype_to_onnx_dtype class TensorflowNode(object): def __init__(self, node=None, name=None, inputs=None, outputs=None, attr=None, domain=None, op_type=None): # storing a reference to the original protobuf object if node is None: self.node = None self.name = name or "" self.inputs = inputs or [] self.attr = attr or {} self.domain = domain or "" self.op_type = op_type or "" self.outputs = outputs or self.get_outputs_names() elif isinstance(node, (OnnxNode, NodeProto)): self._load_onnx_node(node) elif isinstance(node, NodeDef): self._load_tf_node(node) def _load_onnx_node(self, node): if isinstance(node, NodeProto): node = OnnxNode(node) self.name = node.name self.inputs = node.inputs self.outputs = node.outputs self.attr = node.attrs self.domain = node.domain self.op_type = node.op_type def _load_tf_node(self, node): self.node = node self.name = node.name self.inputs = list(node.input) self.attr = {} for key, val in node.attr.items(): new_val = attr_translator.translate_tf(key, val) if isinstance(new_val, AttrValue): new_val = attr_converter.convert_tf(new_val) self.attr[key] = new_val splitted_op_name = node.op.split(".") self.domain = "" if len(splitted_op_name) == 1 else ".".join( splitted_op_name[:-1]) self.op_type = splitted_op_name[-1] self.outputs = self.get_outputs_names() def get_outputs_names(self, num=None): """ Helper method to get outputs names. e.g. tf.split: [Split, Split:1, Split:2] :param num: Force to get `num` outputs names. :return: List of outputs names. """ if num is None: if "_output_shapes" in self.attr: num = len(self.attr["_output_shapes"]) else: num = 1 logger.warning("_output_shapes is not in node.attr. " "The num of output is set to 1 for commonly. " "It will cause problem with case of multiple outputs.") return [ self.name + ":{}".format(i) if i > 0 else self.name for i in range(num) ] class TensorflowGraph(object): def __init__(self, graph_def, outputs=(), graph_name="graph"): self._graph_name = graph_name self._graph_def = self._process_graph_def(graph_def) self._nodes = self._create_util_nodes() + [ TensorflowNode(node) for node in self.graph_def.node ] self._nodes_dict = {n.name: n for n in self._nodes} self._outputs = outputs or self.get_output_node_names(self.graph_def) @staticmethod def _create_util_nodes(): util_nodes = [(CONST_MINUS_ONE_INT32, np.array([-1]).astype(np.int32)), (CONST_ZERO_INT32, np.array([0]).astype(np.int32)), (CONST_ONE_INT32, np.array([1]).astype(np.int32))] return [ TensorflowNode( op_type="Const", name=name, attr={ "value": value, "dtype": any_dtype_to_onnx_dtype(value.dtype), "_output_shapes": [value.shape] }) for name, value in util_nodes ] def get_node_by_name(self, name): node = self._nodes_dict.get(name, None) if node is None: raise ValueError( "Node {} is not found in the graph provided".format(name)) return node def _process_graph_def(self, graph_def): if "_output_shapes" not in TensorflowNode(graph_def.node[0]).attr: graph_def = self._add_infer_shapes(graph_def) return graph_def @staticmethod def _add_infer_shapes(graph_def): with tf.Graph().as_default(): with tf.Session( config=tf.ConfigProto( graph_options=tf.GraphOptions(infer_shapes=True))) as sess: tf.import_graph_def(graph_def, name="") return sess.graph_def @staticmethod def get_output_node_names(graph_def): """Get output node names from GraphDef. Args: graph_def: GraphDef object. Returns: List of output node names. """ input_names, output_names = set(), set() for node in graph_def.node: output_names.add(node.name) input_names.update(set(node.input)) return list(output_names - input_names) def update_nodes(self, nodes): self._nodes = nodes self._nodes_dict = {n.name: n for n in self._nodes} @property def graph_def(self): return self._graph_def @property def graph_name(self): return self._graph_name @property def nodes(self): return self._nodes @property def nodes_dict(self): return self._nodes_dict @property def outputs(self): return self._outputs # TODO: Move this into ONNX main library class OnnxNode(object): """ Reimplementation of NodeProto from ONNX, but in a form more convenient to work with from Python. """ def __init__(self, node): self.name = str(node.name) self.op_type = str(node.op_type) self.domain = str(node.domain) self.attrs = dict([(attr.name, attr_translator.translate_onnx( attr.name, attr_converter.convert_onnx(attr))) for attr in node.attribute]) self.inputs = list(node.input) self.outputs = list(node.output) self.node_proto = node class OnnxGraph(object): """ A helper class for making ONNX graph. This class holds all information ONNX graph needs. """ def __init__(self, name=None, graph_proto=None): if graph_proto: self._name = graph_proto.name self._inputs_proto = list(graph_proto.input) self._outputs_proto = list(graph_proto.output) self._nodes_proto = list(graph_proto.node) self._consts_proto = list(graph_proto.initializer) self._value_info_proto = list(graph_proto.value_info) self._consts = dict([(init.name, numpy_helper.to_array(init)) for init in graph_proto.initializer]) else: self._name = name or "" self._inputs_proto = [] self._outputs_proto = [] self._nodes_proto = [] self._consts = {} self._consts_proto = [] self._value_info_proto = [] # Either way, data_type_cast_map is empty when initialized. self._data_type_cast_map = {} self._add_utility_constants() def _add_utility_constants(self): util_consts = {CONST_ONE_FP32: np.array([1.0]).astype(np.float32)} # Add a few useful utility constants: for name, value in util_consts.items(): self.add_const_explicit(name=name, value=value) self.add_const_proto_explicit( name=name, value=value, np_dtype=value.dtype) self.add_input_proto_explicit( name=name, shape=value.shape, np_dtype=value.dtype) # This list holds the protobuf objects of type ValueInfoProto # representing the input to the converted ONNX graph. @property def inputs_proto(self): return self._inputs_proto @inputs_proto.setter def inputs_proto(self, inputs_proto): self._inputs_proto = inputs_proto @property def all_node_inputs(self): return list(chain.from_iterable(map(lambda p: p.input, self._nodes_proto))) @property def outputs(self): return list(map(lambda p: p.name, self._outputs_proto)) @property def outputs_proto(self): return self._outputs_proto # This list holds the protobuf objects of type NodeProto # representing the ops in the converted ONNX graph. @property def nodes_proto(self): return self._nodes_proto @nodes_proto.setter def nodes_proto(self, nodes_proto): self._nodes_proto = nodes_proto # This dictionary contains a map from the name of the constant # op to the array of values it holds. This is useful because # tensorflow is less eager to know about input values at # graph construction time than ONNX. That is to say, some ONNX # attributes are input tensors in TF. This dictionary extracts # those values of constant tensors that are known at graph # construction time. @property def consts(self): return self._consts @consts.setter def consts(self, consts): self._consts = consts # Sometimes the constants are used as inputs to ops. This list # holds initializers that creates global constant tensors available # to be accessed by ops as inputs (as oppose to attributes which # is supplied by the `consts` map above). @property def consts_proto(self): return self._consts_proto @consts_proto.setter def consts_proto(self, consts_proto): self._consts_proto = consts_proto # A map holds nodes name and new data type. Will be used to # process protos to match ONNX type constraints. @property def data_type_cast_map(self): return self._data_type_cast_map @data_type_cast_map.setter def data_type_cast_map(self, data_type_cast_map): self._data_type_cast_map = data_type_cast_map # This list holds the protobuf objects of type ValueInfoProto # representing the all nodes' outputs to the converted ONNX graph. @property def value_info_proto(self): return self._value_info_proto def add_input_proto_explicit(self, name, shape, np_dtype=None, tf_dtype=None, onnx_dtype=None): onnx_dtype = any_dtype_to_onnx_dtype( np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype) input_proto = make_tensor_value_info(name, onnx_dtype, shape) self._inputs_proto.append(input_proto) def add_input_proto(self, node): name = node.name onnx_dtype = node.attr["dtype"] shape = node.attr["shape"] if node.op_type != "Const" else node.attr[ 'value'].shape self.add_input_proto_explicit(name, shape, onnx_dtype=onnx_dtype) def add_output_proto(self, node): output_onnx_type = node.attr.get("T", TensorProto.BOOL) for i, output_shape in enumerate(node.attr["_output_shapes"]): output_name = node.name + ":{}".format(i) if i > 0 else node.name self._outputs_proto.append( make_tensor_value_info(output_name, output_onnx_type, output_shape)) def add_node_proto(self, node_proto): if not isinstance(node_proto, (list, tuple)): node_proto = [node_proto] self._nodes_proto.extend(node_proto) def remove_node_proto(self, names): if not isinstance(names, (list, tuple)): names = [names] self._nodes_proto = list( filter(lambda x: x.name not in names, self._nodes_proto)) def add_const_explicit(self, name, value): self._consts[name] = value def add_const(self, node): self.add_const_explicit(node.name, node.attr["value"]) def add_const_proto_explicit(self, name, value, np_dtype=None, tf_dtype=None, onnx_dtype=None): onnx_dtype = any_dtype_to_onnx_dtype( np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype) const_dim = len(value.shape) if const_dim == 0: raw_values = [value.tolist()] values = [value] else: raw_values = value.flatten().tolist() values = value shape = np.array(values).shape const_proto = make_tensor( name=name, data_type=onnx_dtype, dims=shape, vals=raw_values) self._consts_proto.append(const_proto) def add_const_proto(self, node): self.add_const_proto_explicit( node.name, node.attr["value"], onnx_dtype=node.attr["dtype"]) def add_value_info_proto(self, node): node_onnx_type = node.attr.get("T", TensorProto.BOOL) for i, output_shape in enumerate(node.attr["_output_shapes"]): node_name = node.name + ":{}".format(i) if i > 0 else node.name value_info_proto = make_tensor_value_info(node_name, node_onnx_type, output_shape) self._value_info_proto.append(value_info_proto) # Remove proto in inputs_proto and consts_proto # if proto is not used as input or an output in ONNX def _clean_graph(self): in_out = self.all_node_inputs + self.outputs self._inputs_proto = list( filter(lambda x: x.name in in_out, self.inputs_proto)) self._consts_proto = list( filter(lambda x: x.name in in_out, self.consts_proto)) def _fix_data_type(self): self.inputs_proto = self._data_type_caster(self.inputs_proto, self.data_type_cast_map) self.consts_proto = self._data_type_caster(self.consts_proto, self.data_type_cast_map) @classmethod def _data_type_caster(cls, protos, data_type_cast_map): """Cast to a new data type if node name is in data_type_cast_map. Be used to process protos to match ONNX type constraints. :param protos: Target protos. TensorProto for inputs and ValueInfoProto for consts. :param data_type_cast_map: A {node.name: new_data_type} dict. :return: Processed protos. """ if not data_type_cast_map: return protos result = [] for proto in protos: new_proto = proto if proto.name in data_type_cast_map: new_data_type = data_type_cast_map[proto.name] if type(proto) == TensorProto and proto.data_type != new_data_type: field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[ mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]] vals = getattr(proto, field) new_proto = make_tensor( name=proto.name, data_type=new_data_type, dims=proto.dims, vals=vals) elif type( proto ) == ValueInfoProto and proto.type.tensor_type.elem_type != new_data_type: new_proto.type.tensor_type.elem_type = new_data_type result.append(new_proto) return result def make_graph_proto(self): self._clean_graph() self._fix_data_type() if IS_PYTHON3: params = list(inspect.signature(make_graph).parameters.keys()) else: params = inspect.getargspec(make_graph).args kwargs = { "initializer": self.consts_proto, "value_info": self.value_info_proto } return make_graph(self.nodes_proto, self._name, self.inputs_proto, self.outputs_proto, **dict([(k, kwargs[k]) for k in kwargs if k in params]))