139 lines
4.0 KiB
Python
139 lines
4.0 KiB
Python
import argparse
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import shutil
|
|
|
|
import onnx
|
|
import tensorflow as tf
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.python.tools import freeze_graph
|
|
|
|
import onnx_tf.backend as backend
|
|
import onnx_tf.common as common
|
|
from onnx_tf.common import get_unique_suffix
|
|
from onnx_tf.pb_wrapper import TensorflowGraph
|
|
|
|
|
|
def main(args):
|
|
args = parse_args(args)
|
|
convert(**{k: v for k, v in vars(args).items() if v is not None})
|
|
|
|
|
|
def parse_args(args):
|
|
|
|
class ListAction(argparse.Action):
|
|
""" Define how to convert command line list strings to Python objects.
|
|
"""
|
|
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
values = values if values[0] not in ("(", "[") or values[-1] not in (
|
|
")", "]") else values[1:-1]
|
|
res = []
|
|
for value in values.split(","):
|
|
if value.isdigit():
|
|
res.append(int(value))
|
|
else:
|
|
res.append(value)
|
|
setattr(namespace, self.dest, res)
|
|
|
|
class OpsetAction(argparse.Action):
|
|
""" Define how to convert command line opset strings to Python objects.
|
|
"""
|
|
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
if values.isdigit():
|
|
setattr(namespace, "opset", int(values))
|
|
else:
|
|
res = []
|
|
while values and values[0] in ("(", "["):
|
|
values = values[1:]
|
|
while values and values[-1] in (")", "]"):
|
|
values = values[:-1]
|
|
for value in values.split("),("):
|
|
l, r = value.split(",")
|
|
res.append((l, int(r)))
|
|
setattr(namespace, "opset", res)
|
|
|
|
def get_param_doc_dict(funcs):
|
|
"""Get doc of funcs params.
|
|
|
|
Args:
|
|
funcs: Target funcs.
|
|
|
|
Returns:
|
|
Dict of params doc.
|
|
"""
|
|
|
|
# TODO(fumihwh): support google doc format
|
|
def helper(doc, func):
|
|
first_idx = doc.find(":param")
|
|
last_idx = doc.find(":return")
|
|
last_idx = last_idx if last_idx != -1 else len(doc)
|
|
param_doc = doc[first_idx:last_idx]
|
|
params_doc = param_doc.split(":param ")[1:]
|
|
return {
|
|
p[:p.find(": ")]: p[p.find(": ") + len(": "):] +
|
|
" (from {})".format(func.__module__ + "." + func.__name__)
|
|
for p in params_doc
|
|
}
|
|
|
|
param_doc_dict = {}
|
|
for func, persists in funcs:
|
|
doc = inspect.getdoc(func)
|
|
doc_dict = helper(doc, func)
|
|
for k, v in doc_dict.items():
|
|
if k not in persists:
|
|
continue
|
|
param_doc_dict[k] = {"doc": v, "params": persists[k]}
|
|
return param_doc_dict
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description=
|
|
"This is the converter for converting protocol buffer between tf and onnx."
|
|
)
|
|
|
|
# required two args, source and destination path
|
|
parser.add_argument("--infile", "-i", help="Input file path.", required=True)
|
|
parser.add_argument(
|
|
"--outfile", "-o", help="Output file path.", required=True)
|
|
|
|
def add_argument_group(parser, group_name, funcs):
|
|
group = parser.add_argument_group(group_name)
|
|
param_doc_dict = get_param_doc_dict(funcs)
|
|
for k, v in param_doc_dict.items():
|
|
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])
|
|
|
|
# backend args
|
|
# Args must be named consistently with respect to backend.prepare.
|
|
add_argument_group(parser, "backend arguments (onnx -> tf)",
|
|
[(backend.prepare, {
|
|
"device": {},
|
|
"strict": {},
|
|
"logging_level": {}
|
|
})])
|
|
|
|
return parser.parse_args(args)
|
|
|
|
|
|
def convert(infile, outfile, **kwargs):
|
|
"""Convert pb.
|
|
|
|
Args:
|
|
infile: Input path.
|
|
outfile: Output path.
|
|
**kwargs: Other args for converting.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
logging_level = kwargs.get("logging_level", "INFO")
|
|
common.logger.setLevel(logging_level)
|
|
common.logger.handlers[0].setLevel(logging_level)
|
|
|
|
common.logger.info("Start converting onnx pb to tf pb:")
|
|
onnx_model = onnx.load(infile)
|
|
tf_rep = backend.prepare(onnx_model, **kwargs)
|
|
tf_rep.export_graph(outfile)
|
|
common.logger.info("Converting completes successfully.")
|