add pt2tf tool

This commit is contained in:
zxros10
2020-09-23 09:09:49 +08:00
parent 7f7b7df65d
commit 18aefa4dd0
407 changed files with 16211 additions and 0 deletions
@@ -0,0 +1,2 @@
from . import backend
from .version import version as __version__
+352
View File
@@ -0,0 +1,352 @@
"""Backend for running ONNX on Tensorflow
To run this, you will need to have Tensorflow installed as well.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
try:
from itertools import izip as zip
except ImportError: # will be 3.x series
pass
from onnx import defs
from onnx import numpy_helper
from onnx.backend.base import Backend
from onnx.backend.base import Device
from onnx.backend.base import namedtupledict
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
from onnx.helper import make_opsetid
import tensorflow as tf
from onnx_tf.backend_rep import TensorflowRep
from onnx_tf.common import data_type
from onnx_tf.common import get_device_option
from onnx_tf.common import get_unique_suffix
from onnx_tf.common import supports_device as common_supports_device
from onnx_tf.common.handler_helper import get_all_backend_handlers
from onnx_tf.pb_wrapper import OnnxNode
import onnx_tf.common as common
class TensorflowBackend(Backend):
""" Tensorflow Backend for ONNX
"""
@classmethod
def prepare(cls,
model,
device='CPU',
strict=True,
logging_level='INFO',
**kwargs):
"""Prepare an ONNX model for Tensorflow Backend.
This function converts an ONNX model to an internel representation
of the computational graph called TensorflowRep and returns
the converted representation.
:param model: The ONNX model to be converted.
:param device: The device to execute this model on.
:param strict: Whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
:param logging_level: The logging level, default is INFO. Change it to DEBUG
to see more conversion details or to WARNING to see less
:returns: A TensorflowRep class object representing the ONNX model
"""
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
return cls.onnx_model_to_tensorflow_rep(model, strict)
@classmethod
def onnx_model_to_tensorflow_rep(cls, model, strict):
""" Convert ONNX model to TensorflowRep.
:param model: ONNX ModelProto object.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:return: TensorflowRep object.
"""
# Models with IR_VERSION less than 3 does not have opset_import set.
# We default to minimum opset, this behavior is consistent with
# onnx checker.
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
if model.ir_version < 3:
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
else:
opset_import = model.opset_import
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
@classmethod
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
""" Convert ONNX graph to TensorflowRep.
:param graph_def: ONNX GraphProto object.
:param opset: ONNX OperatorSetIdProto list.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:return: TensorflowRep object.
"""
handlers = cls._get_handlers(opset)
tf_rep_graph = tf.Graph()
with tf_rep_graph.as_default():
# initializer: TensorProtos representing the values to initialize
# a given tensor.
# initialized: A list of names of the initialized tensors.
if graph_def.initializer:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer)
initialized = {init.name for init in graph_def.initializer}
else:
input_dict_items = []
initialized = set()
# creating placeholders for currently unknown inputs
for value_info in graph_def.input:
if value_info.name in initialized:
continue
shape = list(
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in value_info.type.tensor_type.shape.dim)
value_info_name = value_info.name.replace(
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name
x = tf.placeholder(data_type.onnx2tf(
value_info.type.tensor_type.elem_type),
name=value_info_name,
shape=shape)
input_dict_items.append((value_info.name, x))
# tensor dict: this dictionary is a map from variable names
# to the latest produced TF tensors of the given name.
# This dictionary will get updated as we build the graph to
# record the names of newly produced tensors.
tensor_dict = dict(input_dict_items)
# Since tensor dict may be updated, we need to keep a copy
# of the original input dict where we track the earliest
# defined tensors so we can have access to the placeholders
# to feed in input tensors when we run the graph.
input_dict = dict(input_dict_items)
for node in graph_def.node:
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
handlers,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)
tf_rep = TensorflowRep()
tf_rep.graph = tf_rep_graph
tf_rep.inputs = [
value_info.name
for value_info in graph_def.input
if value_info.name not in initialized
]
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
tf_rep.tensor_dict = tensor_dict
return tf_rep
@classmethod
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
""" Run ONNX node.
:param node: ONNX NodeProto object.
:param inputs: Inputs.
:param device: Device run on.
:param outputs_info: None.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowBackend, cls).run_node(node, inputs, device)
node_graph = tf.Graph()
with node_graph.as_default():
node = OnnxNode(node)
device_option = get_device_option(Device(device))
input_tensors = []
for i in inputs:
input_tensors.append(tf.constant(i))
if isinstance(inputs, dict):
feed_dict_raw = inputs
else:
assert len(node.inputs) == len(inputs)
feed_dict_raw = dict(zip(node.inputs, inputs))
# TODO: is constant the best way for feeding inputs?
input_dict = dict([
(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])
ops = cls._onnx_node_to_tensorflow_op(node, input_dict)
with tf.Session() as sess:
with tf.device(device_option):
sess.run(tf.global_variables_initializer())
output_vals = sess.run(ops)
return namedtupledict('Outputs', node.outputs)(*output_vals)
@classmethod
def _onnx_initializer_to_input_dict_items(cls, initializer):
""" Convert ONNX graph initializer to input dict items.
:param initializer: ONNX graph initializer, list of TensorProto.
:return: List of input dict items.
"""
def tensor2list(onnx_tensor):
# Use the onnx.numpy_helper because the data may be raw
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
def validate_initializer_name(name):
# Prepend a unique suffix if leading charater is "_"
name = get_unique_suffix() + name if name[0] is "_" else name
# Replace ":" with "_tf_" and append a unique suffix for
# traceability
return name.replace(
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
return [(init.name,
tf.constant(tensor2list(init),
shape=init.dims,
dtype=data_type.onnx2tf(init.data_type),
name=validate_initializer_name(init.name)))
for init in initializer]
@classmethod
def _onnx_node_to_tensorflow_op(cls,
node,
tensor_dict,
handlers=None,
opset=None,
strict=True):
"""
Convert onnx node to tensorflow op.
Args:
node: Onnx node object.
tensor_dict: Tensor dict of graph.
opset: Opset version of the operator set. Default 0 means using latest version.
strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Returns:
Tensorflow op
"""
handlers = handlers or cls._get_handlers(opset)
if handlers:
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
@classmethod
def _get_handlers(cls, opset):
""" Get all backend handlers with opset.
:param opset: ONNX OperatorSetIdProto list.
:return: All backend handlers.
"""
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
opset_dict = dict([(o.domain, o.version) for o in opset])
return get_all_backend_handlers(opset_dict)
@classmethod
def supports_device(cls, device):
return common_supports_device(device)
@classmethod
def onnx_graph_to_tensorflow_ops(cls,
subgraph,
input_values,
tensor_dict,
opset=None,
strict=True):
"""
Converts ONNX graph to Tensorflow operations
Args:
subgraph: the ONNX graph to be converted
input_values: dictionary with values/tensors to initialize
the subgraph inputs. if the subgraph.input
are send in as parameters then it is required,
otherwise this can be empty dictionary.
tensor_dict: the dictionary that contain values for all the
node.inputs in the subgraph that are not defined
in the subgraph or input_values.
opset: opset version of the operator set.
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
array of Tensorflow Tensors
"""
# get the subgraph.input from input_values
subgraph_tensor_dict = input_values.copy()
# get the rest of the subgraph input from tensor_dict
for i in subgraph.input:
if i.name not in subgraph_tensor_dict.keys():
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
# get the required initializer constant node(s) for the subgraph
# Need to get the initializer constant nodes from tensor_dict here
# because input from initializer will not be send in as inputs
# to the subgraph and those nodes are not in the subgraph
nodes_outputs = []
for node in subgraph.node:
for o_name in node.output:
nodes_outputs.append(o_name)
for node in subgraph.node:
for i_name in node.input:
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
):
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
subgraph_tensor_dict,
opset=opset,
strict=strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
subgraph_tensor_dict.update(curr_node_output_map)
return subgraph_tensor_dict
@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
"""
Converts ONNX graph to TensorflowRep
Args:
graph_def: the ONNX graph to be converted
strict: whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Returns:
TensorflowRep object.
"""
# get the opset of the installed ONNX
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
prepare = TensorflowBackend.prepare
run_node = TensorflowBackend.run_node
run_model = TensorflowBackend.run_model
supports_device = TensorflowBackend.supports_device
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep
@@ -0,0 +1,109 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import tensorflow as tf
from onnx.backend.base import BackendRep, namedtupledict
class TensorflowRep(BackendRep):
def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
super(TensorflowRep, self).__init__()
self._graph = graph
self._inputs = inputs or []
self._outputs = outputs or []
self._tensor_dict = tensor_dict or {}
@property
def graph(self):
return self._graph
@graph.setter
def graph(self, graph):
self._graph = graph
@property
def inputs(self):
return self._inputs
@inputs.setter
def inputs(self, inputs):
self._inputs = inputs
@property
def outputs(self):
return self._outputs
@outputs.setter
def outputs(self, outputs):
self._outputs = outputs
@property
def tensor_dict(self):
return self._tensor_dict
@tensor_dict.setter
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict
def run(self, inputs, **kwargs):
""" Run TensorflowRep.
:param inputs: Given inputs.
:param kwargs: Other args.
:return: Outputs.
"""
super(TensorflowRep, self).run(inputs, **kwargs)
# TODO: handle name scope if necessary
with self.graph.as_default():
with tf.Session() as sess:
if isinstance(inputs, dict):
feed_dict = inputs
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.inputs) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.inputs), ', '.join(self.inputs),
len(inputs)))
feed_dict = dict(zip(self.inputs, inputs))
else:
# single input
feed_dict = dict([(self.inputs[0], inputs)])
feed_dict = {
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
}
sess.run(tf.global_variables_initializer())
outputs = [self.tensor_dict[output] for output in self.outputs]
output_values = sess.run(outputs, feed_dict=feed_dict)
return namedtupledict('Outputs', self.outputs)(*output_values)
def export_graph(self, path):
"""Export backend representation to a Tensorflow proto file.
This function obtains the graph proto corresponding to the ONNX
model associated with the backend representation and serializes
to a protobuf file.
:param path: The path to the output TF protobuf file.
:returns: none.
"""
graph_proto = self.graph.as_graph_def()
# rename the output nodes
meaningful_names = {}
for output_name in self.outputs:
meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
for node in graph_proto.node:
if node.name in meaningful_names.keys():
node.name = meaningful_names[node.name]
file = open(path, "wb")
file.write(graph_proto.SerializeToString())
file.close()
+24
View File
@@ -0,0 +1,24 @@
import argparse
import sys
import onnx_tf.converter
def main():
args = sys.argv[1:]
parser = argparse.ArgumentParser(
description="ONNX-Tensorflow Command Line Interface")
parser.add_argument(
"command",
choices=["convert"],
help="Available commands.")
if len(args) == 0:
parser.parse_args(["-h"])
cli_tool = parser.parse_args([args[0]])
if cli_tool.command == "convert":
return onnx_tf.converter.main(args[1:])
if __name__ == '__main__':
main()
@@ -0,0 +1,196 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import inspect
import re
import sys
import uuid
import warnings
import logging
from onnx.backend.base import DeviceType
from tensorflow.python.client import device_lib
IS_PYTHON3 = sys.version_info > (3,)
logger = logging.getLogger('onnx-tf')
# create console handler and formatter for logger
console = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
logger.addHandler(console)
class Deprecated:
"""Add deprecated message when function is called.
Usage:
from onnx_tf.common import deprecated
@deprecated
def func():
pass
UserWarning: func is deprecated. It will be removed in future release.
@deprecated("Message")
def func():
pass
UserWarning: Message
@deprecated({"arg": "Message",
"arg_1": deprecated.MSG_WILL_REMOVE,
"arg_2": "",})
def func(arg, arg_1, arg_2):
pass
UserWarning: Message
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
UserWarning: arg_2 of func is deprecated.
"""
MSG_WILL_REMOVE = " It will be removed in future release."
def __call__(self, *args, **kwargs):
return self.deprecated_decorator(*args, **kwargs)
@staticmethod
def messages():
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
@staticmethod
def deprecated_decorator(arg=None):
# deprecate function with default message MSG_WILL_REMOVE
# @deprecated
if inspect.isfunction(arg):
def wrapper(*args, **kwargs):
warnings.warn("{} is deprecated.{}".format(
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
return arg(*args, **kwargs)
return wrapper
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
def deco(func):
# deprecate arg
# @deprecated({...})
if isinstance(deprecated_arg, dict):
for name, message in deprecated_arg.items():
if message in Deprecated.messages():
message = "{} of {} is deprecated.{}".format(
name, func.__module__ + "." + func.__name__, message or "")
warnings.warn(message)
# deprecate function with message
# @deprecated("message")
elif isinstance(deprecated_arg, str):
message = deprecated_arg
if message in Deprecated.messages():
message = "{} is deprecated.{}".format(
func.__module__ + "." + func.__name__, message)
warnings.warn(message)
return func
return deco
deprecated = Deprecated()
# This function inserts an underscore before every upper
# case letter and lowers that upper case letter except for
# the first letter.
def op_name_to_lower(name):
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
def get_unique_suffix():
""" Get unique suffix by using first 8 chars from uuid.uuid4
to make unique identity name.
:return: Unique suffix string.
"""
return str(uuid.uuid4())[:8]
def get_perm_from_formats(from_, to_):
""" Get perm from data formats.
For example:
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
:param from_: From data format string.
:param to_: To data format string.
:return: Perm. Int list.
"""
return list(map(lambda x: from_.find(x), to_))
# TODO: allow more flexible placement
def get_device_option(device):
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
return m[device.type]
def get_data_format(x_rank):
""" Get data format by input rank.
Channel first if support CUDA.
:param x_rank: Input rank.
:return: Data format.
"""
sp_dim_names = ["D", "H", "W"]
sp_dim_lst = []
for i in range(x_rank - 2):
sp_dim_lst.append(sp_dim_names[-i - 1])
sp_dim_string = "".join(reversed(sp_dim_lst))
storage_format = "NC" + sp_dim_string
if supports_device("CUDA"):
compute_format = "NC" + sp_dim_string
else:
compute_format = "N" + sp_dim_string + "C"
return storage_format, compute_format
def supports_device(device):
""" Check if support target device.
:param device: CUDA or CPU.
:return: If supports.
"""
if device == "CUDA":
local_device_protos = device_lib.list_local_devices()
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
]) > 0
elif device == "CPU":
return True
return False
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
deprecated.MSG_WILL_REMOVE,
"Use TensorflowGraph.get_outputs_names instead."))
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
nodes, input_names = dict(), set()
for node in graph_def.node:
nodes[node.name] = node
input_names.update(set(node.input))
return list(set(nodes) - input_names)
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
@@ -0,0 +1,86 @@
from onnx_tf.common import IS_PYTHON3
def convert_tf(attr):
return __convert_tf_attr_value(attr)
def convert_onnx(attr):
return __convert_onnx_attribute_proto(attr)
def __convert_tf_attr_value(attr):
""" convert Tensorflow AttrValue object to Python object
"""
if attr.HasField('list'):
return __convert_tf_list_value(attr.list)
if attr.HasField('s'):
return attr.s
elif attr.HasField('i'):
return attr.i
elif attr.HasField('f'):
return attr.f
elif attr.HasField('b'):
return attr.b
elif attr.HasField('type'):
return attr.type
elif attr.HasField('shape'):
return attr.type
elif attr.HasField('tensor'):
return attr.tensor
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
def __convert_tf_list_value(list_value):
""" convert Tensorflow ListValue object to Python object
"""
if list_value.s:
return list_value.s
elif list_value.i:
return list_value.i
elif list_value.f:
return list_value.f
elif list_value.b:
return list_value.b
elif list_value.tensor:
return list_value.tensor
elif list_value.type:
return list_value.type
elif list_value.shape:
return list_value.shape
elif list_value.func:
return list_value.func
else:
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
def __convert_onnx_attribute_proto(attr_proto):
"""
Convert an ONNX AttributeProto into an appropriate Python object
for the type.
NB: Tensor attribute gets returned as the straight proto.
"""
if attr_proto.HasField('f'):
return attr_proto.f
elif attr_proto.HasField('i'):
return attr_proto.i
elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto!
elif attr_proto.HasField('g'):
return attr_proto.g
elif attr_proto.floats:
return list(attr_proto.floats)
elif attr_proto.ints:
return list(attr_proto.ints)
elif attr_proto.strings:
str_list = list(attr_proto.strings)
if IS_PYTHON3:
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
return str_list
elif attr_proto.HasField('sparse_tensor'):
return attr_proto.sparse_tensor
else:
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
@@ -0,0 +1,37 @@
from tensorflow.python.framework.tensor_util import MakeNdarray
from onnx_tf.common import data_type
# Keyed by old attribute names.
__tf_attr_translator = {
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
"value": lambda x: MakeNdarray(x.tensor),
"seed2": lambda x: float(x.i),
"seed": lambda x: float(x.i),
"keep_dims": lambda x: int(x.b),
"squeeze_dims": lambda x: list(x.list.i),
}
__onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
}
def translate_tf(key, val):
return __tf_attr_translator.get(key, lambda x: x)(val)
def translate_onnx(key, val):
return __onnx_attr_translator.get(key, lambda x: x)(val)
def get_tf_shape_as_list(tf_shape_dim):
return list(map(lambda x: x.size, list(tf_shape_dim)))
@@ -0,0 +1,71 @@
from numbers import Number
import numpy as np
from onnx import mapping
from onnx import TensorProto
import tensorflow as tf
def tf2onnx(dtype):
if isinstance(dtype, Number):
tf_dype = tf.as_dtype(dtype)
elif isinstance(dtype, tf.DType):
tf_dype = dtype
elif isinstance(dtype, list):
return [tf2onnx(t) for t in dtype]
else:
raise RuntimeError("dtype should be number or tf.DType.")
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
# to leverage existing type conversion infrastructure;
# However, we need to intercept the string type early because
# lowering tf.string type to numpy dtype results in loss of
# information. <class 'object'> is returned instead of the
# numpy string type desired.
if tf_dype is tf.string:
return TensorProto.STRING
onnx_dtype = None
try:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
tf_dype.as_numpy_dtype)]
finally:
if onnx_dtype is None:
common.logger.warning(
"Can't convert tf dtype {} to ONNX dtype. Return 0 (TensorProto.UNDEFINED)."
.format(tf_dype))
onnx_dtype = TensorProto.UNDEFINED
return onnx_dtype
def onnx2tf(dtype):
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
def onnx2field(dtype):
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
def _onnx_dtype(dtype):
if isinstance(dtype, Number):
onnx_dype = dtype
elif isinstance(dtype, str):
onnx_dype = TensorProto.DataType.Value(dtype)
else:
raise RuntimeError("dtype should be number or str.")
return onnx_dype
# TODO (tjingrant) unify _onnx_dtype into any_dtype_to_onnx_dtype
def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
dtype_mask = [1 if val else 0 for val in [np_dtype, tf_dtype, onnx_dtype]]
num_type_set = sum(dtype_mask)
assert num_type_set == 1, "One and only one type must be set. However, {} set.".format(
sum(num_type_set))
if np_dtype:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
if tf_dtype:
onnx_dtype = tf2onnx(tf_dtype)
return onnx_dtype
@@ -0,0 +1,73 @@
import inspect
import onnx_tf.common as common
class CustomException(object):
def __init__(self):
self._func = RuntimeError
self._message = ""
def __call__(self, *args, **kwargs):
if inspect.isclass(self._func) and issubclass(self._func, Exception):
raise self._func(self.get_message(*args, **kwargs))
elif callable(self._func):
self._func(self.get_message(*args, **kwargs))
def get_message(self, *args, **kwargs):
return self._message
class OpUnimplementedException(CustomException):
def __init__(self):
super(OpUnimplementedException, self).__init__()
self._func = NotImplementedError
self._message = "{} is not implemented."
def __call__(self, op, version=None, domain=None):
if IGNORE_UNIMPLEMENTED:
self._func = common.logger.warning
super(OpUnimplementedException, self).__call__(op, version, domain)
def get_message(self, op, version=None, domain=None):
insert_message = op
if version is not None:
insert_message += " version {}".format(version)
if domain is not None:
insert_message += " in domain `{}`".format(domain)
return self._message.format(insert_message)
class OpUnsupportedException(object):
def __init__(self):
super(OpUnsupportedException, self).__init__()
self._func = RuntimeError
self._message = "{} is not supported in {}."
def __call__(self, op, framework):
raise self._func(self.get_message(op, framework))
def get_message(self, op, framework):
return self._message.format(op, framework)
class ConstNotFoundException(CustomException):
def __init__(self):
super(ConstNotFoundException, self).__init__()
self._func = RuntimeError
self._message = "{} of {} is not found in graph consts."
def __call__(self, name, op):
super(ConstNotFoundException, self).__call__(name, op)
def get_message(self, name, op):
return self._message.format(name, op)
IGNORE_UNIMPLEMENTED = False
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
@@ -0,0 +1,76 @@
from onnx import defs
import onnx_tf.common as common
from onnx_tf.handlers.backend import * # noqa
from onnx_tf.handlers.backend_handler import BackendHandler
import onnx_tf.common as common
def get_all_backend_handlers(opset_dict):
""" Get a dict of all backend handler classes.
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
:param opset_dict: A dict of opset. e.g. {'domain': version, ...}
:return: Dict.
"""
handlers = {}
for handler in BackendHandler.__subclasses__():
handler.check_cls()
domain = handler.DOMAIN
version = opset_dict[domain] if domain in opset_dict else 1
handler.VERSION = version
since_version = 1
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
try:
since_version = defs.get_schema(
handler.ONNX_OP,
domain=handler.DOMAIN,
max_inclusive_version=version).since_version
except RuntimeError:
common.logger.debug("Fail to get since_version of {} in domain `{}` "
"with max_inclusive_version={}. Set to 1.".format(
handler.ONNX_OP, handler.DOMAIN, version))
else:
common.logger.debug("Unknown op {} in domain `{}`.".format(
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
handler.SINCE_VERSION = since_version
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
return handlers
def get_backend_coverage():
""" Get backend coverage for document.
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
"""
onnx_coverage = {}
experimental_op = set()
for handler in BackendHandler.__subclasses__():
handler.check_cls()
versions = handler.get_versions()
domain = handler.DOMAIN
if getattr(handler, "EXPERIMENTAL", False):
experimental_op.add(handler.ONNX_OP)
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
return onnx_coverage, experimental_op
def _update_coverage(coverage, domain, key, versions):
domain_coverage = coverage.setdefault(domain, {})
vers = domain_coverage.get(key, [])
vers.extend(versions)
domain_coverage[key] = sorted(list(set(vers)))
def get_backend_partial_support_detail():
ps_dict = {}
opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())])
handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN]
for op_name in handlers:
if handlers[op_name].PARTIAL_SUPPORT:
ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION
return ps_dict
@@ -0,0 +1,21 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import onnx
def get_onnx_version():
return tuple(map(int, onnx.version.version.split(".")))
# Returns whether onnx version is prior to major.minor.patch
def legacy_onnx_pre_ver(major=0, minor=0, patch=0):
return get_onnx_version() < (major, minor, patch)
# Returns whether the opset version accompanying the
# onnx installation is prior to version passed.
def legacy_opset_pre_ver(version):
return onnx.defs.onnx_opset_version() < version
@@ -0,0 +1,263 @@
from __future__ import division
from collections import namedtuple
import numpy as np
import tensorflow as tf
import itertools
pad_ops = namedtuple("pad_ops",
["max_op", "ceil_op", "floor_op", "cast_int_op"])
pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
lambda arr: arr.astype(np.int64))
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
lambda tensor: tf.cast(tensor, tf.int64))
def calc_pads_same(in_spatial_shape, kernel_shape, strides,
dilations, padding, padding_ops=pad_numpy_ops,
pads_order=1):
"""
Calculates the SAME paddings that need to be added to the input
Args:
in_spatial_shape: input spatial shape
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis
padding: padding to calculate: SAME_UPPER or
SAME_LOWER
padding_ops: namedtuple with ops to be used during
calculations. there are two sets of ops
defined pad_numpy_ops and pad_tf_ops with
numpy and tensorflow ops
pads_order: order of returned pads. possible options are:
1 - b1, b2, ..., bn, e1, e2, ..., en
2 - b1, e1, b2, e2, ..., bn, en
where n = len(kernel_shape) * 2,
b1, b2, ..., bn define pads at the begging of
axis
e1, e2, ..., en define pads at the end of
axis
Return:
pads: array with calculated pads. the order of the
values is determined by `pads_order`
"""
spatial_size = len(kernel_shape)
pads = [0] * (spatial_size * 2)
for i in range(spatial_size):
in_size = in_spatial_shape[i]
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1
out_size = padding_ops.ceil_op(in_size / strides[i])
out_size = padding_ops.cast_int_op(out_size)
pad_along_axis = \
padding_ops.max_op((out_size - 1) * strides[i] +
filter_size - in_size, 0)
if padding.lower() == "same_lower":
pad_op = padding_ops.ceil_op
else:
pad_op = padding_ops.floor_op
pad_begin = pad_op(pad_along_axis / 2)
pad_begin = padding_ops.cast_int_op(pad_begin)
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)
pad_end = pad_along_axis - pad_begin
pads[i * pads_order] = pad_begin
pads[i * pads_order +
(spatial_size if pads_order == 1 else 1)] = pad_end
return pads
def calc_output_shape(input_spatial_shape, kernel_shape, strides, dilations,
padding, ceil_mode=False):
"""
Calculate output shape
Args:
input_spatial_shape: input spatial shape
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis
padding: can be explicit paddings, "SAME_UPPER" or
"SAME_LOWER"
Return:
output_shape: calculated output shape
"""
spatial_size = len(input_spatial_shape)
if type(padding) is not list and type(padding) is not np.ndarray:
if padding.lower().startswith("same"):
padding = calc_pads_same(input_spatial_shape, kernel_shape,
strides, dilations, padding)
else:
padding = [0] * spatial_size * 2
output_shape = []
for dim in range(spatial_size):
output_shape.append(_pooling_output_shape(input_spatial_shape[dim],
kernel_shape[dim], strides[dim], dilations[dim],
padding[dim] + padding[dim + spatial_size],
ceil_mode))
return output_shape
def _pooling_output_shape(input_size, ksize, stride, dilation, pad, ceil_mode):
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
((stride-1) if ceil_mode else 0)) // stride + 1
if (pad):
if ((output_size - 1) * stride >= input_size + pad):
output_size -= 1
return output_size
def py_pool(input, kernel_shape, strides=None, dilations=None,
padding=None, ceil_mode=False, pooling_type="MAX",
include_indices=True, p=2):
"""
Implementation of Max and Average pool operations in Python
Args:
input: input N-D data array in NC* format
kernel_shape: the size of the kernel along each axis
strides: stride along each spatial axis
dilations: dilations value along each spatial axis of filter
padding: padding for the beginning and ending along each
spatial axis. `padding` format should be as follow
[x1_begin, x2_begin...x1_end, x2_end,...]
ceil_mode: whether to use ceil or floor (default) to compute
the output shape.
pooling_type: specifies pooling type. Values can be "MAX", "AVG" or
"LP"
include_indices: should indices be included in the output
p: specifies the p parameter for LpPooling
Return:
pooled: output data from max pooling across the input
ind: indices of the selected max values from the input
"""
if type(pooling_type) is not str:
pooling_type = pooling_type.decode("UTF-8")
input_shape = np.shape(input)
inp_sp_shape = input_shape[2:]
input_dtype = input.dtype
if np.issubdtype(input_dtype, np.integer):
input_dtype_min = np.iinfo(input_dtype).min
else:
input_dtype_min = np.finfo(input_dtype).min
if pooling_type == "LP":
rootN = (1.0 / p)
def _loop_over_output(batch, channel):
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
for counters in itertools.product(*dims):
input_ranges = []
for dim in range(spatial_size):
dim_start = \
counters[dim] * strides[dim] - pads[dim * 2]
dim_end = \
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
+ 1, inp_sp_shape[dim])
while dim_start < 0:
dim_start += dilations[dim]
cur_range = [i for i in range(dim_start,
dim_end, dilations[dim])]
input_ranges.append(cur_range)
if pooling_type in ["AVG", "LP"]:
val_sum = 0
val_count = 0
else:
maxval = input_dtype_min
maxind = -1
for input_ind in itertools.product(*input_ranges):
ind = (batch, channel) + input_ind
val = input[ind]
if pooling_type == "AVG":
val_sum += val
val_count += 1
elif pooling_type == "LP":
val_sum += abs(val ** p)
else:
if val > maxval:
maxval = val
ind = 0
for i in range(spatial_size):
coef = 1
for j in range(i+1, spatial_size):
coef *= inp_sp_shape[j]
ind += input_ind[i] * coef
maxind = ind
ind = (batch, channel) + counters
if pooling_type == "AVG":
out_pool[ind] = val_sum / val_count
elif pooling_type == "LP":
out_pool[ind] = val_sum ** rootN
else:
out_pool[ind] = maxval
out_ind[ind] = maxind
spatial_size = len(kernel_shape)
batch_size = input_shape[0]
channels_num = input_shape[1]
if strides is None:
strides = kernel_shape
if dilations is None:
dilations = [1] * spatial_size
if padding is None:
padding = [0] * spatial_size * 2
if type(padding) is bytes:
padding = padding.decode()
if type(padding) is not list and type(padding) is not np.ndarray:
if type(padding) is not str:
padding = padding.decode("UTF-8")
if padding.lower().startswith("same"):
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
dilations, padding)
else:
padding = [0] * spatial_size * 2
pads = []
pad_along_axis = []
output_sp_shape = []
for dim in range(spatial_size):
pads.append(padding[dim])
pads.append(padding[dim + spatial_size])
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])
input_size = input_shape[dim + 2]
output_size = \
_pooling_output_shape(input_size, kernel_shape[dim],
strides[dim], dilations[dim],
pad_along_axis[dim], ceil_mode)
output_sp_shape.append(output_size)
out_pool = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, input_dtype)
out_ind = np.zeros([input_shape[0], input_shape[1]] +
output_sp_shape, np.int64)
for batch in range(batch_size):
for channel in range(channels_num):
_loop_over_output(batch, channel)
if not include_indices:
return out_pool
else:
return out_pool, out_ind
@@ -0,0 +1,45 @@
import tensorflow as tf
import numpy as np
def tf_shape(tensor):
"""
Helper function returning the shape of a Tensor.
The function will check for fully defined shape and will return
numpy array or if the shape is not fully defined will use tf.shape()
to return the shape as a Tensor.
"""
if tensor.shape.is_fully_defined():
return np.array(tensor.shape.as_list(), dtype=np.int64)
else:
return tf.shape(tensor, out_type=tf.int64)
def tf_product(a, b):
"""
Calculates the cartesian product of two column vectors a and b
Example:
a = [[1]
[2]
[3]]
b = [[0]
[1]]
result = [[1 0]
[1 1]
[2 0]
[2 1]
[3 0]
[3 1]]
"""
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
tile_a = tf.expand_dims(tile_a, 2)
tile_a = tf.reshape(tile_a, [-1, 1])
b = tf.tile(b, [tf.shape(a)[0], 1])
b = tf.concat([tile_a, b], axis=1)
return b
+138
View File
@@ -0,0 +1,138 @@
import argparse
import inspect
import logging
import os
import shutil
import onnx
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.tools import freeze_graph
import onnx_tf.backend as backend
import onnx_tf.common as common
from onnx_tf.common import get_unique_suffix
from onnx_tf.pb_wrapper import TensorflowGraph
def main(args):
args = parse_args(args)
convert(**{k: v for k, v in vars(args).items() if v is not None})
def parse_args(args):
class ListAction(argparse.Action):
""" Define how to convert command line list strings to Python objects.
"""
def __call__(self, parser, namespace, values, option_string=None):
values = values if values[0] not in ("(", "[") or values[-1] not in (
")", "]") else values[1:-1]
res = []
for value in values.split(","):
if value.isdigit():
res.append(int(value))
else:
res.append(value)
setattr(namespace, self.dest, res)
class OpsetAction(argparse.Action):
""" Define how to convert command line opset strings to Python objects.
"""
def __call__(self, parser, namespace, values, option_string=None):
if values.isdigit():
setattr(namespace, "opset", int(values))
else:
res = []
while values and values[0] in ("(", "["):
values = values[1:]
while values and values[-1] in (")", "]"):
values = values[:-1]
for value in values.split("),("):
l, r = value.split(",")
res.append((l, int(r)))
setattr(namespace, "opset", res)
def get_param_doc_dict(funcs):
"""Get doc of funcs params.
Args:
funcs: Target funcs.
Returns:
Dict of params doc.
"""
# TODO(fumihwh): support google doc format
def helper(doc, func):
first_idx = doc.find(":param")
last_idx = doc.find(":return")
last_idx = last_idx if last_idx != -1 else len(doc)
param_doc = doc[first_idx:last_idx]
params_doc = param_doc.split(":param ")[1:]
return {
p[:p.find(": ")]: p[p.find(": ") + len(": "):] +
" (from {})".format(func.__module__ + "." + func.__name__)
for p in params_doc
}
param_doc_dict = {}
for func, persists in funcs:
doc = inspect.getdoc(func)
doc_dict = helper(doc, func)
for k, v in doc_dict.items():
if k not in persists:
continue
param_doc_dict[k] = {"doc": v, "params": persists[k]}
return param_doc_dict
parser = argparse.ArgumentParser(
description=
"This is the converter for converting protocol buffer between tf and onnx."
)
# required two args, source and destination path
parser.add_argument("--infile", "-i", help="Input file path.", required=True)
parser.add_argument(
"--outfile", "-o", help="Output file path.", required=True)
def add_argument_group(parser, group_name, funcs):
group = parser.add_argument_group(group_name)
param_doc_dict = get_param_doc_dict(funcs)
for k, v in param_doc_dict.items():
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])
# backend args
# Args must be named consistently with respect to backend.prepare.
add_argument_group(parser, "backend arguments (onnx -> tf)",
[(backend.prepare, {
"device": {},
"strict": {},
"logging_level": {}
})])
return parser.parse_args(args)
def convert(infile, outfile, **kwargs):
"""Convert pb.
Args:
infile: Input path.
outfile: Output path.
**kwargs: Other args for converting.
Returns:
None.
"""
logging_level = kwargs.get("logging_level", "INFO")
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.logger.info("Start converting onnx pb to tf pb:")
onnx_model = onnx.load(infile)
tf_rep = backend.prepare(onnx_model, **kwargs)
tf_rep.export_graph(outfile)
common.logger.info("Converting completes successfully.")
+74
View File
@@ -0,0 +1,74 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import re
import subprocess
import onnx_tf.backend
import onnx_tf.backend_rep
from third_party import get_info
def main(docs_dir):
gen_api(docs_dir)
gen_cli(docs_dir)
def gen_api(docs_dir):
gen_doc_for = {
'onnx_tf.backend': [
onnx_tf.backend.prepare,
],
'onnx_tf.backend_rep.TensorflowRep': [
onnx_tf.backend_rep.TensorflowRep.export_graph,
]
}
with open(os.path.join(docs_dir, 'API.md'), 'w') as doc_file:
doc_file.write('ONNX-Tensorflow API\n')
doc_file.write('======\n\n')
for scope, funcs in sorted(gen_doc_for.items()):
for func in funcs:
doc_parsed = get_info.parse_docstring(func.__doc__)
doc_file.write('#### `' + scope + '.' + func.__name__ + '`\n\n')
doc_file.write('<details>\n')
doc_file.write(' <summary>')
doc_file.write(doc_parsed['short_description'] + '\n\n')
doc_file.write(' </summary>\n')
doc_file.write(doc_parsed['long_description'] + '\n\n')
doc_file.write('</details>\n\n\n\n')
doc_file.write('_params_:\n\n')
for param in doc_parsed['params']:
doc_file.write('`' + param['name'] + '` : ' + param['doc'] + '\n\n')
doc_file.write('_returns_:\n\n')
doc_file.write(doc_parsed['returns'] + '\n\n')
def gen_cli(docs_dir):
with open(os.path.join(docs_dir, 'CLI_template.md'), 'r') as cli_temp_file:
temp_lines = cli_temp_file.readlines()
lines = []
for line in temp_lines:
matched = re.match(r"{onnx-tf.*}", line)
if matched:
command = matched.string.strip()[1:-1]
output = subprocess.check_output(command.split(" ")).decode("UTF-8")
lines.append(output)
else:
lines.append(line)
with open(os.path.join(docs_dir, 'CLI.md'), 'w') as cli_file:
cli_file.writelines(lines)
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
docs_dir = os.path.join(base_dir, 'doc')
main(docs_dir)
@@ -0,0 +1,35 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import pprint
from onnx import defs
from onnx_tf.common.handler_helper import get_backend_coverage
from onnx_tf.common.handler_helper import get_backend_partial_support_detail
def main():
backend_opset_dict = {}
for schema in defs.get_all_schemas():
op_name = schema.name
backend_opset_dict[op_name] = []
backend_onnx_coverage, backend_experimental_op = get_backend_coverage()
backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
backend_ps_dict = get_backend_partial_support_detail()
with open('opset_version.py', 'w') as version_file:
pp = pprint.PrettyPrinter(indent=4)
version_file.write("backend_opset_version = {\n " +
pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
version_file.write("backend_partial_support = {\n " +
pp.pformat(backend_ps_dict)[1:-1] + "\n}\n")
if __name__ == '__main__':
main()
+231
View File
@@ -0,0 +1,231 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import getopt
import os
import subprocess
import sys
import onnx
import tensorflow as tf
from onnx_tf import opset_version, __version__
def main(docs_dir):
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
docs_dir = os.path.join(base_dir, 'doc')
onnx_version = onnx.__version__
onnx_tf_release_build = False
try:
opts, args = getopt.getopt(sys.argv[1:], 'h:mr',
['onnx_master', 'onnx_tf_release_build'])
except getopt.GetoptError:
print('Usage:')
print(' gen_status.py [-m -r]')
print(' gen_status.py -h')
print('Description:')
print(' -m, --onnx_master installed ONNX is the latest master code')
print(' if omitted, ONNX version is onnx.__version__')
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
print(' stated in the VERSION_NUMBER file')
print(' if omitted, the report is for ONNX-TF master')
print(' -h show this help message and exit')
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
print(' gen_status.py')
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
print(' gen_status.py -m')
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
print(' gen_status.py -r')
sys.exit(2)
for opt, arg in opts:
if opt == '-h':
print('Usage:')
print(' gen_status.py [-m -r]')
print(' gen_status.py -h')
print('Description:')
print(' -m, --onnx_master installed ONNX is the latest master code')
print(' if omitted, ONNX version is onnx.__version__')
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
print(' stated in the VERSION_NUMBER file')
print(' if omitted, the report is for ONNX-TF master')
print(' -h show this help message and exit')
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
print(' gen_status.py')
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
print(' gen_status.py -m')
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
print(' gen_status.py -r')
sys.exit()
elif opt in ('-m', '--onnx_master'):
onnx_version = 'master'
elif opt in ('-r', '--onnx_tf_release_build'):
onnx_tf_release_build = True
gen_support_status(docs_dir, onnx_version, onnx_tf_release_build)
def gen_support_status(docs_dir, onnx_version, onnx_tf_release_build):
# set filename
if onnx_tf_release_build:
onnx_tf_version = 'v' + __version__
filename = 'support_status_' + onnx_tf_version.replace('.', '_') + '.md'
else: # onnx-tf = master
# get onnx-tf commit id
onnx_tf_commit_id = subprocess.check_output('git rev-parse HEAD',
shell=True)
onnx_tf_commit_id = onnx_tf_commit_id.decode().strip('\n')
onnx_tf_version = 'Master ( commit id: {} )'.format(onnx_tf_commit_id)
filename = 'support_status.md'
with open(os.path.join(docs_dir, filename), 'w') as status_file:
status_file.write('# ONNX-Tensorflow Support Status\n')
status_file.write('|||\n')
status_file.write('|-:|:-|\n')
status_file.write('|ONNX-Tensorflow Version|{}|\n'.format(onnx_tf_version))
# get onnx commit id
if onnx_version == 'master':
onnx_commit_id = onnx.version.git_version
status_file.write(
'|ONNX Version|Master ( commit id: {} )|\n'.format(onnx_commit_id))
else:
status_file.write('|ONNX Version|v{}|\n'.format(onnx_version))
# get tf_version
status_file.write('|Tensorflow Version|v{}|\n\n'.format(tf.__version__))
# display the table legend
status_file.write('Notes:\n')
status_file.write('* Values that are new or updated from a ')
status_file.write('previous opset version are in bold.\n')
status_file.write('* -: not defined in corresponding ONNX ')
status_file.write('opset version\n')
status_file.write('* \*: the operator is deprecated\n')
status_file.write('* :small_red_triangle:: not supported yet\n')
status_file.write('* :small_orange_diamond:: partially supported\n')
status_file.write('* the rest are all supported\n\n')
# get oll onnx ops
onnx_ops = {}
for schema in onnx.defs.get_all_schemas():
if schema.domain == '': # only get onnx ops
onnx_ops[schema.name] = {
'versions': [],
'deprecated': schema.since_version if schema.deprecated else -1
}
for schema in onnx.defs.get_all_schemas_with_history():
if schema.domain == '': # only get onnx ops
op = onnx_ops[schema.name]
if schema.deprecated:
if schema.since_version <= op['deprecated']:
op['versions'].append(schema.since_version)
op['deprecated'] = schema.since_version
else:
op['versions'].append(schema.since_version)
# get all onnx-tf supported ops
onnx_tf_ops = opset_version.backend_opset_version
onnx_tf_ops_ps = opset_version.backend_partial_support
# get the cureent opset version
current_opset = onnx.defs.onnx_opset_version()
# setup table header
status_file.write('|||')
for i in range(current_opset):
status_file.write('|')
status_file.write('\n|:-:|:-:|')
for i in range(current_opset):
status_file.write(':-:|')
status_file.write('\n|**ONNX Operator**|')
for opset in range(1, current_opset + 1):
status_file.write('**Opset {}**|'.format(opset))
status_file.write('**ONNX Operator**|')
ops_count = len(onnx_ops)
# fill in data for the table
for key, val in sorted(onnx_ops.items()):
try:
status_file.write('\n|{}|'.format(key))
i = 0
vers = val['versions']
deprecated = val['deprecated']
for opset in range(1, current_opset + 1):
if i <= len(vers) - 1:
lb = vers[i]
ub = vers[i + 1] if i < len(vers) - 1 else vers[i]
if opset < lb:
if i == 0:
status_file.write('-')
elif opset == lb:
status_file.write('**{}**'.format(lb))
if lb >= deprecated and deprecated > 0:
status_file.write('\*')
elif lb not in onnx_tf_ops[key]:
status_file.write(':small_red_triangle:')
if opset == current_opset:
ops_count -= 1
elif key in onnx_tf_ops_ps:
status_file.write(':small_orange_diamond:')
else: # opset > lb
if opset < ub:
status_file.write('{}'.format(lb))
if lb >= deprecated and deprecated > 0:
status_file.write('\*')
elif lb not in onnx_tf_ops[key]:
status_file.write(':small_red_triangle:')
if opset == current_opset:
ops_count -= 1
elif key in onnx_tf_ops_ps:
status_file.write(':small_orange_diamond:')
elif opset == ub:
status_file.write('**{}**'.format(ub))
if ub >= deprecated and deprecated > 0:
status_file.write('\*')
elif ub not in onnx_tf_ops[key]:
status_file.write(':small_red_triangle:')
if opset == current_opset:
ops_count -= 1
elif key in onnx_tf_ops_ps:
status_file.write(':small_orange_diamond:')
i += 1
else: #opset > ub
status_file.write('{}'.format(ub))
if ub >= deprecated and deprecated > 0:
status_file.write('\*')
elif ub not in onnx_tf_ops[key]:
status_file.write(':small_red_triangle:')
if opset == current_opset:
ops_count -= 1
elif key in onnx_tf_ops_ps:
status_file.write(':small_orange_diamond:')
status_file.write('|')
status_file.write('{}|'.format(key))
except:
# ops defined in onnx but not in opset_version.backend_opset_versionn
status_file.write(':small_red_triangle:|')
status_file.write(
'\n\nONNX-TF Supported Operators / ONNX Operators: {} / {}'.format(
ops_count, len(onnx_ops)))
# display partial support footnote
status_file.write('\n\nNotes:\n')
index = 1
for key in onnx_tf_ops_ps:
status_file.write(
str(index) + '. ' + key + ': ' + onnx_tf_ops_ps[key] + '\n')
index += 1
if __name__ == '__main__':
main(sys.argv[1:])
@@ -0,0 +1,7 @@
import os
import pkgutil
__all__ = [
modname for _, modname, _ in pkgutil.walk_packages(
path=[os.path.split(__file__)[0]])
]

Some files were not shown because too many files have changed in this diff Show More