Renamed some variables to reduce pylint warnings
This commit is contained in:
@@ -321,15 +321,19 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
|||||||
if objectness.all() <= obj_thresh:
|
if objectness.all() <= obj_thresh:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# First 4 elements are x, y, w, and h
|
# First 4 elements to for the bounding box are x, y, w, and h
|
||||||
x, y, w, h = netout[int(row)][int(col)][j][:4]
|
box_x, box_y, box_w, box_h = netout[int(row)][int(col)][j][:4]
|
||||||
x = (col + x) / grid_w # Center position, unit: image width
|
box_x = (col + box_x) / grid_w # Center position, unit: image width
|
||||||
y = (row + y) / grid_h # Center position, unit: image height
|
box_y = (row + box_y) / grid_h # Center position, unit: image height
|
||||||
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
|
box_w = anchors[2 * j + 0] * np.exp(box_w) / net_w # Unit: image width
|
||||||
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
|
box_h = anchors[2 * j + 1] * np.exp(box_h) / net_h # Unit: image height
|
||||||
# Last elements are class probabilities
|
# Last elements are class probabilities
|
||||||
classes = netout[int(row)][col][j][5:]
|
classes = netout[int(row)][col][j][5:]
|
||||||
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
|
box = BoundBox(box_x - box_w / 2,
|
||||||
|
box_y - box_h / 2,
|
||||||
|
box_x + box_w / 2,
|
||||||
|
box_y + box_h / 2,
|
||||||
|
objectness, classes)
|
||||||
boxes.append(box)
|
boxes.append(box)
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
@@ -351,22 +355,24 @@ def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
|
|||||||
# Step 6
|
# Step 6
|
||||||
def _interval_overlap(interval_a, interval_b):
|
def _interval_overlap(interval_a, interval_b):
|
||||||
"""
|
"""
|
||||||
Implementing IOU
|
Implementing Intersection over Unit (IoU)
|
||||||
|
Source: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
|
||||||
|
https://medium.com/@amrokamal_47691/yolo-yolov2-and-yolov3-all-you-want-to-know-7e3e92dc4899
|
||||||
|
|
||||||
"""
|
"""
|
||||||
x1, x2 = interval_a
|
x_1, x_2 = interval_a
|
||||||
x3, x4 = interval_b
|
x_3, x_4 = interval_b
|
||||||
|
|
||||||
if x3 < x1:
|
if x_3 < x_1:
|
||||||
if x4 < x1:
|
if x_4 < x_1:
|
||||||
ret = 0
|
ret = 0
|
||||||
else:
|
else:
|
||||||
ret = min(x2,x4) - x1
|
ret = min(x_2, x_4) - x_1
|
||||||
else:
|
else:
|
||||||
if x2 < x3:
|
if x_2 < x_3:
|
||||||
ret = 0
|
ret = 0
|
||||||
else:
|
else:
|
||||||
ret = min(x2,x4) - x3
|
ret = min(x_2, x_4) - x_3
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -377,9 +383,10 @@ def bbox_iou(box1, box2):
|
|||||||
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
||||||
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
||||||
intersect = intersect_w * intersect_h
|
intersect = intersect_w * intersect_h
|
||||||
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
|
w_1, h_1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
|
||||||
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
|
w_2, h_2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
|
||||||
union = w1 * h1 + w2 * h2 - intersect
|
union = w_1 * h_1 + w_2 * h_2 - intersect
|
||||||
|
|
||||||
return float(intersect) / union
|
return float(intersect) / union
|
||||||
|
|
||||||
def do_nms(boxes, nms_thresh):
|
def do_nms(boxes, nms_thresh):
|
||||||
@@ -387,21 +394,21 @@ def do_nms(boxes, nms_thresh):
|
|||||||
TODO
|
TODO
|
||||||
"""
|
"""
|
||||||
if len(boxes) > 0:
|
if len(boxes) > 0:
|
||||||
nb_class = len(boxes[0].classes)
|
nb_classes = len(boxes[0].classes)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
for c in range(nb_class):
|
for nb_class in range(nb_classes):
|
||||||
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
|
sorted_indices = np.argsort([-box.classes[nb_class] for box in boxes])
|
||||||
|
|
||||||
for i, index_i in enumerate(sorted_indices):
|
for i, index_i in enumerate(sorted_indices):
|
||||||
if boxes[index_i].classes[c] == 0:
|
if boxes[index_i].classes[nb_class] == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for j in range(i+1, len(sorted_indices)):
|
for j in range(i+1, len(sorted_indices)):
|
||||||
index_j = sorted_indices[j]
|
index_j = sorted_indices[j]
|
||||||
|
|
||||||
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
||||||
boxes[index_j].classes[c] = 0
|
boxes[index_j].classes[nb_class] = 0
|
||||||
|
|
||||||
def get_boxes(boxes, labels, thresh):
|
def get_boxes(boxes, labels, thresh):
|
||||||
"""
|
"""
|
||||||
@@ -431,20 +438,20 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
|||||||
# Plot the image
|
# Plot the image
|
||||||
pyplot.imshow(data)
|
pyplot.imshow(data)
|
||||||
# Get the context for drawing boxes
|
# Get the context for drawing boxes
|
||||||
ax = pyplot.gca()
|
axes = pyplot.gca()
|
||||||
# Plot each box
|
# Plot each box
|
||||||
for i, box in enumerate(v_boxes):
|
for i, box in enumerate(v_boxes):
|
||||||
# Get coordinates
|
# Get coordinates
|
||||||
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
|
y_1, x_1, y_2, x_2 = box.ymin, box.xmin, box.ymax, box.xmax
|
||||||
# Calculate width and height of the box
|
# Calculate width and height of the box
|
||||||
width, height = x2 - x1, y2 - y1
|
width, height = x_2 - x_1, y_2 - y_1
|
||||||
# Create the shape
|
# Create the shape
|
||||||
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
|
rect = Rectangle((x_1, y_1), width, height, fill=False, color='red', linewidth = '2')
|
||||||
# Draw the box
|
# Draw the box
|
||||||
ax.add_patch(rect)
|
axes.add_patch(rect)
|
||||||
# Draw text and score in top left corner
|
# Draw text and score in top left corner
|
||||||
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
||||||
pyplot.text(x1, y1, label, color='white', backgroundcolor='red')
|
pyplot.text(x_1, y_1, label, color='white', backgroundcolor='red')
|
||||||
|
|
||||||
# Show the plot
|
# Show the plot
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
@@ -600,7 +607,8 @@ def main():
|
|||||||
weight_reader.load_weights(yolov3)
|
weight_reader.load_weights(yolov3)
|
||||||
|
|
||||||
# Save the model to file
|
# Save the model to file
|
||||||
yolov3.save('yolov3.h5')
|
# yolov3.trainable = False
|
||||||
|
yolov3.save('yolov3')
|
||||||
|
|
||||||
# Step 8:
|
# Step 8:
|
||||||
# Make Prediction
|
# Make Prediction
|
||||||
|
|||||||
Reference in New Issue
Block a user