pt2tf readme modify

This commit is contained in:
zxros10
2020-09-28 20:04:17 -07:00
parent 27faf8e0f2
commit f673126e32
3 changed files with 214 additions and 146 deletions
+74 -14
View File
@@ -13,22 +13,82 @@
# limitations under the License.
#########################################################################
import os
import torch
from efficientnet_pytorch import EfficientNet
import argparse
# Specify which model to use
model_name = 'efficientnet-b3'
image_size = EfficientNet.get_image_size(model_name)
print('Image size: ', image_size)
def parse_args():
# Load model
model = EfficientNet.from_pretrained(model_name)
model.set_swish(memory_efficient=False)
model.eval()
print('Model image size: ', model._global_params.image_size)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Dummy input for ONNX
dummy_input = torch.randn(1, 3, 300, 300)
parser.add_argument('--model_type', default=0, type=int,
help="""the type of pytorch model,
0:model wiht net and weight; 1:only weight""")
parser.add_argument('--model_path', default=None,
help="""the pytorch model pth file path""")
parser.add_argument('--input_shape', default=[], type=list,
help="""the model input shape, e.g. 1 3 224 224""")
args, unknown_args = parser.parse_known_args()
if len(unknown_args) > 0:
for bad_arg in unknown_args:
print("ERROR: Unknown command line arg: %s" % bad_arg)
raise ValueError("Invalid command line arg(s)")
# Export with ONNX
torch.onnx.export(model, dummy_input, f"{model_name}.onnx", verbose=True)
return args
def load_weight_model(model_file):
#修改点1:放开导入模型的注释,并导入自己的模型实现接口.
#例如:模型实现代码目录为./resnet50,网络实现在resnet.py的class ResNet50类
#from resnet50.resnet import ResNet50
model = None
#修改点2:放开创建模型对象注释,并根据自己的模型接口创建模型对象
#model = ResNet50()
#修改点3:放开加载模型的注释
#model.load_state_dict(torch.load(model_file))
return model
def load_complete_mode(model_file):
return torch.load(model_file)
def load_model(model_type, model_path):
if not os.path.exists(model_path):
print("The pytorch model is not exist")
return None
model = None
if model_type == 0:
model = load_complete_mode(model_path)
elif model_type == 1:
model = load_weight_model(model_path)
else:
print("Unknow model type %d, please "
"execute --help to obtain help"%(model_type))
return model
def main():
args = parse_args()
#加载模型
model = load_model(args.model_type, args.model_path)
if model is None:
print("Load model failed")
return
#将模型切换到推理状态
model.eval()
#创建输入张量
input = torch.randn(tuple(args.input_shape))
#生成的onnx文件存放在pytorch模型同级目录下,文件名相同,后缀为onnx
export_onnx_file = os.path.splitext(model_file)[0] + '.onnx'
# Export with ONNX
torch.onnx.export(model, input, export_onnx_file, verbose=True)
if __name__== "__main__":
main()