Renamed some variables to reduce pylint warnings
This commit is contained in:
@@ -321,15 +321,19 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
||||
if objectness.all() <= obj_thresh:
|
||||
continue
|
||||
|
||||
# First 4 elements are x, y, w, and h
|
||||
x, y, w, h = netout[int(row)][int(col)][j][:4]
|
||||
x = (col + x) / grid_w # Center position, unit: image width
|
||||
y = (row + y) / grid_h # Center position, unit: image height
|
||||
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
|
||||
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
|
||||
# First 4 elements to for the bounding box are x, y, w, and h
|
||||
box_x, box_y, box_w, box_h = netout[int(row)][int(col)][j][:4]
|
||||
box_x = (col + box_x) / grid_w # Center position, unit: image width
|
||||
box_y = (row + box_y) / grid_h # Center position, unit: image height
|
||||
box_w = anchors[2 * j + 0] * np.exp(box_w) / net_w # Unit: image width
|
||||
box_h = anchors[2 * j + 1] * np.exp(box_h) / net_h # Unit: image height
|
||||
# Last elements are class probabilities
|
||||
classes = netout[int(row)][col][j][5:]
|
||||
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
|
||||
box = BoundBox(box_x - box_w / 2,
|
||||
box_y - box_h / 2,
|
||||
box_x + box_w / 2,
|
||||
box_y + box_h / 2,
|
||||
objectness, classes)
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
@@ -351,22 +355,24 @@ def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
|
||||
# Step 6
|
||||
def _interval_overlap(interval_a, interval_b):
|
||||
"""
|
||||
Implementing IOU
|
||||
Implementing Intersection over Unit (IoU)
|
||||
Source: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
|
||||
https://medium.com/@amrokamal_47691/yolo-yolov2-and-yolov3-all-you-want-to-know-7e3e92dc4899
|
||||
|
||||
"""
|
||||
x1, x2 = interval_a
|
||||
x3, x4 = interval_b
|
||||
x_1, x_2 = interval_a
|
||||
x_3, x_4 = interval_b
|
||||
|
||||
if x3 < x1:
|
||||
if x4 < x1:
|
||||
if x_3 < x_1:
|
||||
if x_4 < x_1:
|
||||
ret = 0
|
||||
else:
|
||||
ret = min(x2,x4) - x1
|
||||
ret = min(x_2, x_4) - x_1
|
||||
else:
|
||||
if x2 < x3:
|
||||
if x_2 < x_3:
|
||||
ret = 0
|
||||
else:
|
||||
ret = min(x2,x4) - x3
|
||||
ret = min(x_2, x_4) - x_3
|
||||
|
||||
return ret
|
||||
|
||||
@@ -377,9 +383,10 @@ def bbox_iou(box1, box2):
|
||||
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
||||
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
||||
intersect = intersect_w * intersect_h
|
||||
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
|
||||
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
|
||||
union = w1 * h1 + w2 * h2 - intersect
|
||||
w_1, h_1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
|
||||
w_2, h_2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
|
||||
union = w_1 * h_1 + w_2 * h_2 - intersect
|
||||
|
||||
return float(intersect) / union
|
||||
|
||||
def do_nms(boxes, nms_thresh):
|
||||
@@ -387,21 +394,21 @@ def do_nms(boxes, nms_thresh):
|
||||
TODO
|
||||
"""
|
||||
if len(boxes) > 0:
|
||||
nb_class = len(boxes[0].classes)
|
||||
nb_classes = len(boxes[0].classes)
|
||||
else:
|
||||
return
|
||||
for c in range(nb_class):
|
||||
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
|
||||
for nb_class in range(nb_classes):
|
||||
sorted_indices = np.argsort([-box.classes[nb_class] for box in boxes])
|
||||
|
||||
for i, index_i in enumerate(sorted_indices):
|
||||
if boxes[index_i].classes[c] == 0:
|
||||
if boxes[index_i].classes[nb_class] == 0:
|
||||
continue
|
||||
|
||||
for j in range(i+1, len(sorted_indices)):
|
||||
index_j = sorted_indices[j]
|
||||
|
||||
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
||||
boxes[index_j].classes[c] = 0
|
||||
boxes[index_j].classes[nb_class] = 0
|
||||
|
||||
def get_boxes(boxes, labels, thresh):
|
||||
"""
|
||||
@@ -431,20 +438,20 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
||||
# Plot the image
|
||||
pyplot.imshow(data)
|
||||
# Get the context for drawing boxes
|
||||
ax = pyplot.gca()
|
||||
axes = pyplot.gca()
|
||||
# Plot each box
|
||||
for i, box in enumerate(v_boxes):
|
||||
# Get coordinates
|
||||
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
|
||||
y_1, x_1, y_2, x_2 = box.ymin, box.xmin, box.ymax, box.xmax
|
||||
# Calculate width and height of the box
|
||||
width, height = x2 - x1, y2 - y1
|
||||
width, height = x_2 - x_1, y_2 - y_1
|
||||
# Create the shape
|
||||
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
|
||||
rect = Rectangle((x_1, y_1), width, height, fill=False, color='red', linewidth = '2')
|
||||
# Draw the box
|
||||
ax.add_patch(rect)
|
||||
axes.add_patch(rect)
|
||||
# Draw text and score in top left corner
|
||||
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
||||
pyplot.text(x1, y1, label, color='white', backgroundcolor='red')
|
||||
pyplot.text(x_1, y_1, label, color='white', backgroundcolor='red')
|
||||
|
||||
# Show the plot
|
||||
pyplot.show()
|
||||
@@ -600,7 +607,8 @@ def main():
|
||||
weight_reader.load_weights(yolov3)
|
||||
|
||||
# Save the model to file
|
||||
yolov3.save('yolov3.h5')
|
||||
# yolov3.trainable = False
|
||||
yolov3.save('yolov3')
|
||||
|
||||
# Step 8:
|
||||
# Make Prediction
|
||||
|
||||
Reference in New Issue
Block a user