Coding style changes

This commit is contained in:
Heiko J Schick
2020-10-21 16:12:21 +02:00
parent 18a3c7a43a
commit 8eb406b02c
+49 -40
View File
@@ -17,7 +17,6 @@ from matplotlib.patches import Rectangle
# Step 1: # Step 1:
# Define WeightReader class # Define WeightReader class
class WeightReader: class WeightReader:
""" """
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
@@ -85,7 +84,7 @@ class WeightReader:
""" """
self.offset = 0 self.offset = 0
# Step 2: # Step 2
def _conv_block(input_layer, convs, skip=True): def _conv_block(input_layer, convs, skip=True):
""" """
Function to create convolutional layer. Function to create convolutional layer.
@@ -370,6 +369,9 @@ def _interval_overlap(interval_a, interval_b):
return min(x2,x4) - x3 return min(x2,x4) - x3
def bbox_iou(box1, box2): def bbox_iou(box1, box2):
"""
TODO
"""
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax]) intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
intersect = intersect_w * intersect_h intersect = intersect_w * intersect_h
@@ -379,6 +381,9 @@ def bbox_iou(box1, box2):
return float(intersect) / union return float(intersect) / union
def do_nms(boxes, nms_thresh): def do_nms(boxes, nms_thresh):
"""
TODO
"""
if len(boxes) > 0: if len(boxes) > 0:
nb_class = len(boxes[0].classes) nb_class = len(boxes[0].classes)
else: else:
@@ -387,54 +392,63 @@ def do_nms(boxes, nms_thresh):
sorted_indices = np.argsort([-box.classes[c] for box in boxes]) sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for i in range(len(sorted_indices)): for i in range(len(sorted_indices)):
index_i = sorted_indices[i] index_i = sorted_indices[i]
if boxes[index_i].classes[c] == 0: continue
if boxes[index_i].classes[c] == 0:
continue
for j in range(i+1, len(sorted_indices)): for j in range(i+1, len(sorted_indices)):
index_j = sorted_indices[j] index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh: if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0 boxes[index_j].classes[c] = 0
# get all of the results above a threshold
def get_boxes(boxes, labels, thresh): def get_boxes(boxes, labels, thresh):
"""
Get all of the results above a threshold
"""
v_boxes, v_labels, v_scores = list(), list(), list() v_boxes, v_labels, v_scores = list(), list(), list()
# enumerate all boxes
# Enumerate all boxes
for box in boxes: for box in boxes:
# enumerate all possible labels # Enumerate all possible labels
for i in range(len(labels)): for i, label in enumerate(labels):
# check if the threshold for this label is high enough # Check if the threshold for this label is high enough
if box.classes[i] > thresh: if box.classes[i] > thresh:
v_boxes.append(box) v_boxes.append(box)
v_labels.append(labels[i]) v_labels.append(label)
v_scores.append(box.classes[i]*100) v_scores.append(box.classes[i]*100)
# don't break, many labels may trigger for one box # Don't break, many labels may trigger for one box
return v_boxes, v_labels, v_scores return v_boxes, v_labels, v_scores
# draw all results
def draw_boxes(filename, v_boxes, v_labels, v_scores): def draw_boxes(filename, v_boxes, v_labels, v_scores):
"""
# load the image Draw all results
"""
# Load the image
data = pyplot.imread(filename) data = pyplot.imread(filename)
# plot the image # Plot the image
pyplot.imshow(data) pyplot.imshow(data)
# get the context for drawing boxes # Get the context for drawing boxes
ax = pyplot.gca() ax = pyplot.gca()
# plot each box # Plot each box
for i in range(len(v_boxes)): for i, box in enumerate(v_boxes):
box = v_boxes[i] # Get coordinates
# get coordinates
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
# calculate width and height of the box # Calculate width and height of the box
width, height = x2 - x1, y2 - y1 width, height = x2 - x1, y2 - y1
# create the shape # Create the shape
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2') rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
# draw the box # Draw the box
ax.add_patch(rect) ax.add_patch(rect)
# draw text and score in top left corner # Draw text and score in top left corner
label = "%s (%.3f)" % (v_labels[i], v_scores[i]) label = "%s (%.3f)" % (v_labels[i], v_scores[i])
pyplot.text(x1, y1, label, color='red') pyplot.text(x1, y1, label, color='red')
# show the plot # Show the plot
pyplot.show() pyplot.show()
"""**step 7:** declare several configuration""" # Step 7:
# Dclare several configurationd
# Define the anchors # Define the anchors
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]] ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
@@ -482,25 +496,22 @@ def main():
# Step 8: # Step 8:
# Make Prediction # Make Prediction
for photo_filename in glob.glob("images/test/dog/*"): for photo_filename in glob.glob("images/test/dog/*"):
# Define the expected input shape for the model
# for fn in upload.keys():
# photo_filename = '/content/' + fn
# photo_filename = 'test.jpg'
# define the expected input shape for the model
input_w, input_h = 416, 416 input_w, input_h = 416, 416
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h)) image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
# make prediction # Make prediction
yhat = yolov3.predict(image) netouts = yolov3.predict(image)
# summarize the shape of the list of arrays
print([a.shape for a in yhat]) # Summarize the shape of the list of arrays
print([a.shape for a in netouts])
boxes = list() boxes = list()
for i in range(len(yhat)):
# decode the output of the network for i, netout in enumerate(netouts):
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w) # Decode the output of the network
boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
# correct the sizes of the bounding boxes for the shape of the image # correct the sizes of the bounding boxes for the shape of the image
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w) correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
@@ -518,7 +529,5 @@ def main():
# draw what we found # draw what we found
draw_boxes(photo_filename, v_boxes, v_labels, v_scores) draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
print([a.shape for a in yhat])
if __name__ == "__main__": if __name__ == "__main__":
main() main()