Coding style changes
This commit is contained in:
@@ -279,67 +279,85 @@ class BoundBox:
|
|||||||
self.score = -1
|
self.score = -1
|
||||||
|
|
||||||
def get_label(self):
|
def get_label(self):
|
||||||
|
"""
|
||||||
|
Gets the label of the current object
|
||||||
|
"""
|
||||||
if self.label == -1:
|
if self.label == -1:
|
||||||
self.label = np.argmax(self.classes)
|
self.label = np.argmax(self.classes)
|
||||||
|
|
||||||
return self.label
|
return self.label
|
||||||
|
|
||||||
def get_score(self):
|
def get_score(self):
|
||||||
|
"""
|
||||||
|
Gets the score of the current object
|
||||||
|
"""
|
||||||
if self.score == -1:
|
if self.score == -1:
|
||||||
self.score = self.classes[self.get_label()]
|
self.score = self.classes[self.get_label()]
|
||||||
|
|
||||||
return self.get_score
|
return self.get_score
|
||||||
|
|
||||||
def _sigmoid(x):
|
def _sigmoid(inp):
|
||||||
return 1. /(1. + np.exp(-x))
|
return 1. / (1. + np.exp(-inp))
|
||||||
|
|
||||||
def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
||||||
|
"""
|
||||||
|
Decode output information of network.
|
||||||
|
"""
|
||||||
grid_h, grid_w = netout.shape[:2]
|
grid_h, grid_w = netout.shape[:2]
|
||||||
nb_box = 3
|
nb_box = 3
|
||||||
netout = netout.reshape((grid_h, grid_w, nb_box, -1))
|
netout = netout.reshape((grid_h, grid_w, nb_box, -1))
|
||||||
nb_class = netout.shape[-1] - 5
|
|
||||||
boxes = []
|
boxes = []
|
||||||
netout[..., :2] = _sigmoid(netout[..., :2])
|
netout[..., :2] = _sigmoid(netout[..., :2])
|
||||||
netout[..., 4:] = _sigmoid(netout[..., 4:])
|
netout[..., 4:] = _sigmoid(netout[..., 4:])
|
||||||
netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:]
|
netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:]
|
||||||
netout[..., 5:] *= netout[..., 5:] > obj_thresh
|
netout[..., 5:] *= netout[..., 5:] > obj_thresh
|
||||||
|
|
||||||
for i in range(grid_h*grid_w):
|
for i in range(grid_h * grid_w):
|
||||||
row = i / grid_w
|
row = i / grid_w
|
||||||
col = i % grid_w
|
col = i % grid_w
|
||||||
for b in range(nb_box):
|
for j in range(nb_box):
|
||||||
# 4th element is objectness score
|
# 4th element is objectness score
|
||||||
objectness = netout[int(row)][int(col)][b][4]
|
objectness = netout[int(row)][int(col)][j][4]
|
||||||
if objectness.all() <= obj_thresh: continue
|
|
||||||
# first 4 elements are x, y, w, and h
|
if objectness.all() <= obj_thresh:
|
||||||
x, y, w, h = netout[int(row)][int(col)][b][:4]
|
continue
|
||||||
x = (col + x) / grid_w # center position, unit: image width
|
|
||||||
y = (row + y) / grid_h # center position, unit: image height
|
# First 4 elements are x, y, w, and h
|
||||||
w = anchors[2 * b + 0] * np.exp(w) / net_w # unit: image width
|
x, y, w, h = netout[int(row)][int(col)][j][:4]
|
||||||
h = anchors[2 * b + 1] * np.exp(h) / net_h # unit: image height
|
x = (col + x) / grid_w # Center position, unit: image width
|
||||||
# last elements are class probabilities
|
y = (row + y) / grid_h # Center position, unit: image height
|
||||||
classes = netout[int(row)][col][b][5:]
|
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
|
||||||
box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, objectness, classes)
|
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
|
||||||
|
# Last elements are class probabilities
|
||||||
|
classes = netout[int(row)][col][j][5:]
|
||||||
|
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
|
||||||
boxes.append(box)
|
boxes.append(box)
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
"""**Step 5:** strech the box to be fit to the image normal shape"""
|
# Step 5
|
||||||
|
|
||||||
def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
|
def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
|
||||||
|
"""
|
||||||
|
Strech the box to be fit to the image normal shape
|
||||||
|
"""
|
||||||
new_w, new_h = net_w, net_h
|
new_w, new_h = net_w, net_h
|
||||||
for i in range(len(boxes)):
|
for box in boxes:
|
||||||
x_offset, x_scale = (net_w - new_w)/2./net_w, float(new_w)/net_w
|
x_offset, x_scale = (net_w - new_w) / 2. / net_w, float(new_w) / net_w
|
||||||
y_offset, y_scale = (net_h - new_h)/2./net_h, float(new_h)/net_h
|
y_offset, y_scale = (net_h - new_h) / 2. / net_h, float(new_h) / net_h
|
||||||
boxes[i].xmin = int((boxes[i].xmin - x_offset) / x_scale * image_w)
|
|
||||||
boxes[i].xmax = int((boxes[i].xmax - x_offset) / x_scale * image_w)
|
|
||||||
boxes[i].ymin = int((boxes[i].ymin - y_offset) / y_scale * image_h)
|
|
||||||
boxes[i].ymax = int((boxes[i].ymax - y_offset) / y_scale * image_h)
|
|
||||||
|
|
||||||
"""**Step 6:** implementing IOU"""
|
box.xmin = int((box.xmin - x_offset) / x_scale * image_w)
|
||||||
|
box.xmax = int((box.xmax - x_offset) / x_scale * image_w)
|
||||||
|
box.ymin = int((box.ymin - y_offset) / y_scale * image_h)
|
||||||
|
box.ymax = int((box.ymax - y_offset) / y_scale * image_h)
|
||||||
|
|
||||||
|
# Step 6
|
||||||
def _interval_overlap(interval_a, interval_b):
|
def _interval_overlap(interval_a, interval_b):
|
||||||
|
"""
|
||||||
|
Implementing IOU
|
||||||
|
|
||||||
|
"""
|
||||||
x1, x2 = interval_a
|
x1, x2 = interval_a
|
||||||
x3, x4 = interval_b
|
x3, x4 = interval_b
|
||||||
|
|
||||||
if x3 < x1:
|
if x3 < x1:
|
||||||
if x4 < x1:
|
if x4 < x1:
|
||||||
return 0
|
return 0
|
||||||
@@ -418,14 +436,14 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
|||||||
|
|
||||||
"""**step 7:** declare several configuration"""
|
"""**step 7:** declare several configuration"""
|
||||||
|
|
||||||
# define the anchors
|
# Define the anchors
|
||||||
anchors = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
|
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
|
||||||
|
|
||||||
# define the probability threshold for detected objects
|
# Define the probability threshold for detected objects
|
||||||
class_threshold = 0.6
|
CLASS_THRESHOLD = 0.6
|
||||||
|
|
||||||
# define the labels
|
# Define the labels
|
||||||
labels = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck",
|
LABELS = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck",
|
||||||
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
|
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
|
||||||
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
|
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
|
||||||
"backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
"backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
||||||
@@ -482,7 +500,7 @@ def main():
|
|||||||
boxes = list()
|
boxes = list()
|
||||||
for i in range(len(yhat)):
|
for i in range(len(yhat)):
|
||||||
# decode the output of the network
|
# decode the output of the network
|
||||||
boxes += decode_netout(yhat[i][0], anchors[i], class_threshold, input_h, input_w)
|
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
|
||||||
|
|
||||||
# correct the sizes of the bounding boxes for the shape of the image
|
# correct the sizes of the bounding boxes for the shape of the image
|
||||||
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
|
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
|
||||||
@@ -491,7 +509,7 @@ def main():
|
|||||||
do_nms(boxes, 0.5)
|
do_nms(boxes, 0.5)
|
||||||
|
|
||||||
# get the details of the detected objects
|
# get the details of the detected objects
|
||||||
v_boxes, v_labels, v_scores = get_boxes(boxes, labels, class_threshold)
|
v_boxes, v_labels, v_scores = get_boxes(boxes, LABELS, CLASS_THRESHOLD)
|
||||||
|
|
||||||
# summarize what we found
|
# summarize what we found
|
||||||
for i in range(len(v_boxes)):
|
for i in range(len(v_boxes)):
|
||||||
|
|||||||
Reference in New Issue
Block a user