重命名 pt2tf 为 pt2pb
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from . import backend
|
||||
from .version import version as __version__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,352 @@
|
||||
"""Backend for running ONNX on Tensorflow
|
||||
|
||||
To run this, you will need to have Tensorflow installed as well.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
try:
|
||||
from itertools import izip as zip
|
||||
except ImportError: # will be 3.x series
|
||||
pass
|
||||
|
||||
from onnx import defs
|
||||
from onnx import numpy_helper
|
||||
from onnx.backend.base import Backend
|
||||
from onnx.backend.base import Device
|
||||
from onnx.backend.base import namedtupledict
|
||||
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
|
||||
from onnx.helper import make_opsetid
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf.backend_rep import TensorflowRep
|
||||
from onnx_tf.common import data_type
|
||||
from onnx_tf.common import get_device_option
|
||||
from onnx_tf.common import get_unique_suffix
|
||||
from onnx_tf.common import supports_device as common_supports_device
|
||||
from onnx_tf.common.handler_helper import get_all_backend_handlers
|
||||
from onnx_tf.pb_wrapper import OnnxNode
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class TensorflowBackend(Backend):
|
||||
""" Tensorflow Backend for ONNX
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def prepare(cls,
|
||||
model,
|
||||
device='CPU',
|
||||
strict=True,
|
||||
logging_level='INFO',
|
||||
**kwargs):
|
||||
"""Prepare an ONNX model for Tensorflow Backend.
|
||||
|
||||
This function converts an ONNX model to an internel representation
|
||||
of the computational graph called TensorflowRep and returns
|
||||
the converted representation.
|
||||
|
||||
:param model: The ONNX model to be converted.
|
||||
:param device: The device to execute this model on.
|
||||
:param strict: Whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged.
|
||||
Currently, the strict flag only affects the behavior of MaxPool and AveragePool ops.
|
||||
:param logging_level: The logging level, default is INFO. Change it to DEBUG
|
||||
to see more conversion details or to WARNING to see less
|
||||
|
||||
:returns: A TensorflowRep class object representing the ONNX model
|
||||
"""
|
||||
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
|
||||
common.logger.setLevel(logging_level)
|
||||
common.logger.handlers[0].setLevel(logging_level)
|
||||
|
||||
return cls.onnx_model_to_tensorflow_rep(model, strict)
|
||||
|
||||
@classmethod
|
||||
def onnx_model_to_tensorflow_rep(cls, model, strict):
|
||||
""" Convert ONNX model to TensorflowRep.
|
||||
|
||||
:param model: ONNX ModelProto object.
|
||||
:param strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model.
|
||||
:return: TensorflowRep object.
|
||||
"""
|
||||
|
||||
# Models with IR_VERSION less than 3 does not have opset_import set.
|
||||
# We default to minimum opset, this behavior is consistent with
|
||||
# onnx checker.
|
||||
# c.f. https://github.com/onnx/onnx/blob/427ac0c1b792363d373e3d7e4eef97fa46458420/onnx/checker.cc#L478
|
||||
if model.ir_version < 3:
|
||||
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
|
||||
else:
|
||||
opset_import = model.opset_import
|
||||
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
|
||||
|
||||
@classmethod
|
||||
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
|
||||
""" Convert ONNX graph to TensorflowRep.
|
||||
|
||||
:param graph_def: ONNX GraphProto object.
|
||||
:param opset: ONNX OperatorSetIdProto list.
|
||||
:param strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model.
|
||||
:return: TensorflowRep object.
|
||||
"""
|
||||
handlers = cls._get_handlers(opset)
|
||||
|
||||
tf_rep_graph = tf.Graph()
|
||||
with tf_rep_graph.as_default():
|
||||
# initializer: TensorProtos representing the values to initialize
|
||||
# a given tensor.
|
||||
# initialized: A list of names of the initialized tensors.
|
||||
if graph_def.initializer:
|
||||
input_dict_items = cls._onnx_initializer_to_input_dict_items(
|
||||
graph_def.initializer)
|
||||
initialized = {init.name for init in graph_def.initializer}
|
||||
else:
|
||||
input_dict_items = []
|
||||
initialized = set()
|
||||
|
||||
# creating placeholders for currently unknown inputs
|
||||
for value_info in graph_def.input:
|
||||
if value_info.name in initialized:
|
||||
continue
|
||||
shape = list(
|
||||
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
|
||||
for d in value_info.type.tensor_type.shape.dim)
|
||||
value_info_name = value_info.name.replace(
|
||||
":", "_tf_") + "_" + get_unique_suffix(
|
||||
) if ":" in value_info.name else value_info.name
|
||||
|
||||
x = tf.placeholder(data_type.onnx2tf(
|
||||
value_info.type.tensor_type.elem_type),
|
||||
name=value_info_name,
|
||||
shape=shape)
|
||||
input_dict_items.append((value_info.name, x))
|
||||
|
||||
# tensor dict: this dictionary is a map from variable names
|
||||
# to the latest produced TF tensors of the given name.
|
||||
# This dictionary will get updated as we build the graph to
|
||||
# record the names of newly produced tensors.
|
||||
tensor_dict = dict(input_dict_items)
|
||||
# Since tensor dict may be updated, we need to keep a copy
|
||||
# of the original input dict where we track the earliest
|
||||
# defined tensors so we can have access to the placeholders
|
||||
# to feed in input tensors when we run the graph.
|
||||
input_dict = dict(input_dict_items)
|
||||
|
||||
for node in graph_def.node:
|
||||
onnx_node = OnnxNode(node)
|
||||
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
||||
tensor_dict,
|
||||
handlers,
|
||||
opset=opset,
|
||||
strict=strict)
|
||||
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
||||
tensor_dict.update(curr_node_output_map)
|
||||
|
||||
tf_rep = TensorflowRep()
|
||||
tf_rep.graph = tf_rep_graph
|
||||
tf_rep.inputs = [
|
||||
value_info.name
|
||||
for value_info in graph_def.input
|
||||
if value_info.name not in initialized
|
||||
]
|
||||
tf_rep.outputs = [value_info.name for value_info in graph_def.output]
|
||||
tf_rep.tensor_dict = tensor_dict
|
||||
return tf_rep
|
||||
|
||||
@classmethod
|
||||
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
|
||||
""" Run ONNX node.
|
||||
|
||||
:param node: ONNX NodeProto object.
|
||||
:param inputs: Inputs.
|
||||
:param device: Device run on.
|
||||
:param outputs_info: None.
|
||||
:param kwargs: Other args.
|
||||
:return: Outputs.
|
||||
"""
|
||||
super(TensorflowBackend, cls).run_node(node, inputs, device)
|
||||
node_graph = tf.Graph()
|
||||
with node_graph.as_default():
|
||||
node = OnnxNode(node)
|
||||
device_option = get_device_option(Device(device))
|
||||
input_tensors = []
|
||||
for i in inputs:
|
||||
input_tensors.append(tf.constant(i))
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
feed_dict_raw = inputs
|
||||
else:
|
||||
assert len(node.inputs) == len(inputs)
|
||||
feed_dict_raw = dict(zip(node.inputs, inputs))
|
||||
|
||||
# TODO: is constant the best way for feeding inputs?
|
||||
input_dict = dict([
|
||||
(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
|
||||
])
|
||||
ops = cls._onnx_node_to_tensorflow_op(node, input_dict)
|
||||
|
||||
with tf.Session() as sess:
|
||||
with tf.device(device_option):
|
||||
sess.run(tf.global_variables_initializer())
|
||||
output_vals = sess.run(ops)
|
||||
|
||||
return namedtupledict('Outputs', node.outputs)(*output_vals)
|
||||
|
||||
@classmethod
|
||||
def _onnx_initializer_to_input_dict_items(cls, initializer):
|
||||
""" Convert ONNX graph initializer to input dict items.
|
||||
|
||||
:param initializer: ONNX graph initializer, list of TensorProto.
|
||||
:return: List of input dict items.
|
||||
"""
|
||||
|
||||
def tensor2list(onnx_tensor):
|
||||
# Use the onnx.numpy_helper because the data may be raw
|
||||
return numpy_helper.to_array(onnx_tensor).flatten().tolist()
|
||||
|
||||
def validate_initializer_name(name):
|
||||
# Prepend a unique suffix if leading charater is "_"
|
||||
name = get_unique_suffix() + name if name[0] is "_" else name
|
||||
|
||||
# Replace ":" with "_tf_" and append a unique suffix for
|
||||
# traceability
|
||||
return name.replace(
|
||||
":", "_tf_") + "_" + get_unique_suffix() if ":" in name else name
|
||||
|
||||
return [(init.name,
|
||||
tf.constant(tensor2list(init),
|
||||
shape=init.dims,
|
||||
dtype=data_type.onnx2tf(init.data_type),
|
||||
name=validate_initializer_name(init.name)))
|
||||
for init in initializer]
|
||||
|
||||
@classmethod
|
||||
def _onnx_node_to_tensorflow_op(cls,
|
||||
node,
|
||||
tensor_dict,
|
||||
handlers=None,
|
||||
opset=None,
|
||||
strict=True):
|
||||
"""
|
||||
Convert onnx node to tensorflow op.
|
||||
|
||||
Args:
|
||||
node: Onnx node object.
|
||||
tensor_dict: Tensor dict of graph.
|
||||
opset: Opset version of the operator set. Default 0 means using latest version.
|
||||
strict: whether to enforce semantic equivalence between the original model
|
||||
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
|
||||
Changing to False is strongly discouraged.
|
||||
Returns:
|
||||
Tensorflow op
|
||||
"""
|
||||
handlers = handlers or cls._get_handlers(opset)
|
||||
if handlers:
|
||||
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
|
||||
if handler:
|
||||
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
|
||||
|
||||
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
|
||||
|
||||
@classmethod
|
||||
def _get_handlers(cls, opset):
|
||||
""" Get all backend handlers with opset.
|
||||
|
||||
:param opset: ONNX OperatorSetIdProto list.
|
||||
:return: All backend handlers.
|
||||
"""
|
||||
opset = opset or [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
||||
opset_dict = dict([(o.domain, o.version) for o in opset])
|
||||
return get_all_backend_handlers(opset_dict)
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device):
|
||||
return common_supports_device(device)
|
||||
|
||||
@classmethod
|
||||
def onnx_graph_to_tensorflow_ops(cls,
|
||||
subgraph,
|
||||
input_values,
|
||||
tensor_dict,
|
||||
opset=None,
|
||||
strict=True):
|
||||
"""
|
||||
Converts ONNX graph to Tensorflow operations
|
||||
Args:
|
||||
subgraph: the ONNX graph to be converted
|
||||
input_values: dictionary with values/tensors to initialize
|
||||
the subgraph inputs. if the subgraph.input
|
||||
are send in as parameters then it is required,
|
||||
otherwise this can be empty dictionary.
|
||||
tensor_dict: the dictionary that contain values for all the
|
||||
node.inputs in the subgraph that are not defined
|
||||
in the subgraph or input_values.
|
||||
opset: opset version of the operator set.
|
||||
strict: whether to enforce semantic equivalence between the
|
||||
original model and the converted tensorflow model,
|
||||
defaults to True (yes, enforce semantic equivalence).
|
||||
Returns:
|
||||
array of Tensorflow Tensors
|
||||
"""
|
||||
# get the subgraph.input from input_values
|
||||
subgraph_tensor_dict = input_values.copy()
|
||||
# get the rest of the subgraph input from tensor_dict
|
||||
for i in subgraph.input:
|
||||
if i.name not in subgraph_tensor_dict.keys():
|
||||
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
|
||||
# get the required initializer constant node(s) for the subgraph
|
||||
# Need to get the initializer constant nodes from tensor_dict here
|
||||
# because input from initializer will not be send in as inputs
|
||||
# to the subgraph and those nodes are not in the subgraph
|
||||
nodes_outputs = []
|
||||
for node in subgraph.node:
|
||||
for o_name in node.output:
|
||||
nodes_outputs.append(o_name)
|
||||
for node in subgraph.node:
|
||||
for i_name in node.input:
|
||||
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
|
||||
):
|
||||
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
|
||||
onnx_node = OnnxNode(node)
|
||||
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
|
||||
subgraph_tensor_dict,
|
||||
opset=opset,
|
||||
strict=strict)
|
||||
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
|
||||
subgraph_tensor_dict.update(curr_node_output_map)
|
||||
return subgraph_tensor_dict
|
||||
|
||||
@classmethod
|
||||
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
|
||||
"""
|
||||
Converts ONNX graph to TensorflowRep
|
||||
Args:
|
||||
graph_def: the ONNX graph to be converted
|
||||
strict: whether to enforce semantic equivalence between the
|
||||
original model and the converted tensorflow model,
|
||||
defaults to True (yes, enforce semantic equivalence).
|
||||
Returns:
|
||||
TensorflowRep object.
|
||||
"""
|
||||
# get the opset of the installed ONNX
|
||||
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
|
||||
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
|
||||
|
||||
|
||||
prepare = TensorflowBackend.prepare
|
||||
|
||||
run_node = TensorflowBackend.run_node
|
||||
|
||||
run_model = TensorflowBackend.run_model
|
||||
|
||||
supports_device = TensorflowBackend.supports_device
|
||||
|
||||
onnx_graph_to_tensorflow_ops = TensorflowBackend.onnx_graph_to_tensorflow_ops
|
||||
|
||||
onnx_graph_to_tensorflow_rep = TensorflowBackend.onnx_graph_to_tensorflow_rep
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx.backend.base import BackendRep, namedtupledict
|
||||
|
||||
|
||||
class TensorflowRep(BackendRep):
|
||||
|
||||
def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
|
||||
super(TensorflowRep, self).__init__()
|
||||
self._graph = graph
|
||||
self._inputs = inputs or []
|
||||
self._outputs = outputs or []
|
||||
self._tensor_dict = tensor_dict or {}
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._graph
|
||||
|
||||
@graph.setter
|
||||
def graph(self, graph):
|
||||
self._graph = graph
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self._inputs
|
||||
|
||||
@inputs.setter
|
||||
def inputs(self, inputs):
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return self._outputs
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, outputs):
|
||||
self._outputs = outputs
|
||||
|
||||
@property
|
||||
def tensor_dict(self):
|
||||
return self._tensor_dict
|
||||
|
||||
@tensor_dict.setter
|
||||
def tensor_dict(self, tensor_dict):
|
||||
self._tensor_dict = tensor_dict
|
||||
|
||||
def run(self, inputs, **kwargs):
|
||||
""" Run TensorflowRep.
|
||||
|
||||
:param inputs: Given inputs.
|
||||
:param kwargs: Other args.
|
||||
:return: Outputs.
|
||||
"""
|
||||
super(TensorflowRep, self).run(inputs, **kwargs)
|
||||
|
||||
# TODO: handle name scope if necessary
|
||||
with self.graph.as_default():
|
||||
with tf.Session() as sess:
|
||||
if isinstance(inputs, dict):
|
||||
feed_dict = inputs
|
||||
elif isinstance(inputs, list) or isinstance(inputs, tuple):
|
||||
if len(self.inputs) != len(inputs):
|
||||
raise RuntimeError('Expected {} values for uninitialized '
|
||||
'graph inputs ({}), but got {}.'.format(
|
||||
len(self.inputs), ', '.join(self.inputs),
|
||||
len(inputs)))
|
||||
feed_dict = dict(zip(self.inputs, inputs))
|
||||
else:
|
||||
# single input
|
||||
feed_dict = dict([(self.inputs[0], inputs)])
|
||||
|
||||
feed_dict = {
|
||||
self.tensor_dict[key]: feed_dict[key] for key in self.inputs
|
||||
}
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
outputs = [self.tensor_dict[output] for output in self.outputs]
|
||||
|
||||
output_values = sess.run(outputs, feed_dict=feed_dict)
|
||||
return namedtupledict('Outputs', self.outputs)(*output_values)
|
||||
|
||||
def export_graph(self, path):
|
||||
"""Export backend representation to a Tensorflow proto file.
|
||||
|
||||
This function obtains the graph proto corresponding to the ONNX
|
||||
model associated with the backend representation and serializes
|
||||
to a protobuf file.
|
||||
|
||||
:param path: The path to the output TF protobuf file.
|
||||
|
||||
:returns: none.
|
||||
"""
|
||||
graph_proto = self.graph.as_graph_def()
|
||||
# rename the output nodes
|
||||
meaningful_names = {}
|
||||
for output_name in self.outputs:
|
||||
meaningful_names[self.tensor_dict[output_name].name.replace(':0', '')] = output_name
|
||||
for node in graph_proto.node:
|
||||
if node.name in meaningful_names.keys():
|
||||
node.name = meaningful_names[node.name]
|
||||
|
||||
file = open(path, "wb")
|
||||
file.write(graph_proto.SerializeToString())
|
||||
file.close()
|
||||
@@ -0,0 +1,24 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import onnx_tf.converter
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ONNX-Tensorflow Command Line Interface")
|
||||
parser.add_argument(
|
||||
"command",
|
||||
choices=["convert"],
|
||||
help="Available commands.")
|
||||
|
||||
if len(args) == 0:
|
||||
parser.parse_args(["-h"])
|
||||
cli_tool = parser.parse_args([args[0]])
|
||||
if cli_tool.command == "convert":
|
||||
return onnx_tf.converter.main(args[1:])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
from onnx.backend.base import DeviceType
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
IS_PYTHON3 = sys.version_info > (3,)
|
||||
logger = logging.getLogger('onnx-tf')
|
||||
|
||||
# create console handler and formatter for logger
|
||||
console = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
|
||||
class Deprecated:
|
||||
"""Add deprecated message when function is called.
|
||||
|
||||
Usage:
|
||||
from onnx_tf.common import deprecated
|
||||
|
||||
@deprecated
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: func is deprecated. It will be removed in future release.
|
||||
|
||||
@deprecated("Message")
|
||||
def func():
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
|
||||
@deprecated({"arg": "Message",
|
||||
"arg_1": deprecated.MSG_WILL_REMOVE,
|
||||
"arg_2": "",})
|
||||
def func(arg, arg_1, arg_2):
|
||||
pass
|
||||
|
||||
UserWarning: Message
|
||||
UserWarning: arg_1 of func is deprecated. It will be removed in future release.
|
||||
UserWarning: arg_2 of func is deprecated.
|
||||
"""
|
||||
|
||||
MSG_WILL_REMOVE = " It will be removed in future release."
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.deprecated_decorator(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def messages():
|
||||
return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
|
||||
|
||||
@staticmethod
|
||||
def deprecated_decorator(arg=None):
|
||||
# deprecate function with default message MSG_WILL_REMOVE
|
||||
# @deprecated
|
||||
if inspect.isfunction(arg):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn("{} is deprecated.{}".format(
|
||||
arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
|
||||
return arg(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
|
||||
|
||||
def deco(func):
|
||||
# deprecate arg
|
||||
# @deprecated({...})
|
||||
if isinstance(deprecated_arg, dict):
|
||||
for name, message in deprecated_arg.items():
|
||||
if message in Deprecated.messages():
|
||||
message = "{} of {} is deprecated.{}".format(
|
||||
name, func.__module__ + "." + func.__name__, message or "")
|
||||
warnings.warn(message)
|
||||
# deprecate function with message
|
||||
# @deprecated("message")
|
||||
elif isinstance(deprecated_arg, str):
|
||||
message = deprecated_arg
|
||||
if message in Deprecated.messages():
|
||||
message = "{} is deprecated.{}".format(
|
||||
func.__module__ + "." + func.__name__, message)
|
||||
warnings.warn(message)
|
||||
return func
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
deprecated = Deprecated()
|
||||
|
||||
|
||||
# This function inserts an underscore before every upper
|
||||
# case letter and lowers that upper case letter except for
|
||||
# the first letter.
|
||||
def op_name_to_lower(name):
|
||||
return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
|
||||
|
||||
|
||||
def get_unique_suffix():
|
||||
""" Get unique suffix by using first 8 chars from uuid.uuid4
|
||||
to make unique identity name.
|
||||
|
||||
:return: Unique suffix string.
|
||||
"""
|
||||
return str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
def get_perm_from_formats(from_, to_):
|
||||
""" Get perm from data formats.
|
||||
For example:
|
||||
get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
|
||||
|
||||
:param from_: From data format string.
|
||||
:param to_: To data format string.
|
||||
:return: Perm. Int list.
|
||||
"""
|
||||
return list(map(lambda x: from_.find(x), to_))
|
||||
|
||||
|
||||
# TODO: allow more flexible placement
|
||||
def get_device_option(device):
|
||||
m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
|
||||
return m[device.type]
|
||||
|
||||
|
||||
def get_data_format(x_rank):
|
||||
""" Get data format by input rank.
|
||||
Channel first if support CUDA.
|
||||
|
||||
:param x_rank: Input rank.
|
||||
:return: Data format.
|
||||
"""
|
||||
sp_dim_names = ["D", "H", "W"]
|
||||
sp_dim_lst = []
|
||||
for i in range(x_rank - 2):
|
||||
sp_dim_lst.append(sp_dim_names[-i - 1])
|
||||
|
||||
sp_dim_string = "".join(reversed(sp_dim_lst))
|
||||
storage_format = "NC" + sp_dim_string
|
||||
|
||||
if supports_device("CUDA"):
|
||||
compute_format = "NC" + sp_dim_string
|
||||
else:
|
||||
compute_format = "N" + sp_dim_string + "C"
|
||||
return storage_format, compute_format
|
||||
|
||||
|
||||
def supports_device(device):
|
||||
""" Check if support target device.
|
||||
|
||||
:param device: CUDA or CPU.
|
||||
:return: If supports.
|
||||
"""
|
||||
if device == "CUDA":
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
return len([x.name for x in local_device_protos if x.device_type == 'GPU'
|
||||
]) > 0
|
||||
elif device == "CPU":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
|
||||
deprecated.MSG_WILL_REMOVE,
|
||||
"Use TensorflowGraph.get_outputs_names instead."))
|
||||
def get_output_node_names(graph_def):
|
||||
"""Get output node names from GraphDef.
|
||||
Args:
|
||||
graph_def: GraphDef object.
|
||||
Returns:
|
||||
List of output node names.
|
||||
"""
|
||||
nodes, input_names = dict(), set()
|
||||
for node in graph_def.node:
|
||||
nodes[node.name] = node
|
||||
input_names.update(set(node.input))
|
||||
return list(set(nodes) - input_names)
|
||||
|
||||
|
||||
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
|
||||
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
|
||||
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
|
||||
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from onnx_tf.common import IS_PYTHON3
|
||||
|
||||
|
||||
def convert_tf(attr):
|
||||
return __convert_tf_attr_value(attr)
|
||||
|
||||
|
||||
def convert_onnx(attr):
|
||||
return __convert_onnx_attribute_proto(attr)
|
||||
|
||||
|
||||
def __convert_tf_attr_value(attr):
|
||||
""" convert Tensorflow AttrValue object to Python object
|
||||
"""
|
||||
if attr.HasField('list'):
|
||||
return __convert_tf_list_value(attr.list)
|
||||
if attr.HasField('s'):
|
||||
return attr.s
|
||||
elif attr.HasField('i'):
|
||||
return attr.i
|
||||
elif attr.HasField('f'):
|
||||
return attr.f
|
||||
elif attr.HasField('b'):
|
||||
return attr.b
|
||||
elif attr.HasField('type'):
|
||||
return attr.type
|
||||
elif attr.HasField('shape'):
|
||||
return attr.type
|
||||
elif attr.HasField('tensor'):
|
||||
return attr.tensor
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
|
||||
|
||||
|
||||
def __convert_tf_list_value(list_value):
|
||||
""" convert Tensorflow ListValue object to Python object
|
||||
"""
|
||||
if list_value.s:
|
||||
return list_value.s
|
||||
elif list_value.i:
|
||||
return list_value.i
|
||||
elif list_value.f:
|
||||
return list_value.f
|
||||
elif list_value.b:
|
||||
return list_value.b
|
||||
elif list_value.tensor:
|
||||
return list_value.tensor
|
||||
elif list_value.type:
|
||||
return list_value.type
|
||||
elif list_value.shape:
|
||||
return list_value.shape
|
||||
elif list_value.func:
|
||||
return list_value.func
|
||||
else:
|
||||
raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
|
||||
|
||||
|
||||
def __convert_onnx_attribute_proto(attr_proto):
|
||||
"""
|
||||
Convert an ONNX AttributeProto into an appropriate Python object
|
||||
for the type.
|
||||
NB: Tensor attribute gets returned as the straight proto.
|
||||
"""
|
||||
if attr_proto.HasField('f'):
|
||||
return attr_proto.f
|
||||
elif attr_proto.HasField('i'):
|
||||
return attr_proto.i
|
||||
elif attr_proto.HasField('s'):
|
||||
return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
|
||||
elif attr_proto.HasField('t'):
|
||||
return attr_proto.t # this is a proto!
|
||||
elif attr_proto.HasField('g'):
|
||||
return attr_proto.g
|
||||
elif attr_proto.floats:
|
||||
return list(attr_proto.floats)
|
||||
elif attr_proto.ints:
|
||||
return list(attr_proto.ints)
|
||||
elif attr_proto.strings:
|
||||
str_list = list(attr_proto.strings)
|
||||
if IS_PYTHON3:
|
||||
str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
|
||||
return str_list
|
||||
elif attr_proto.HasField('sparse_tensor'):
|
||||
return attr_proto.sparse_tensor
|
||||
else:
|
||||
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
|
||||
@@ -0,0 +1,37 @@
|
||||
from tensorflow.python.framework.tensor_util import MakeNdarray
|
||||
|
||||
from onnx_tf.common import data_type
|
||||
|
||||
# Keyed by old attribute names.
|
||||
__tf_attr_translator = {
|
||||
"_output_shapes": lambda x: list(map(lambda shape: get_tf_shape_as_list(shape.dim), x.list.shape)),
|
||||
"shape": lambda x: get_tf_shape_as_list(x.shape.dim),
|
||||
"T": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"dtype": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"component_types": lambda x: data_type.tf2onnx(list(x.list.type) or x.type),
|
||||
"value": lambda x: MakeNdarray(x.tensor),
|
||||
"seed2": lambda x: float(x.i),
|
||||
"seed": lambda x: float(x.i),
|
||||
"keep_dims": lambda x: int(x.b),
|
||||
"squeeze_dims": lambda x: list(x.list.i),
|
||||
}
|
||||
|
||||
__onnx_attr_translator = {
|
||||
"axis": lambda x: int(x),
|
||||
"axes": lambda x: [int(a) for a in x],
|
||||
"dtype": lambda x: data_type.onnx2tf(x),
|
||||
"keepdims": lambda x: bool(x),
|
||||
"to": lambda x: data_type.onnx2tf(x),
|
||||
}
|
||||
|
||||
|
||||
def translate_tf(key, val):
|
||||
return __tf_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def translate_onnx(key, val):
|
||||
return __onnx_attr_translator.get(key, lambda x: x)(val)
|
||||
|
||||
|
||||
def get_tf_shape_as_list(tf_shape_dim):
|
||||
return list(map(lambda x: x.size, list(tf_shape_dim)))
|
||||
@@ -0,0 +1,71 @@
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
from onnx import mapping
|
||||
from onnx import TensorProto
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def tf2onnx(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
tf_dype = tf.as_dtype(dtype)
|
||||
elif isinstance(dtype, tf.DType):
|
||||
tf_dype = dtype
|
||||
elif isinstance(dtype, list):
|
||||
return [tf2onnx(t) for t in dtype]
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or tf.DType.")
|
||||
|
||||
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
|
||||
# to leverage existing type conversion infrastructure;
|
||||
# However, we need to intercept the string type early because
|
||||
# lowering tf.string type to numpy dtype results in loss of
|
||||
# information. <class 'object'> is returned instead of the
|
||||
# numpy string type desired.
|
||||
if tf_dype is tf.string:
|
||||
return TensorProto.STRING
|
||||
|
||||
onnx_dtype = None
|
||||
try:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
|
||||
tf_dype.as_numpy_dtype)]
|
||||
finally:
|
||||
if onnx_dtype is None:
|
||||
common.logger.warning(
|
||||
"Can't convert tf dtype {} to ONNX dtype. Return 0 (TensorProto.UNDEFINED)."
|
||||
.format(tf_dype))
|
||||
onnx_dtype = TensorProto.UNDEFINED
|
||||
return onnx_dtype
|
||||
|
||||
|
||||
def onnx2tf(dtype):
|
||||
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
|
||||
|
||||
|
||||
def onnx2field(dtype):
|
||||
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
|
||||
|
||||
|
||||
def _onnx_dtype(dtype):
|
||||
if isinstance(dtype, Number):
|
||||
onnx_dype = dtype
|
||||
elif isinstance(dtype, str):
|
||||
onnx_dype = TensorProto.DataType.Value(dtype)
|
||||
else:
|
||||
raise RuntimeError("dtype should be number or str.")
|
||||
return onnx_dype
|
||||
|
||||
|
||||
# TODO (tjingrant) unify _onnx_dtype into any_dtype_to_onnx_dtype
|
||||
def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
|
||||
dtype_mask = [1 if val else 0 for val in [np_dtype, tf_dtype, onnx_dtype]]
|
||||
num_type_set = sum(dtype_mask)
|
||||
assert num_type_set == 1, "One and only one type must be set. However, {} set.".format(
|
||||
sum(num_type_set))
|
||||
|
||||
if np_dtype:
|
||||
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
|
||||
if tf_dtype:
|
||||
onnx_dtype = tf2onnx(tf_dtype)
|
||||
|
||||
return onnx_dtype
|
||||
@@ -0,0 +1,73 @@
|
||||
import inspect
|
||||
import onnx_tf.common as common
|
||||
|
||||
|
||||
class CustomException(object):
|
||||
|
||||
def __init__(self):
|
||||
self._func = RuntimeError
|
||||
self._message = ""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if inspect.isclass(self._func) and issubclass(self._func, Exception):
|
||||
raise self._func(self.get_message(*args, **kwargs))
|
||||
elif callable(self._func):
|
||||
self._func(self.get_message(*args, **kwargs))
|
||||
|
||||
def get_message(self, *args, **kwargs):
|
||||
return self._message
|
||||
|
||||
|
||||
class OpUnimplementedException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnimplementedException, self).__init__()
|
||||
self._func = NotImplementedError
|
||||
self._message = "{} is not implemented."
|
||||
|
||||
def __call__(self, op, version=None, domain=None):
|
||||
if IGNORE_UNIMPLEMENTED:
|
||||
self._func = common.logger.warning
|
||||
super(OpUnimplementedException, self).__call__(op, version, domain)
|
||||
|
||||
def get_message(self, op, version=None, domain=None):
|
||||
insert_message = op
|
||||
if version is not None:
|
||||
insert_message += " version {}".format(version)
|
||||
if domain is not None:
|
||||
insert_message += " in domain `{}`".format(domain)
|
||||
return self._message.format(insert_message)
|
||||
|
||||
|
||||
class OpUnsupportedException(object):
|
||||
|
||||
def __init__(self):
|
||||
super(OpUnsupportedException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} is not supported in {}."
|
||||
|
||||
def __call__(self, op, framework):
|
||||
raise self._func(self.get_message(op, framework))
|
||||
|
||||
def get_message(self, op, framework):
|
||||
return self._message.format(op, framework)
|
||||
|
||||
|
||||
class ConstNotFoundException(CustomException):
|
||||
|
||||
def __init__(self):
|
||||
super(ConstNotFoundException, self).__init__()
|
||||
self._func = RuntimeError
|
||||
self._message = "{} of {} is not found in graph consts."
|
||||
|
||||
def __call__(self, name, op):
|
||||
super(ConstNotFoundException, self).__call__(name, op)
|
||||
|
||||
def get_message(self, name, op):
|
||||
return self._message.format(name, op)
|
||||
|
||||
|
||||
IGNORE_UNIMPLEMENTED = False
|
||||
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
|
||||
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
|
||||
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
|
||||
@@ -0,0 +1,76 @@
|
||||
from onnx import defs
|
||||
|
||||
import onnx_tf.common as common
|
||||
from onnx_tf.handlers.backend import * # noqa
|
||||
from onnx_tf.handlers.backend_handler import BackendHandler
|
||||
|
||||
import onnx_tf.common as common
|
||||
|
||||
def get_all_backend_handlers(opset_dict):
|
||||
""" Get a dict of all backend handler classes.
|
||||
e.g. {'domain': {'Abs': Abs handler class}, ...}, }.
|
||||
|
||||
:param opset_dict: A dict of opset. e.g. {'domain': version, ...}
|
||||
:return: Dict.
|
||||
"""
|
||||
handlers = {}
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
domain = handler.DOMAIN
|
||||
version = opset_dict[domain] if domain in opset_dict else 1
|
||||
handler.VERSION = version
|
||||
|
||||
since_version = 1
|
||||
if defs.has(handler.ONNX_OP, domain=handler.DOMAIN):
|
||||
try:
|
||||
since_version = defs.get_schema(
|
||||
handler.ONNX_OP,
|
||||
domain=handler.DOMAIN,
|
||||
max_inclusive_version=version).since_version
|
||||
except RuntimeError:
|
||||
common.logger.debug("Fail to get since_version of {} in domain `{}` "
|
||||
"with max_inclusive_version={}. Set to 1.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN, version))
|
||||
else:
|
||||
common.logger.debug("Unknown op {} in domain `{}`.".format(
|
||||
handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
|
||||
handler.SINCE_VERSION = since_version
|
||||
handlers.setdefault(domain, {})[handler.ONNX_OP] = handler
|
||||
return handlers
|
||||
|
||||
|
||||
def get_backend_coverage():
|
||||
""" Get backend coverage for document.
|
||||
|
||||
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
|
||||
"""
|
||||
|
||||
onnx_coverage = {}
|
||||
experimental_op = set()
|
||||
for handler in BackendHandler.__subclasses__():
|
||||
handler.check_cls()
|
||||
|
||||
versions = handler.get_versions()
|
||||
domain = handler.DOMAIN
|
||||
if getattr(handler, "EXPERIMENTAL", False):
|
||||
experimental_op.add(handler.ONNX_OP)
|
||||
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
|
||||
return onnx_coverage, experimental_op
|
||||
|
||||
|
||||
def _update_coverage(coverage, domain, key, versions):
|
||||
domain_coverage = coverage.setdefault(domain, {})
|
||||
vers = domain_coverage.get(key, [])
|
||||
vers.extend(versions)
|
||||
domain_coverage[key] = sorted(list(set(vers)))
|
||||
|
||||
|
||||
def get_backend_partial_support_detail():
|
||||
ps_dict = {}
|
||||
opset_dict = dict([(defs.ONNX_DOMAIN, defs.onnx_opset_version())])
|
||||
handlers = get_all_backend_handlers(opset_dict)[defs.ONNX_DOMAIN]
|
||||
for op_name in handlers:
|
||||
if handlers[op_name].PARTIAL_SUPPORT:
|
||||
ps_dict[op_name] = handlers[op_name].PS_DESCRIPTION
|
||||
return ps_dict
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
def get_onnx_version():
|
||||
return tuple(map(int, onnx.version.version.split(".")))
|
||||
|
||||
|
||||
# Returns whether onnx version is prior to major.minor.patch
|
||||
def legacy_onnx_pre_ver(major=0, minor=0, patch=0):
|
||||
return get_onnx_version() < (major, minor, patch)
|
||||
|
||||
|
||||
# Returns whether the opset version accompanying the
|
||||
# onnx installation is prior to version passed.
|
||||
def legacy_opset_pre_ver(version):
|
||||
return onnx.defs.onnx_opset_version() < version
|
||||
@@ -0,0 +1,263 @@
|
||||
from __future__ import division
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
pad_ops = namedtuple("pad_ops",
|
||||
["max_op", "ceil_op", "floor_op", "cast_int_op"])
|
||||
|
||||
pad_numpy_ops = pad_ops(np.maximum, np.ceil, np.floor,
|
||||
lambda arr: arr.astype(np.int64))
|
||||
pad_tf_ops = pad_ops(tf.maximum, tf.math.ceil, tf.math.floor,
|
||||
lambda tensor: tf.cast(tensor, tf.int64))
|
||||
|
||||
|
||||
def calc_pads_same(in_spatial_shape, kernel_shape, strides,
|
||||
dilations, padding, padding_ops=pad_numpy_ops,
|
||||
pads_order=1):
|
||||
"""
|
||||
Calculates the SAME paddings that need to be added to the input
|
||||
|
||||
Args:
|
||||
in_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: padding to calculate: SAME_UPPER or
|
||||
SAME_LOWER
|
||||
padding_ops: namedtuple with ops to be used during
|
||||
calculations. there are two sets of ops
|
||||
defined pad_numpy_ops and pad_tf_ops with
|
||||
numpy and tensorflow ops
|
||||
pads_order: order of returned pads. possible options are:
|
||||
1 - b1, b2, ..., bn, e1, e2, ..., en
|
||||
2 - b1, e1, b2, e2, ..., bn, en
|
||||
where n = len(kernel_shape) * 2,
|
||||
b1, b2, ..., bn define pads at the begging of
|
||||
axis
|
||||
e1, e2, ..., en define pads at the end of
|
||||
axis
|
||||
Return:
|
||||
pads: array with calculated pads. the order of the
|
||||
values is determined by `pads_order`
|
||||
|
||||
"""
|
||||
spatial_size = len(kernel_shape)
|
||||
pads = [0] * (spatial_size * 2)
|
||||
for i in range(spatial_size):
|
||||
in_size = in_spatial_shape[i]
|
||||
filter_size = (kernel_shape[i] - 1) * dilations[i] + 1
|
||||
|
||||
out_size = padding_ops.ceil_op(in_size / strides[i])
|
||||
out_size = padding_ops.cast_int_op(out_size)
|
||||
pad_along_axis = \
|
||||
padding_ops.max_op((out_size - 1) * strides[i] +
|
||||
filter_size - in_size, 0)
|
||||
if padding.lower() == "same_lower":
|
||||
pad_op = padding_ops.ceil_op
|
||||
else:
|
||||
pad_op = padding_ops.floor_op
|
||||
pad_begin = pad_op(pad_along_axis / 2)
|
||||
|
||||
pad_begin = padding_ops.cast_int_op(pad_begin)
|
||||
pad_along_axis = padding_ops.cast_int_op(pad_along_axis)
|
||||
|
||||
pad_end = pad_along_axis - pad_begin
|
||||
|
||||
pads[i * pads_order] = pad_begin
|
||||
pads[i * pads_order +
|
||||
(spatial_size if pads_order == 1 else 1)] = pad_end
|
||||
|
||||
return pads
|
||||
|
||||
|
||||
def calc_output_shape(input_spatial_shape, kernel_shape, strides, dilations,
|
||||
padding, ceil_mode=False):
|
||||
"""
|
||||
Calculate output shape
|
||||
|
||||
Args:
|
||||
input_spatial_shape: input spatial shape
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis
|
||||
padding: can be explicit paddings, "SAME_UPPER" or
|
||||
"SAME_LOWER"
|
||||
Return:
|
||||
output_shape: calculated output shape
|
||||
"""
|
||||
spatial_size = len(input_spatial_shape)
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(input_spatial_shape, kernel_shape,
|
||||
strides, dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
output_shape = []
|
||||
for dim in range(spatial_size):
|
||||
output_shape.append(_pooling_output_shape(input_spatial_shape[dim],
|
||||
kernel_shape[dim], strides[dim], dilations[dim],
|
||||
padding[dim] + padding[dim + spatial_size],
|
||||
ceil_mode))
|
||||
|
||||
return output_shape
|
||||
|
||||
|
||||
def _pooling_output_shape(input_size, ksize, stride, dilation, pad, ceil_mode):
|
||||
output_size = (input_size + pad - ((ksize - 1) * dilation + 1) +
|
||||
((stride-1) if ceil_mode else 0)) // stride + 1
|
||||
if (pad):
|
||||
if ((output_size - 1) * stride >= input_size + pad):
|
||||
output_size -= 1
|
||||
return output_size
|
||||
|
||||
|
||||
def py_pool(input, kernel_shape, strides=None, dilations=None,
|
||||
padding=None, ceil_mode=False, pooling_type="MAX",
|
||||
include_indices=True, p=2):
|
||||
"""
|
||||
Implementation of Max and Average pool operations in Python
|
||||
Args:
|
||||
input: input N-D data array in NC* format
|
||||
kernel_shape: the size of the kernel along each axis
|
||||
strides: stride along each spatial axis
|
||||
dilations: dilations value along each spatial axis of filter
|
||||
padding: padding for the beginning and ending along each
|
||||
spatial axis. `padding` format should be as follow
|
||||
[x1_begin, x2_begin...x1_end, x2_end,...]
|
||||
ceil_mode: whether to use ceil or floor (default) to compute
|
||||
the output shape.
|
||||
pooling_type: specifies pooling type. Values can be "MAX", "AVG" or
|
||||
"LP"
|
||||
include_indices: should indices be included in the output
|
||||
p: specifies the p parameter for LpPooling
|
||||
Return:
|
||||
pooled: output data from max pooling across the input
|
||||
ind: indices of the selected max values from the input
|
||||
"""
|
||||
|
||||
if type(pooling_type) is not str:
|
||||
pooling_type = pooling_type.decode("UTF-8")
|
||||
|
||||
input_shape = np.shape(input)
|
||||
inp_sp_shape = input_shape[2:]
|
||||
input_dtype = input.dtype
|
||||
if np.issubdtype(input_dtype, np.integer):
|
||||
input_dtype_min = np.iinfo(input_dtype).min
|
||||
else:
|
||||
input_dtype_min = np.finfo(input_dtype).min
|
||||
|
||||
if pooling_type == "LP":
|
||||
rootN = (1.0 / p)
|
||||
|
||||
def _loop_over_output(batch, channel):
|
||||
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
|
||||
for counters in itertools.product(*dims):
|
||||
input_ranges = []
|
||||
for dim in range(spatial_size):
|
||||
dim_start = \
|
||||
counters[dim] * strides[dim] - pads[dim * 2]
|
||||
dim_end = \
|
||||
min(dim_start + (kernel_shape[dim] - 1) * dilations[dim]
|
||||
+ 1, inp_sp_shape[dim])
|
||||
while dim_start < 0:
|
||||
dim_start += dilations[dim]
|
||||
|
||||
cur_range = [i for i in range(dim_start,
|
||||
dim_end, dilations[dim])]
|
||||
input_ranges.append(cur_range)
|
||||
if pooling_type in ["AVG", "LP"]:
|
||||
val_sum = 0
|
||||
val_count = 0
|
||||
else:
|
||||
maxval = input_dtype_min
|
||||
maxind = -1
|
||||
for input_ind in itertools.product(*input_ranges):
|
||||
ind = (batch, channel) + input_ind
|
||||
val = input[ind]
|
||||
if pooling_type == "AVG":
|
||||
val_sum += val
|
||||
val_count += 1
|
||||
elif pooling_type == "LP":
|
||||
val_sum += abs(val ** p)
|
||||
else:
|
||||
if val > maxval:
|
||||
maxval = val
|
||||
ind = 0
|
||||
for i in range(spatial_size):
|
||||
coef = 1
|
||||
for j in range(i+1, spatial_size):
|
||||
coef *= inp_sp_shape[j]
|
||||
ind += input_ind[i] * coef
|
||||
maxind = ind
|
||||
ind = (batch, channel) + counters
|
||||
if pooling_type == "AVG":
|
||||
out_pool[ind] = val_sum / val_count
|
||||
elif pooling_type == "LP":
|
||||
out_pool[ind] = val_sum ** rootN
|
||||
else:
|
||||
out_pool[ind] = maxval
|
||||
out_ind[ind] = maxind
|
||||
|
||||
spatial_size = len(kernel_shape)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
channels_num = input_shape[1]
|
||||
|
||||
if strides is None:
|
||||
strides = kernel_shape
|
||||
|
||||
if dilations is None:
|
||||
dilations = [1] * spatial_size
|
||||
|
||||
if padding is None:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
if type(padding) is bytes:
|
||||
padding = padding.decode()
|
||||
|
||||
if type(padding) is not list and type(padding) is not np.ndarray:
|
||||
if type(padding) is not str:
|
||||
padding = padding.decode("UTF-8")
|
||||
|
||||
if padding.lower().startswith("same"):
|
||||
padding = calc_pads_same(inp_sp_shape, kernel_shape, strides,
|
||||
dilations, padding)
|
||||
else:
|
||||
padding = [0] * spatial_size * 2
|
||||
|
||||
pads = []
|
||||
pad_along_axis = []
|
||||
output_sp_shape = []
|
||||
|
||||
for dim in range(spatial_size):
|
||||
pads.append(padding[dim])
|
||||
pads.append(padding[dim + spatial_size])
|
||||
pad_along_axis.append(padding[dim] + padding[dim + spatial_size])
|
||||
|
||||
input_size = input_shape[dim + 2]
|
||||
output_size = \
|
||||
_pooling_output_shape(input_size, kernel_shape[dim],
|
||||
strides[dim], dilations[dim],
|
||||
pad_along_axis[dim], ceil_mode)
|
||||
output_sp_shape.append(output_size)
|
||||
|
||||
out_pool = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, input_dtype)
|
||||
out_ind = np.zeros([input_shape[0], input_shape[1]] +
|
||||
output_sp_shape, np.int64)
|
||||
|
||||
for batch in range(batch_size):
|
||||
for channel in range(channels_num):
|
||||
_loop_over_output(batch, channel)
|
||||
|
||||
if not include_indices:
|
||||
return out_pool
|
||||
else:
|
||||
return out_pool, out_ind
|
||||
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tf_shape(tensor):
|
||||
"""
|
||||
Helper function returning the shape of a Tensor.
|
||||
The function will check for fully defined shape and will return
|
||||
numpy array or if the shape is not fully defined will use tf.shape()
|
||||
to return the shape as a Tensor.
|
||||
"""
|
||||
if tensor.shape.is_fully_defined():
|
||||
return np.array(tensor.shape.as_list(), dtype=np.int64)
|
||||
else:
|
||||
return tf.shape(tensor, out_type=tf.int64)
|
||||
|
||||
|
||||
def tf_product(a, b):
|
||||
"""
|
||||
Calculates the cartesian product of two column vectors a and b
|
||||
|
||||
Example:
|
||||
|
||||
a = [[1]
|
||||
[2]
|
||||
[3]]
|
||||
|
||||
b = [[0]
|
||||
[1]]
|
||||
|
||||
result = [[1 0]
|
||||
[1 1]
|
||||
[2 0]
|
||||
[2 1]
|
||||
[3 0]
|
||||
[3 1]]
|
||||
"""
|
||||
tile_a = tf.tile(a, [1, tf.shape(b)[0]])
|
||||
tile_a = tf.expand_dims(tile_a, 2)
|
||||
tile_a = tf.reshape(tile_a, [-1, 1])
|
||||
|
||||
b = tf.tile(b, [tf.shape(a)[0], 1])
|
||||
b = tf.concat([tile_a, b], axis=1)
|
||||
|
||||
return b
|
||||
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import onnx
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
|
||||
import onnx_tf.backend as backend
|
||||
import onnx_tf.common as common
|
||||
from onnx_tf.common import get_unique_suffix
|
||||
from onnx_tf.pb_wrapper import TensorflowGraph
|
||||
|
||||
|
||||
def main(args):
|
||||
args = parse_args(args)
|
||||
convert(**{k: v for k, v in vars(args).items() if v is not None})
|
||||
|
||||
|
||||
def parse_args(args):
|
||||
|
||||
class ListAction(argparse.Action):
|
||||
""" Define how to convert command line list strings to Python objects.
|
||||
"""
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
values = values if values[0] not in ("(", "[") or values[-1] not in (
|
||||
")", "]") else values[1:-1]
|
||||
res = []
|
||||
for value in values.split(","):
|
||||
if value.isdigit():
|
||||
res.append(int(value))
|
||||
else:
|
||||
res.append(value)
|
||||
setattr(namespace, self.dest, res)
|
||||
|
||||
class OpsetAction(argparse.Action):
|
||||
""" Define how to convert command line opset strings to Python objects.
|
||||
"""
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.isdigit():
|
||||
setattr(namespace, "opset", int(values))
|
||||
else:
|
||||
res = []
|
||||
while values and values[0] in ("(", "["):
|
||||
values = values[1:]
|
||||
while values and values[-1] in (")", "]"):
|
||||
values = values[:-1]
|
||||
for value in values.split("),("):
|
||||
l, r = value.split(",")
|
||||
res.append((l, int(r)))
|
||||
setattr(namespace, "opset", res)
|
||||
|
||||
def get_param_doc_dict(funcs):
|
||||
"""Get doc of funcs params.
|
||||
|
||||
Args:
|
||||
funcs: Target funcs.
|
||||
|
||||
Returns:
|
||||
Dict of params doc.
|
||||
"""
|
||||
|
||||
# TODO(fumihwh): support google doc format
|
||||
def helper(doc, func):
|
||||
first_idx = doc.find(":param")
|
||||
last_idx = doc.find(":return")
|
||||
last_idx = last_idx if last_idx != -1 else len(doc)
|
||||
param_doc = doc[first_idx:last_idx]
|
||||
params_doc = param_doc.split(":param ")[1:]
|
||||
return {
|
||||
p[:p.find(": ")]: p[p.find(": ") + len(": "):] +
|
||||
" (from {})".format(func.__module__ + "." + func.__name__)
|
||||
for p in params_doc
|
||||
}
|
||||
|
||||
param_doc_dict = {}
|
||||
for func, persists in funcs:
|
||||
doc = inspect.getdoc(func)
|
||||
doc_dict = helper(doc, func)
|
||||
for k, v in doc_dict.items():
|
||||
if k not in persists:
|
||||
continue
|
||||
param_doc_dict[k] = {"doc": v, "params": persists[k]}
|
||||
return param_doc_dict
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
"This is the converter for converting protocol buffer between tf and onnx."
|
||||
)
|
||||
|
||||
# required two args, source and destination path
|
||||
parser.add_argument("--infile", "-i", help="Input file path.", required=True)
|
||||
parser.add_argument(
|
||||
"--outfile", "-o", help="Output file path.", required=True)
|
||||
|
||||
def add_argument_group(parser, group_name, funcs):
|
||||
group = parser.add_argument_group(group_name)
|
||||
param_doc_dict = get_param_doc_dict(funcs)
|
||||
for k, v in param_doc_dict.items():
|
||||
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])
|
||||
|
||||
# backend args
|
||||
# Args must be named consistently with respect to backend.prepare.
|
||||
add_argument_group(parser, "backend arguments (onnx -> tf)",
|
||||
[(backend.prepare, {
|
||||
"device": {},
|
||||
"strict": {},
|
||||
"logging_level": {}
|
||||
})])
|
||||
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def convert(infile, outfile, **kwargs):
|
||||
"""Convert pb.
|
||||
|
||||
Args:
|
||||
infile: Input path.
|
||||
outfile: Output path.
|
||||
**kwargs: Other args for converting.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
logging_level = kwargs.get("logging_level", "INFO")
|
||||
common.logger.setLevel(logging_level)
|
||||
common.logger.handlers[0].setLevel(logging_level)
|
||||
|
||||
common.logger.info("Start converting onnx pb to tf pb:")
|
||||
onnx_model = onnx.load(infile)
|
||||
tf_rep = backend.prepare(onnx_model, **kwargs)
|
||||
tf_rep.export_graph(outfile)
|
||||
common.logger.info("Converting completes successfully.")
|
||||
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import onnx_tf.backend
|
||||
import onnx_tf.backend_rep
|
||||
from third_party import get_info
|
||||
|
||||
|
||||
def main(docs_dir):
|
||||
gen_api(docs_dir)
|
||||
gen_cli(docs_dir)
|
||||
|
||||
|
||||
def gen_api(docs_dir):
|
||||
gen_doc_for = {
|
||||
'onnx_tf.backend': [
|
||||
onnx_tf.backend.prepare,
|
||||
],
|
||||
'onnx_tf.backend_rep.TensorflowRep': [
|
||||
onnx_tf.backend_rep.TensorflowRep.export_graph,
|
||||
]
|
||||
}
|
||||
with open(os.path.join(docs_dir, 'API.md'), 'w') as doc_file:
|
||||
doc_file.write('ONNX-Tensorflow API\n')
|
||||
doc_file.write('======\n\n')
|
||||
|
||||
for scope, funcs in sorted(gen_doc_for.items()):
|
||||
for func in funcs:
|
||||
doc_parsed = get_info.parse_docstring(func.__doc__)
|
||||
doc_file.write('#### `' + scope + '.' + func.__name__ + '`\n\n')
|
||||
doc_file.write('<details>\n')
|
||||
doc_file.write(' <summary>')
|
||||
doc_file.write(doc_parsed['short_description'] + '\n\n')
|
||||
doc_file.write(' </summary>\n')
|
||||
doc_file.write(doc_parsed['long_description'] + '\n\n')
|
||||
doc_file.write('</details>\n\n\n\n')
|
||||
|
||||
doc_file.write('_params_:\n\n')
|
||||
for param in doc_parsed['params']:
|
||||
doc_file.write('`' + param['name'] + '` : ' + param['doc'] + '\n\n')
|
||||
|
||||
doc_file.write('_returns_:\n\n')
|
||||
doc_file.write(doc_parsed['returns'] + '\n\n')
|
||||
|
||||
|
||||
def gen_cli(docs_dir):
|
||||
with open(os.path.join(docs_dir, 'CLI_template.md'), 'r') as cli_temp_file:
|
||||
temp_lines = cli_temp_file.readlines()
|
||||
|
||||
lines = []
|
||||
for line in temp_lines:
|
||||
matched = re.match(r"{onnx-tf.*}", line)
|
||||
if matched:
|
||||
command = matched.string.strip()[1:-1]
|
||||
output = subprocess.check_output(command.split(" ")).decode("UTF-8")
|
||||
lines.append(output)
|
||||
else:
|
||||
lines.append(line)
|
||||
|
||||
with open(os.path.join(docs_dir, 'CLI.md'), 'w') as cli_file:
|
||||
cli_file.writelines(lines)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
docs_dir = os.path.join(base_dir, 'doc')
|
||||
main(docs_dir)
|
||||
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pprint
|
||||
|
||||
from onnx import defs
|
||||
|
||||
from onnx_tf.common.handler_helper import get_backend_coverage
|
||||
from onnx_tf.common.handler_helper import get_backend_partial_support_detail
|
||||
|
||||
|
||||
def main():
|
||||
backend_opset_dict = {}
|
||||
|
||||
for schema in defs.get_all_schemas():
|
||||
op_name = schema.name
|
||||
backend_opset_dict[op_name] = []
|
||||
|
||||
backend_onnx_coverage, backend_experimental_op = get_backend_coverage()
|
||||
backend_opset_dict.update(backend_onnx_coverage.get(defs.ONNX_DOMAIN, {}))
|
||||
backend_ps_dict = get_backend_partial_support_detail()
|
||||
|
||||
with open('opset_version.py', 'w') as version_file:
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
version_file.write("backend_opset_version = {\n " +
|
||||
pp.pformat(backend_opset_dict)[1:-1] + "\n}\n\n")
|
||||
version_file.write("backend_partial_support = {\n " +
|
||||
pp.pformat(backend_ps_dict)[1:-1] + "\n}\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import getopt
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from onnx_tf import opset_version, __version__
|
||||
|
||||
|
||||
def main(docs_dir):
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
docs_dir = os.path.join(base_dir, 'doc')
|
||||
onnx_version = onnx.__version__
|
||||
onnx_tf_release_build = False
|
||||
|
||||
try:
|
||||
opts, args = getopt.getopt(sys.argv[1:], 'h:mr',
|
||||
['onnx_master', 'onnx_tf_release_build'])
|
||||
except getopt.GetoptError:
|
||||
print('Usage:')
|
||||
print(' gen_status.py [-m -r]')
|
||||
print(' gen_status.py -h')
|
||||
print('Description:')
|
||||
print(' -m, --onnx_master installed ONNX is the latest master code')
|
||||
print(' if omitted, ONNX version is onnx.__version__')
|
||||
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
|
||||
print(' stated in the VERSION_NUMBER file')
|
||||
print(' if omitted, the report is for ONNX-TF master')
|
||||
print(' -h show this help message and exit')
|
||||
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
|
||||
print(' gen_status.py')
|
||||
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
|
||||
print(' gen_status.py -m')
|
||||
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
|
||||
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
|
||||
print(' gen_status.py -r')
|
||||
sys.exit(2)
|
||||
for opt, arg in opts:
|
||||
if opt == '-h':
|
||||
print('Usage:')
|
||||
print(' gen_status.py [-m -r]')
|
||||
print(' gen_status.py -h')
|
||||
print('Description:')
|
||||
print(' -m, --onnx_master installed ONNX is the latest master code')
|
||||
print(' if omitted, ONNX version is onnx.__version__')
|
||||
print(' -r, --onnx_tf_release_build create report for ONNX-TF release with version')
|
||||
print(' stated in the VERSION_NUMBER file')
|
||||
print(' if omitted, the report is for ONNX-TF master')
|
||||
print(' -h show this help message and exit')
|
||||
print('eg. 1. generate support_status.md for ONNX-TF master and ONNX onnx.__version__')
|
||||
print(' gen_status.py')
|
||||
print(' 2. generate support_status.md for ONNX-TF master and ONNX master')
|
||||
print(' gen_status.py -m')
|
||||
print(' 3. generate support_status_<onnx_tf_version>.md for ONNX-TF version')
|
||||
print(' stated in the VERSION_NUMBER file and ONNX onnx.__version__ ')
|
||||
print(' gen_status.py -r')
|
||||
sys.exit()
|
||||
elif opt in ('-m', '--onnx_master'):
|
||||
onnx_version = 'master'
|
||||
elif opt in ('-r', '--onnx_tf_release_build'):
|
||||
onnx_tf_release_build = True
|
||||
|
||||
gen_support_status(docs_dir, onnx_version, onnx_tf_release_build)
|
||||
|
||||
|
||||
def gen_support_status(docs_dir, onnx_version, onnx_tf_release_build):
|
||||
|
||||
# set filename
|
||||
if onnx_tf_release_build:
|
||||
onnx_tf_version = 'v' + __version__
|
||||
filename = 'support_status_' + onnx_tf_version.replace('.', '_') + '.md'
|
||||
else: # onnx-tf = master
|
||||
# get onnx-tf commit id
|
||||
onnx_tf_commit_id = subprocess.check_output('git rev-parse HEAD',
|
||||
shell=True)
|
||||
onnx_tf_commit_id = onnx_tf_commit_id.decode().strip('\n')
|
||||
onnx_tf_version = 'Master ( commit id: {} )'.format(onnx_tf_commit_id)
|
||||
filename = 'support_status.md'
|
||||
|
||||
with open(os.path.join(docs_dir, filename), 'w') as status_file:
|
||||
status_file.write('# ONNX-Tensorflow Support Status\n')
|
||||
status_file.write('|||\n')
|
||||
status_file.write('|-:|:-|\n')
|
||||
status_file.write('|ONNX-Tensorflow Version|{}|\n'.format(onnx_tf_version))
|
||||
|
||||
# get onnx commit id
|
||||
if onnx_version == 'master':
|
||||
onnx_commit_id = onnx.version.git_version
|
||||
status_file.write(
|
||||
'|ONNX Version|Master ( commit id: {} )|\n'.format(onnx_commit_id))
|
||||
else:
|
||||
status_file.write('|ONNX Version|v{}|\n'.format(onnx_version))
|
||||
|
||||
# get tf_version
|
||||
status_file.write('|Tensorflow Version|v{}|\n\n'.format(tf.__version__))
|
||||
|
||||
# display the table legend
|
||||
status_file.write('Notes:\n')
|
||||
status_file.write('* Values that are new or updated from a ')
|
||||
status_file.write('previous opset version are in bold.\n')
|
||||
status_file.write('* -: not defined in corresponding ONNX ')
|
||||
status_file.write('opset version\n')
|
||||
status_file.write('* \*: the operator is deprecated\n')
|
||||
status_file.write('* :small_red_triangle:: not supported yet\n')
|
||||
status_file.write('* :small_orange_diamond:: partially supported\n')
|
||||
status_file.write('* the rest are all supported\n\n')
|
||||
|
||||
# get oll onnx ops
|
||||
onnx_ops = {}
|
||||
for schema in onnx.defs.get_all_schemas():
|
||||
if schema.domain == '': # only get onnx ops
|
||||
onnx_ops[schema.name] = {
|
||||
'versions': [],
|
||||
'deprecated': schema.since_version if schema.deprecated else -1
|
||||
}
|
||||
for schema in onnx.defs.get_all_schemas_with_history():
|
||||
if schema.domain == '': # only get onnx ops
|
||||
op = onnx_ops[schema.name]
|
||||
if schema.deprecated:
|
||||
if schema.since_version <= op['deprecated']:
|
||||
op['versions'].append(schema.since_version)
|
||||
op['deprecated'] = schema.since_version
|
||||
else:
|
||||
op['versions'].append(schema.since_version)
|
||||
|
||||
# get all onnx-tf supported ops
|
||||
onnx_tf_ops = opset_version.backend_opset_version
|
||||
onnx_tf_ops_ps = opset_version.backend_partial_support
|
||||
|
||||
# get the cureent opset version
|
||||
current_opset = onnx.defs.onnx_opset_version()
|
||||
|
||||
# setup table header
|
||||
status_file.write('|||')
|
||||
for i in range(current_opset):
|
||||
status_file.write('|')
|
||||
status_file.write('\n|:-:|:-:|')
|
||||
for i in range(current_opset):
|
||||
status_file.write(':-:|')
|
||||
status_file.write('\n|**ONNX Operator**|')
|
||||
for opset in range(1, current_opset + 1):
|
||||
status_file.write('**Opset {}**|'.format(opset))
|
||||
status_file.write('**ONNX Operator**|')
|
||||
|
||||
ops_count = len(onnx_ops)
|
||||
# fill in data for the table
|
||||
for key, val in sorted(onnx_ops.items()):
|
||||
try:
|
||||
status_file.write('\n|{}|'.format(key))
|
||||
i = 0
|
||||
vers = val['versions']
|
||||
deprecated = val['deprecated']
|
||||
for opset in range(1, current_opset + 1):
|
||||
if i <= len(vers) - 1:
|
||||
lb = vers[i]
|
||||
ub = vers[i + 1] if i < len(vers) - 1 else vers[i]
|
||||
if opset < lb:
|
||||
if i == 0:
|
||||
status_file.write('-')
|
||||
elif opset == lb:
|
||||
status_file.write('**{}**'.format(lb))
|
||||
if lb >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif lb not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
else: # opset > lb
|
||||
if opset < ub:
|
||||
status_file.write('{}'.format(lb))
|
||||
if lb >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif lb not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
elif opset == ub:
|
||||
status_file.write('**{}**'.format(ub))
|
||||
if ub >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif ub not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
i += 1
|
||||
else: #opset > ub
|
||||
status_file.write('{}'.format(ub))
|
||||
if ub >= deprecated and deprecated > 0:
|
||||
status_file.write('\*')
|
||||
elif ub not in onnx_tf_ops[key]:
|
||||
status_file.write(':small_red_triangle:')
|
||||
if opset == current_opset:
|
||||
ops_count -= 1
|
||||
elif key in onnx_tf_ops_ps:
|
||||
status_file.write(':small_orange_diamond:')
|
||||
status_file.write('|')
|
||||
status_file.write('{}|'.format(key))
|
||||
except:
|
||||
# ops defined in onnx but not in opset_version.backend_opset_versionn
|
||||
status_file.write(':small_red_triangle:|')
|
||||
|
||||
status_file.write(
|
||||
'\n\nONNX-TF Supported Operators / ONNX Operators: {} / {}'.format(
|
||||
ops_count, len(onnx_ops)))
|
||||
|
||||
# display partial support footnote
|
||||
status_file.write('\n\nNotes:\n')
|
||||
index = 1
|
||||
for key in onnx_tf_ops_ps:
|
||||
status_file.write(
|
||||
str(index) + '. ' + key + ': ' + onnx_tf_ops_ps[key] + '\n')
|
||||
index += 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
__all__ = [
|
||||
modname for _, modname, _ in pkgutil.walk_packages(
|
||||
path=[os.path.split(__file__)[0]])
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user