diff --git a/yolov3.py b/yolov3.py index d774550..4c2ad8d 100644 --- a/yolov3.py +++ b/yolov3.py @@ -390,6 +390,7 @@ def do_nms(boxes, nms_thresh): return for c in range(nb_class): sorted_indices = np.argsort([-box.classes[c] for box in boxes]) + for i in range(len(sorted_indices)): index_i = sorted_indices[i] @@ -416,7 +417,7 @@ def get_boxes(boxes, labels, thresh): if box.classes[i] > thresh: v_boxes.append(box) v_labels.append(label) - v_scores.append(box.classes[i]*100) + v_scores.append(box.classes[i] * 100) # Don't break, many labels may trigger for one box return v_boxes, v_labels, v_scores @@ -443,7 +444,8 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores): ax.add_patch(rect) # Draw text and score in top left corner label = "%s (%.3f)" % (v_labels[i], v_scores[i]) - pyplot.text(x1, y1, label, color='red') + pyplot.text(x1, y1, label, color='white', backgroundcolor='red') + # Show the plot pyplot.show() @@ -457,17 +459,86 @@ ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33 CLASS_THRESHOLD = 0.6 # Define the labels -LABELS = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck", - "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", - "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", - "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", - "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", - "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", - "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", - "chair", "sofa", "pottedplant", "bed", "diningtable", "toilet", "tvmonitor", "laptop", "mouse", - "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", - "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] - +LABELS = ["person", # 0 + "bicycle", + "car", + "motorbike", + "aeroplane", + "bus", # 5 + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", # 10 + "stop sign", + "parking meter", + "bench", + "bird", + "cat", # 15 + "dog", + "horse", + "sheep", + "cow", + "elephant", # 20 + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", # 25 + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", # 30 + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", # 35 + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", # 40 + "cup", + "fork", + "knife", + "spoon", + "bowl", # 45 + "banana", + "apple", + "sandwich", + "orange", + "broccoli", # 50 + "carrot", + "hot dog", + "pizza", + "donut", + "cake", # 55 + "chair", + "sofa", + "pottedplant", + "bed", + "diningtable", # 60 + "toilet", + "tvmonitor", + "laptop", + "mouse", + "remote", # 65 + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", # 70 + "sink", + "refrigerator", + "book", + "clock", + "vase", # 75 + "scissors", + "teddy bear", + "hair drier", + "toothbrush"] def main(): """ @@ -513,20 +584,20 @@ def main(): # Decode the output of the network boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w) - # correct the sizes of the bounding boxes for the shape of the image + # Correct the sizes of the bounding boxes for the shape of the image correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w) - # suppress non-maximal boxes + # Suppress non-maximal boxes do_nms(boxes, 0.5) - # get the details of the detected objects + # Get the details of the detected objects v_boxes, v_labels, v_scores = get_boxes(boxes, LABELS, CLASS_THRESHOLD) - # summarize what we found + # Summarize what we found for i in range(len(v_boxes)): print(v_labels[i], v_scores[i]) - # draw what we found + # Draw what we found draw_boxes(photo_filename, v_boxes, v_labels, v_scores) if __name__ == "__main__":