77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#########################################################################
|
|
|
|
import os
|
|
import torch
|
|
import argparse
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('--model_path', default=None,
|
|
help="""the pytorch model pth file path""")
|
|
parser.add_argument('--input_shape', nargs='+', type=int,
|
|
help="""the model input shape, e.g. 1 3 224 224""")
|
|
|
|
args, unknown_args = parser.parse_known_args()
|
|
if len(unknown_args) > 0:
|
|
for bad_arg in unknown_args:
|
|
print("ERROR: Unknown command line arg: %s" % bad_arg)
|
|
raise ValueError("Invalid command line arg(s)")
|
|
|
|
return args
|
|
|
|
def load_model(model_path, input_shape):
|
|
if not os.path.exists(model_path):
|
|
print("The pytorch model is not exist")
|
|
return None
|
|
|
|
#修改点1:放开导入模型的注释,并导入自己的模型实现接口.
|
|
#例如:模型实现代码目录为./resnet50,网络实现在resnet.py的class ResNet50类
|
|
#from resnet50.resnet import ResNet50
|
|
|
|
|
|
#修改点2:放开创建模型对象注释,并根据自己的模型接口创建模型对象
|
|
#model = ResNet50()
|
|
|
|
#修改点3:放开加载模型的注释
|
|
#model.load_state_dict(torch.load(model_file))
|
|
|
|
return model
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
print("model path ", args.model_path, ", shape ", args.input_shape)
|
|
#加载模型
|
|
model = load_model(args.model_path, args.input_shape)
|
|
if model is None:
|
|
print("Load model failed")
|
|
return
|
|
|
|
#将模型切换到推理状态
|
|
model.eval()
|
|
#创建输入张量
|
|
input = torch.randn(tuple(args.input_shape))
|
|
#生成的onnx文件存放在pytorch模型同级目录下,文件名相同,后缀为onnx
|
|
export_onnx_file = os.path.splitext(args.model_path)[0] + '.onnx'
|
|
|
|
# Export with ONNX
|
|
torch.onnx.export(model, input, export_onnx_file, verbose=True)
|
|
|
|
if __name__== "__main__":
|
|
main()
|