Renamed some variables to reduce pylint warnings

This commit is contained in:
Heiko J Schick
2020-11-06 13:46:11 +01:00
parent aa3d8f33a4
commit 97e5e858e3
+38 -30
View File
@@ -321,15 +321,19 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
if objectness.all() <= obj_thresh:
continue
# First 4 elements are x, y, w, and h
x, y, w, h = netout[int(row)][int(col)][j][:4]
x = (col + x) / grid_w # Center position, unit: image width
y = (row + y) / grid_h # Center position, unit: image height
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
# First 4 elements to for the bounding box are x, y, w, and h
box_x, box_y, box_w, box_h = netout[int(row)][int(col)][j][:4]
box_x = (col + box_x) / grid_w # Center position, unit: image width
box_y = (row + box_y) / grid_h # Center position, unit: image height
box_w = anchors[2 * j + 0] * np.exp(box_w) / net_w # Unit: image width
box_h = anchors[2 * j + 1] * np.exp(box_h) / net_h # Unit: image height
# Last elements are class probabilities
classes = netout[int(row)][col][j][5:]
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
box = BoundBox(box_x - box_w / 2,
box_y - box_h / 2,
box_x + box_w / 2,
box_y + box_h / 2,
objectness, classes)
boxes.append(box)
return boxes
@@ -351,22 +355,24 @@ def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
# Step 6
def _interval_overlap(interval_a, interval_b):
"""
Implementing IOU
Implementing Intersection over Unit (IoU)
Source: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
https://medium.com/@amrokamal_47691/yolo-yolov2-and-yolov3-all-you-want-to-know-7e3e92dc4899
"""
x1, x2 = interval_a
x3, x4 = interval_b
x_1, x_2 = interval_a
x_3, x_4 = interval_b
if x3 < x1:
if x4 < x1:
if x_3 < x_1:
if x_4 < x_1:
ret = 0
else:
ret = min(x2,x4) - x1
ret = min(x_2, x_4) - x_1
else:
if x2 < x3:
if x_2 < x_3:
ret = 0
else:
ret = min(x2,x4) - x3
ret = min(x_2, x_4) - x_3
return ret
@@ -377,9 +383,10 @@ def bbox_iou(box1, box2):
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
intersect = intersect_w * intersect_h
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
union = w1 * h1 + w2 * h2 - intersect
w_1, h_1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
w_2, h_2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
union = w_1 * h_1 + w_2 * h_2 - intersect
return float(intersect) / union
def do_nms(boxes, nms_thresh):
@@ -387,21 +394,21 @@ def do_nms(boxes, nms_thresh):
TODO
"""
if len(boxes) > 0:
nb_class = len(boxes[0].classes)
nb_classes = len(boxes[0].classes)
else:
return
for c in range(nb_class):
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for nb_class in range(nb_classes):
sorted_indices = np.argsort([-box.classes[nb_class] for box in boxes])
for i, index_i in enumerate(sorted_indices):
if boxes[index_i].classes[c] == 0:
if boxes[index_i].classes[nb_class] == 0:
continue
for j in range(i+1, len(sorted_indices)):
index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0
boxes[index_j].classes[nb_class] = 0
def get_boxes(boxes, labels, thresh):
"""
@@ -431,20 +438,20 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
# Plot the image
pyplot.imshow(data)
# Get the context for drawing boxes
ax = pyplot.gca()
axes = pyplot.gca()
# Plot each box
for i, box in enumerate(v_boxes):
# Get coordinates
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
y_1, x_1, y_2, x_2 = box.ymin, box.xmin, box.ymax, box.xmax
# Calculate width and height of the box
width, height = x2 - x1, y2 - y1
width, height = x_2 - x_1, y_2 - y_1
# Create the shape
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
rect = Rectangle((x_1, y_1), width, height, fill=False, color='red', linewidth = '2')
# Draw the box
ax.add_patch(rect)
axes.add_patch(rect)
# Draw text and score in top left corner
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
pyplot.text(x1, y1, label, color='white', backgroundcolor='red')
pyplot.text(x_1, y_1, label, color='white', backgroundcolor='red')
# Show the plot
pyplot.show()
@@ -600,7 +607,8 @@ def main():
weight_reader.load_weights(yolov3)
# Save the model to file
yolov3.save('yolov3.h5')
# yolov3.trainable = False
yolov3.save('yolov3')
# Step 8:
# Make Prediction