Renamed some variables to reduce pylint warnings

This commit is contained in:
Heiko J Schick
2020-11-06 13:46:11 +01:00
parent aa3d8f33a4
commit 97e5e858e3
+38 -30
View File
@@ -321,15 +321,19 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
if objectness.all() <= obj_thresh: if objectness.all() <= obj_thresh:
continue continue
# First 4 elements are x, y, w, and h # First 4 elements to for the bounding box are x, y, w, and h
x, y, w, h = netout[int(row)][int(col)][j][:4] box_x, box_y, box_w, box_h = netout[int(row)][int(col)][j][:4]
x = (col + x) / grid_w # Center position, unit: image width box_x = (col + box_x) / grid_w # Center position, unit: image width
y = (row + y) / grid_h # Center position, unit: image height box_y = (row + box_y) / grid_h # Center position, unit: image height
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width box_w = anchors[2 * j + 0] * np.exp(box_w) / net_w # Unit: image width
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height box_h = anchors[2 * j + 1] * np.exp(box_h) / net_h # Unit: image height
# Last elements are class probabilities # Last elements are class probabilities
classes = netout[int(row)][col][j][5:] classes = netout[int(row)][col][j][5:]
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes) box = BoundBox(box_x - box_w / 2,
box_y - box_h / 2,
box_x + box_w / 2,
box_y + box_h / 2,
objectness, classes)
boxes.append(box) boxes.append(box)
return boxes return boxes
@@ -351,22 +355,24 @@ def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
# Step 6 # Step 6
def _interval_overlap(interval_a, interval_b): def _interval_overlap(interval_a, interval_b):
""" """
Implementing IOU Implementing Intersection over Unit (IoU)
Source: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
https://medium.com/@amrokamal_47691/yolo-yolov2-and-yolov3-all-you-want-to-know-7e3e92dc4899
""" """
x1, x2 = interval_a x_1, x_2 = interval_a
x3, x4 = interval_b x_3, x_4 = interval_b
if x3 < x1: if x_3 < x_1:
if x4 < x1: if x_4 < x_1:
ret = 0 ret = 0
else: else:
ret = min(x2,x4) - x1 ret = min(x_2, x_4) - x_1
else: else:
if x2 < x3: if x_2 < x_3:
ret = 0 ret = 0
else: else:
ret = min(x2,x4) - x3 ret = min(x_2, x_4) - x_3
return ret return ret
@@ -377,9 +383,10 @@ def bbox_iou(box1, box2):
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax]) intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
intersect = intersect_w * intersect_h intersect = intersect_w * intersect_h
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin w_1, h_1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin w_2, h_2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
union = w1 * h1 + w2 * h2 - intersect union = w_1 * h_1 + w_2 * h_2 - intersect
return float(intersect) / union return float(intersect) / union
def do_nms(boxes, nms_thresh): def do_nms(boxes, nms_thresh):
@@ -387,21 +394,21 @@ def do_nms(boxes, nms_thresh):
TODO TODO
""" """
if len(boxes) > 0: if len(boxes) > 0:
nb_class = len(boxes[0].classes) nb_classes = len(boxes[0].classes)
else: else:
return return
for c in range(nb_class): for nb_class in range(nb_classes):
sorted_indices = np.argsort([-box.classes[c] for box in boxes]) sorted_indices = np.argsort([-box.classes[nb_class] for box in boxes])
for i, index_i in enumerate(sorted_indices): for i, index_i in enumerate(sorted_indices):
if boxes[index_i].classes[c] == 0: if boxes[index_i].classes[nb_class] == 0:
continue continue
for j in range(i+1, len(sorted_indices)): for j in range(i+1, len(sorted_indices)):
index_j = sorted_indices[j] index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh: if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0 boxes[index_j].classes[nb_class] = 0
def get_boxes(boxes, labels, thresh): def get_boxes(boxes, labels, thresh):
""" """
@@ -431,20 +438,20 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
# Plot the image # Plot the image
pyplot.imshow(data) pyplot.imshow(data)
# Get the context for drawing boxes # Get the context for drawing boxes
ax = pyplot.gca() axes = pyplot.gca()
# Plot each box # Plot each box
for i, box in enumerate(v_boxes): for i, box in enumerate(v_boxes):
# Get coordinates # Get coordinates
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax y_1, x_1, y_2, x_2 = box.ymin, box.xmin, box.ymax, box.xmax
# Calculate width and height of the box # Calculate width and height of the box
width, height = x2 - x1, y2 - y1 width, height = x_2 - x_1, y_2 - y_1
# Create the shape # Create the shape
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2') rect = Rectangle((x_1, y_1), width, height, fill=False, color='red', linewidth = '2')
# Draw the box # Draw the box
ax.add_patch(rect) axes.add_patch(rect)
# Draw text and score in top left corner # Draw text and score in top left corner
label = "%s (%.3f)" % (v_labels[i], v_scores[i]) label = "%s (%.3f)" % (v_labels[i], v_scores[i])
pyplot.text(x1, y1, label, color='white', backgroundcolor='red') pyplot.text(x_1, y_1, label, color='white', backgroundcolor='red')
# Show the plot # Show the plot
pyplot.show() pyplot.show()
@@ -600,7 +607,8 @@ def main():
weight_reader.load_weights(yolov3) weight_reader.load_weights(yolov3)
# Save the model to file # Save the model to file
yolov3.save('yolov3.h5') # yolov3.trainable = False
yolov3.save('yolov3')
# Step 8: # Step 8:
# Make Prediction # Make Prediction