Coding style changes

This commit is contained in:
Heiko J Schick
2020-10-21 14:53:15 +02:00
parent 935d1d3f1a
commit 18a3c7a43a
+52 -34
View File
@@ -279,67 +279,85 @@ class BoundBox:
self.score = -1 self.score = -1
def get_label(self): def get_label(self):
"""
Gets the label of the current object
"""
if self.label == -1: if self.label == -1:
self.label = np.argmax(self.classes) self.label = np.argmax(self.classes)
return self.label return self.label
def get_score(self): def get_score(self):
"""
Gets the score of the current object
"""
if self.score == -1: if self.score == -1:
self.score = self.classes[self.get_label()] self.score = self.classes[self.get_label()]
return self.get_score return self.get_score
def _sigmoid(x): def _sigmoid(inp):
return 1. /(1. + np.exp(-x)) return 1. / (1. + np.exp(-inp))
def decode_netout(netout, anchors, obj_thresh, net_h, net_w): def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
"""
Decode output information of network.
"""
grid_h, grid_w = netout.shape[:2] grid_h, grid_w = netout.shape[:2]
nb_box = 3 nb_box = 3
netout = netout.reshape((grid_h, grid_w, nb_box, -1)) netout = netout.reshape((grid_h, grid_w, nb_box, -1))
nb_class = netout.shape[-1] - 5
boxes = [] boxes = []
netout[..., :2] = _sigmoid(netout[..., :2]) netout[..., :2] = _sigmoid(netout[..., :2])
netout[..., 4:] = _sigmoid(netout[..., 4:]) netout[..., 4:] = _sigmoid(netout[..., 4:])
netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:] netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:]
netout[..., 5:] *= netout[..., 5:] > obj_thresh netout[..., 5:] *= netout[..., 5:] > obj_thresh
for i in range(grid_h*grid_w): for i in range(grid_h * grid_w):
row = i / grid_w row = i / grid_w
col = i % grid_w col = i % grid_w
for b in range(nb_box): for j in range(nb_box):
# 4th element is objectness score # 4th element is objectness score
objectness = netout[int(row)][int(col)][b][4] objectness = netout[int(row)][int(col)][j][4]
if objectness.all() <= obj_thresh: continue
# first 4 elements are x, y, w, and h if objectness.all() <= obj_thresh:
x, y, w, h = netout[int(row)][int(col)][b][:4] continue
x = (col + x) / grid_w # center position, unit: image width
y = (row + y) / grid_h # center position, unit: image height # First 4 elements are x, y, w, and h
w = anchors[2 * b + 0] * np.exp(w) / net_w # unit: image width x, y, w, h = netout[int(row)][int(col)][j][:4]
h = anchors[2 * b + 1] * np.exp(h) / net_h # unit: image height x = (col + x) / grid_w # Center position, unit: image width
# last elements are class probabilities y = (row + y) / grid_h # Center position, unit: image height
classes = netout[int(row)][col][b][5:] w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, objectness, classes) h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
# Last elements are class probabilities
classes = netout[int(row)][col][j][5:]
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
boxes.append(box) boxes.append(box)
return boxes return boxes
"""**Step 5:** strech the box to be fit to the image normal shape""" # Step 5
def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w): def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
"""
Strech the box to be fit to the image normal shape
"""
new_w, new_h = net_w, net_h new_w, new_h = net_w, net_h
for i in range(len(boxes)): for box in boxes:
x_offset, x_scale = (net_w - new_w)/2./net_w, float(new_w)/net_w x_offset, x_scale = (net_w - new_w) / 2. / net_w, float(new_w) / net_w
y_offset, y_scale = (net_h - new_h)/2./net_h, float(new_h)/net_h y_offset, y_scale = (net_h - new_h) / 2. / net_h, float(new_h) / net_h
boxes[i].xmin = int((boxes[i].xmin - x_offset) / x_scale * image_w)
boxes[i].xmax = int((boxes[i].xmax - x_offset) / x_scale * image_w)
boxes[i].ymin = int((boxes[i].ymin - y_offset) / y_scale * image_h)
boxes[i].ymax = int((boxes[i].ymax - y_offset) / y_scale * image_h)
"""**Step 6:** implementing IOU""" box.xmin = int((box.xmin - x_offset) / x_scale * image_w)
box.xmax = int((box.xmax - x_offset) / x_scale * image_w)
box.ymin = int((box.ymin - y_offset) / y_scale * image_h)
box.ymax = int((box.ymax - y_offset) / y_scale * image_h)
# Step 6
def _interval_overlap(interval_a, interval_b): def _interval_overlap(interval_a, interval_b):
"""
Implementing IOU
"""
x1, x2 = interval_a x1, x2 = interval_a
x3, x4 = interval_b x3, x4 = interval_b
if x3 < x1: if x3 < x1:
if x4 < x1: if x4 < x1:
return 0 return 0
@@ -418,14 +436,14 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
"""**step 7:** declare several configuration""" """**step 7:** declare several configuration"""
# define the anchors # Define the anchors
anchors = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]] ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
# define the probability threshold for detected objects # Define the probability threshold for detected objects
class_threshold = 0.6 CLASS_THRESHOLD = 0.6
# define the labels # Define the labels
labels = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck", LABELS = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
@@ -482,7 +500,7 @@ def main():
boxes = list() boxes = list()
for i in range(len(yhat)): for i in range(len(yhat)):
# decode the output of the network # decode the output of the network
boxes += decode_netout(yhat[i][0], anchors[i], class_threshold, input_h, input_w) boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
# correct the sizes of the bounding boxes for the shape of the image # correct the sizes of the bounding boxes for the shape of the image
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w) correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
@@ -491,7 +509,7 @@ def main():
do_nms(boxes, 0.5) do_nms(boxes, 0.5)
# get the details of the detected objects # get the details of the detected objects
v_boxes, v_labels, v_scores = get_boxes(boxes, labels, class_threshold) v_boxes, v_labels, v_scores = get_boxes(boxes, LABELS, CLASS_THRESHOLD)
# summarize what we found # summarize what we found
for i in range(len(v_boxes)): for i in range(len(v_boxes)):