Moved prediction code to specific functions
This commit is contained in:
@@ -540,40 +540,18 @@ LABELS = ["person", # 0
|
|||||||
"hair drier",
|
"hair drier",
|
||||||
"toothbrush"]
|
"toothbrush"]
|
||||||
|
|
||||||
def main():
|
def make_prediction(model):
|
||||||
"""
|
"""
|
||||||
Defined starting point of source code.
|
Execute predictions with YOLO v3.
|
||||||
"""
|
"""
|
||||||
|
for photo_filename in glob.glob("images/test/motorbike/images2.jpg"):
|
||||||
# Step 3:
|
|
||||||
# (1) Define the model
|
|
||||||
# (2) Load the weight
|
|
||||||
# (3) Save the model
|
|
||||||
|
|
||||||
# Define the YOLO v3 model
|
|
||||||
yolov3 = make_yolov3_model()
|
|
||||||
print(yolov3.summary())
|
|
||||||
|
|
||||||
# Load the weights
|
|
||||||
# Source: https://pjreddie.com/media/files/yolov3.weights
|
|
||||||
weight_reader = WeightReader('yolov3.weights')
|
|
||||||
|
|
||||||
# Set the weights
|
|
||||||
weight_reader.load_weights(yolov3)
|
|
||||||
|
|
||||||
# Save the model to file
|
|
||||||
yolov3.save('yolov3.h5')
|
|
||||||
|
|
||||||
# Step 8:
|
|
||||||
# Make Prediction
|
|
||||||
for photo_filename in glob.glob("images/test/dog/*"):
|
|
||||||
# Define the expected input shape for the model
|
# Define the expected input shape for the model
|
||||||
input_w, input_h = 416, 416
|
input_w, input_h = 416, 416
|
||||||
|
|
||||||
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
|
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
|
||||||
|
|
||||||
# Make prediction
|
# Make prediction
|
||||||
netouts = yolov3.predict(image)
|
netouts = model.predict(image)
|
||||||
|
|
||||||
# Summarize the shape of the list of arrays
|
# Summarize the shape of the list of arrays
|
||||||
print([a.shape for a in netouts])
|
print([a.shape for a in netouts])
|
||||||
@@ -600,5 +578,33 @@ def main():
|
|||||||
# Draw what we found
|
# Draw what we found
|
||||||
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Defined starting point of source code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Step 3:
|
||||||
|
# (1) Define the model
|
||||||
|
# (2) Load the weight
|
||||||
|
# (3) Save the model
|
||||||
|
|
||||||
|
# Define the YOLO v3 model
|
||||||
|
yolov3 = make_yolov3_model()
|
||||||
|
print(yolov3.summary())
|
||||||
|
|
||||||
|
# Load the weights
|
||||||
|
# Source: https://pjreddie.com/media/files/yolov3.weights
|
||||||
|
weight_reader = WeightReader('yolov3.weights')
|
||||||
|
|
||||||
|
# Set the weights
|
||||||
|
weight_reader.load_weights(yolov3)
|
||||||
|
|
||||||
|
# Save the model to file
|
||||||
|
yolov3.save('yolov3.h5')
|
||||||
|
|
||||||
|
# Step 8:
|
||||||
|
# Make Prediction
|
||||||
|
make_prediction(yolov3)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user