import argparse import inspect import logging import os import shutil import onnx import tensorflow as tf from tensorflow.core.framework import graph_pb2 from tensorflow.python.tools import freeze_graph import onnx_tf.backend as backend import onnx_tf.common as common from onnx_tf.common import get_unique_suffix from onnx_tf.pb_wrapper import TensorflowGraph def main(args): args = parse_args(args) convert(**{k: v for k, v in vars(args).items() if v is not None}) def parse_args(args): class ListAction(argparse.Action): """ Define how to convert command line list strings to Python objects. """ def __call__(self, parser, namespace, values, option_string=None): values = values if values[0] not in ("(", "[") or values[-1] not in ( ")", "]") else values[1:-1] res = [] for value in values.split(","): if value.isdigit(): res.append(int(value)) else: res.append(value) setattr(namespace, self.dest, res) class OpsetAction(argparse.Action): """ Define how to convert command line opset strings to Python objects. """ def __call__(self, parser, namespace, values, option_string=None): if values.isdigit(): setattr(namespace, "opset", int(values)) else: res = [] while values and values[0] in ("(", "["): values = values[1:] while values and values[-1] in (")", "]"): values = values[:-1] for value in values.split("),("): l, r = value.split(",") res.append((l, int(r))) setattr(namespace, "opset", res) def get_param_doc_dict(funcs): """Get doc of funcs params. Args: funcs: Target funcs. Returns: Dict of params doc. """ # TODO(fumihwh): support google doc format def helper(doc, func): first_idx = doc.find(":param") last_idx = doc.find(":return") last_idx = last_idx if last_idx != -1 else len(doc) param_doc = doc[first_idx:last_idx] params_doc = param_doc.split(":param ")[1:] return { p[:p.find(": ")]: p[p.find(": ") + len(": "):] + " (from {})".format(func.__module__ + "." + func.__name__) for p in params_doc } param_doc_dict = {} for func, persists in funcs: doc = inspect.getdoc(func) doc_dict = helper(doc, func) for k, v in doc_dict.items(): if k not in persists: continue param_doc_dict[k] = {"doc": v, "params": persists[k]} return param_doc_dict parser = argparse.ArgumentParser( description= "This is the converter for converting protocol buffer between tf and onnx." ) # required two args, source and destination path parser.add_argument("--infile", "-i", help="Input file path.", required=True) parser.add_argument( "--outfile", "-o", help="Output file path.", required=True) def add_argument_group(parser, group_name, funcs): group = parser.add_argument_group(group_name) param_doc_dict = get_param_doc_dict(funcs) for k, v in param_doc_dict.items(): group.add_argument("--{}".format(k), help=v["doc"], **v["params"]) # backend args # Args must be named consistently with respect to backend.prepare. add_argument_group(parser, "backend arguments (onnx -> tf)", [(backend.prepare, { "device": {}, "strict": {}, "logging_level": {} })]) return parser.parse_args(args) def convert(infile, outfile, **kwargs): """Convert pb. Args: infile: Input path. outfile: Output path. **kwargs: Other args for converting. Returns: None. """ logging_level = kwargs.get("logging_level", "INFO") common.logger.setLevel(logging_level) common.logger.handlers[0].setLevel(logging_level) common.logger.info("Start converting onnx pb to tf pb:") onnx_model = onnx.load(infile) tf_rep = backend.prepare(onnx_model, **kwargs) tf_rep.export_graph(outfile) common.logger.info("Converting completes successfully.")