pt2tf工具评审意见修改
This commit is contained in:
+9
-27
@@ -21,12 +21,9 @@ def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument('--model_type', default=0, type=int,
|
||||
help="""the type of pytorch model,
|
||||
0:model wiht net and weight; 1:only weight""")
|
||||
parser.add_argument('--model_path', default=None,
|
||||
help="""the pytorch model pth file path""")
|
||||
parser.add_argument('--input_shape', default=[], type=list,
|
||||
parser.add_argument('--input_shape', nargs='+', type=int,
|
||||
help="""the model input shape, e.g. 1 3 224 224""")
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
@@ -37,12 +34,15 @@ def parse_args():
|
||||
|
||||
return args
|
||||
|
||||
def load_weight_model(model_file):
|
||||
def load_model(model_path, input_shape):
|
||||
if not os.path.exists(model_path):
|
||||
print("The pytorch model is not exist")
|
||||
return None
|
||||
|
||||
#修改点1:放开导入模型的注释,并导入自己的模型实现接口.
|
||||
#例如:模型实现代码目录为./resnet50,网络实现在resnet.py的class ResNet50类
|
||||
#from resnet50.resnet import ResNet50
|
||||
|
||||
model = None
|
||||
|
||||
#修改点2:放开创建模型对象注释,并根据自己的模型接口创建模型对象
|
||||
#model = ResNet50()
|
||||
@@ -52,30 +52,12 @@ def load_weight_model(model_file):
|
||||
|
||||
return model
|
||||
|
||||
def load_complete_mode(model_file):
|
||||
return torch.load(model_file)
|
||||
|
||||
|
||||
def load_model(model_type, model_path):
|
||||
if not os.path.exists(model_path):
|
||||
print("The pytorch model is not exist")
|
||||
return None
|
||||
|
||||
model = None
|
||||
if model_type == 0:
|
||||
model = load_complete_mode(model_path)
|
||||
elif model_type == 1:
|
||||
model = load_weight_model(model_path)
|
||||
else:
|
||||
print("Unknow model type %d, please "
|
||||
"execute --help to obtain help"%(model_type))
|
||||
|
||||
return model
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print("model path ", args.model_path, ", shape ", args.input_shape)
|
||||
#加载模型
|
||||
model = load_model(args.model_type, args.model_path)
|
||||
model = load_model(args.model_path, args.input_shape)
|
||||
if model is None:
|
||||
print("Load model failed")
|
||||
return
|
||||
@@ -85,7 +67,7 @@ def main():
|
||||
#创建输入张量
|
||||
input = torch.randn(tuple(args.input_shape))
|
||||
#生成的onnx文件存放在pytorch模型同级目录下,文件名相同,后缀为onnx
|
||||
export_onnx_file = os.path.splitext(model_file)[0] + '.onnx'
|
||||
export_onnx_file = os.path.splitext(args.model_path)[0] + '.onnx'
|
||||
|
||||
# Export with ONNX
|
||||
torch.onnx.export(model, input, export_onnx_file, verbose=True)
|
||||
|
||||
Reference in New Issue
Block a user