Files
ascend-tools/pt2pb/onnx-tensorflow/onnx_tf/converter.py
T
2020-10-14 08:55:07 +08:00

139 lines
4.0 KiB
Python

import argparse
import inspect
import logging
import os
import shutil
import onnx
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.tools import freeze_graph
import onnx_tf.backend as backend
import onnx_tf.common as common
from onnx_tf.common import get_unique_suffix
from onnx_tf.pb_wrapper import TensorflowGraph
def main(args):
args = parse_args(args)
convert(**{k: v for k, v in vars(args).items() if v is not None})
def parse_args(args):
class ListAction(argparse.Action):
""" Define how to convert command line list strings to Python objects.
"""
def __call__(self, parser, namespace, values, option_string=None):
values = values if values[0] not in ("(", "[") or values[-1] not in (
")", "]") else values[1:-1]
res = []
for value in values.split(","):
if value.isdigit():
res.append(int(value))
else:
res.append(value)
setattr(namespace, self.dest, res)
class OpsetAction(argparse.Action):
""" Define how to convert command line opset strings to Python objects.
"""
def __call__(self, parser, namespace, values, option_string=None):
if values.isdigit():
setattr(namespace, "opset", int(values))
else:
res = []
while values and values[0] in ("(", "["):
values = values[1:]
while values and values[-1] in (")", "]"):
values = values[:-1]
for value in values.split("),("):
l, r = value.split(",")
res.append((l, int(r)))
setattr(namespace, "opset", res)
def get_param_doc_dict(funcs):
"""Get doc of funcs params.
Args:
funcs: Target funcs.
Returns:
Dict of params doc.
"""
# TODO(fumihwh): support google doc format
def helper(doc, func):
first_idx = doc.find(":param")
last_idx = doc.find(":return")
last_idx = last_idx if last_idx != -1 else len(doc)
param_doc = doc[first_idx:last_idx]
params_doc = param_doc.split(":param ")[1:]
return {
p[:p.find(": ")]: p[p.find(": ") + len(": "):] +
" (from {})".format(func.__module__ + "." + func.__name__)
for p in params_doc
}
param_doc_dict = {}
for func, persists in funcs:
doc = inspect.getdoc(func)
doc_dict = helper(doc, func)
for k, v in doc_dict.items():
if k not in persists:
continue
param_doc_dict[k] = {"doc": v, "params": persists[k]}
return param_doc_dict
parser = argparse.ArgumentParser(
description=
"This is the converter for converting protocol buffer between tf and onnx."
)
# required two args, source and destination path
parser.add_argument("--infile", "-i", help="Input file path.", required=True)
parser.add_argument(
"--outfile", "-o", help="Output file path.", required=True)
def add_argument_group(parser, group_name, funcs):
group = parser.add_argument_group(group_name)
param_doc_dict = get_param_doc_dict(funcs)
for k, v in param_doc_dict.items():
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])
# backend args
# Args must be named consistently with respect to backend.prepare.
add_argument_group(parser, "backend arguments (onnx -> tf)",
[(backend.prepare, {
"device": {},
"strict": {},
"logging_level": {}
})])
return parser.parse_args(args)
def convert(infile, outfile, **kwargs):
"""Convert pb.
Args:
infile: Input path.
outfile: Output path.
**kwargs: Other args for converting.
Returns:
None.
"""
logging_level = kwargs.get("logging_level", "INFO")
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.logger.info("Start converting onnx pb to tf pb:")
onnx_model = onnx.load(infile)
tf_rep = backend.prepare(onnx_model, **kwargs)
tf_rep.export_graph(outfile)
common.logger.info("Converting completes successfully.")