Moved prediction code to specific functions

This commit is contained in:
Heiko J Schick
2020-10-23 20:41:35 +02:00
parent 883e3a6867
commit 2ebae83732
+32 -26
View File
@@ -540,40 +540,18 @@ LABELS = ["person", # 0
"hair drier",
"toothbrush"]
def main():
def make_prediction(model):
"""
Defined starting point of source code.
Execute predictions with YOLO v3.
"""
# Step 3:
# (1) Define the model
# (2) Load the weight
# (3) Save the model
# Define the YOLO v3 model
yolov3 = make_yolov3_model()
print(yolov3.summary())
# Load the weights
# Source: https://pjreddie.com/media/files/yolov3.weights
weight_reader = WeightReader('yolov3.weights')
# Set the weights
weight_reader.load_weights(yolov3)
# Save the model to file
yolov3.save('yolov3.h5')
# Step 8:
# Make Prediction
for photo_filename in glob.glob("images/test/dog/*"):
for photo_filename in glob.glob("images/test/motorbike/images2.jpg"):
# Define the expected input shape for the model
input_w, input_h = 416, 416
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
# Make prediction
netouts = yolov3.predict(image)
netouts = model.predict(image)
# Summarize the shape of the list of arrays
print([a.shape for a in netouts])
@@ -600,5 +578,33 @@ def main():
# Draw what we found
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
def main():
"""
Defined starting point of source code.
"""
# Step 3:
# (1) Define the model
# (2) Load the weight
# (3) Save the model
# Define the YOLO v3 model
yolov3 = make_yolov3_model()
print(yolov3.summary())
# Load the weights
# Source: https://pjreddie.com/media/files/yolov3.weights
weight_reader = WeightReader('yolov3.weights')
# Set the weights
weight_reader.load_weights(yolov3)
# Save the model to file
yolov3.save('yolov3.h5')
# Step 8:
# Make Prediction
make_prediction(yolov3)
if __name__ == "__main__":
main()