From 8eb406b02c48af1480967ad62d9b8bf74c56a739 Mon Sep 17 00:00:00 2001 From: Heiko J Schick Date: Wed, 21 Oct 2020 16:12:21 +0200 Subject: [PATCH] Coding style changes --- yolov3.py | 95 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 43 deletions(-) diff --git a/yolov3.py b/yolov3.py index 59f5fd8..d774550 100644 --- a/yolov3.py +++ b/yolov3.py @@ -17,7 +17,6 @@ from matplotlib.patches import Rectangle # Step 1: # Define WeightReader class - class WeightReader: """ WeightReader class is used to parse the "yolov3.weights" file and load the model weights into @@ -85,7 +84,7 @@ class WeightReader: """ self.offset = 0 -# Step 2: +# Step 2 def _conv_block(input_layer, convs, skip=True): """ Function to create convolutional layer. @@ -370,15 +369,21 @@ def _interval_overlap(interval_a, interval_b): return min(x2,x4) - x3 def bbox_iou(box1, box2): + """ + TODO + """ intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax]) intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) intersect = intersect_w * intersect_h - w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin - w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin - union = w1*h1 + w2*h2 - intersect + w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin + w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin + union = w1 * h1 + w2 * h2 - intersect return float(intersect) / union def do_nms(boxes, nms_thresh): + """ + TODO + """ if len(boxes) > 0: nb_class = len(boxes[0].classes) else: @@ -387,54 +392,63 @@ def do_nms(boxes, nms_thresh): sorted_indices = np.argsort([-box.classes[c] for box in boxes]) for i in range(len(sorted_indices)): index_i = sorted_indices[i] - if boxes[index_i].classes[c] == 0: continue + + if boxes[index_i].classes[c] == 0: + continue + for j in range(i+1, len(sorted_indices)): index_j = sorted_indices[j] + if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh: boxes[index_j].classes[c] = 0 -# get all of the results above a threshold def get_boxes(boxes, labels, thresh): + """ + Get all of the results above a threshold + """ v_boxes, v_labels, v_scores = list(), list(), list() - # enumerate all boxes + + # Enumerate all boxes for box in boxes: - # enumerate all possible labels - for i in range(len(labels)): - # check if the threshold for this label is high enough + # Enumerate all possible labels + for i, label in enumerate(labels): + # Check if the threshold for this label is high enough if box.classes[i] > thresh: v_boxes.append(box) - v_labels.append(labels[i]) + v_labels.append(label) v_scores.append(box.classes[i]*100) - # don't break, many labels may trigger for one box + # Don't break, many labels may trigger for one box + return v_boxes, v_labels, v_scores -# draw all results def draw_boxes(filename, v_boxes, v_labels, v_scores): - - # load the image + """ + Draw all results + """ + # Load the image data = pyplot.imread(filename) - # plot the image + # Plot the image pyplot.imshow(data) - # get the context for drawing boxes + # Get the context for drawing boxes ax = pyplot.gca() - # plot each box - for i in range(len(v_boxes)): - box = v_boxes[i] - # get coordinates + # Plot each box + for i, box in enumerate(v_boxes): + # Get coordinates y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax - # calculate width and height of the box + # Calculate width and height of the box width, height = x2 - x1, y2 - y1 - # create the shape + # Create the shape rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2') - # draw the box + # Draw the box ax.add_patch(rect) - # draw text and score in top left corner + # Draw text and score in top left corner label = "%s (%.3f)" % (v_labels[i], v_scores[i]) pyplot.text(x1, y1, label, color='red') - # show the plot + # Show the plot pyplot.show() -"""**step 7:** declare several configuration""" +# Step 7: +# Dclare several configurationd # Define the anchors ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]] @@ -482,25 +496,22 @@ def main(): # Step 8: # Make Prediction for photo_filename in glob.glob("images/test/dog/*"): - - # for fn in upload.keys(): - # photo_filename = '/content/' + fn - # photo_filename = 'test.jpg' - - # define the expected input shape for the model + # Define the expected input shape for the model input_w, input_h = 416, 416 image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h)) - # make prediction - yhat = yolov3.predict(image) - # summarize the shape of the list of arrays - print([a.shape for a in yhat]) + # Make prediction + netouts = yolov3.predict(image) + + # Summarize the shape of the list of arrays + print([a.shape for a in netouts]) boxes = list() - for i in range(len(yhat)): - # decode the output of the network - boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w) + + for i, netout in enumerate(netouts): + # Decode the output of the network + boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w) # correct the sizes of the bounding boxes for the shape of the image correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w) @@ -518,7 +529,5 @@ def main(): # draw what we found draw_boxes(photo_filename, v_boxes, v_labels, v_scores) - print([a.shape for a in yhat]) - if __name__ == "__main__": main()