Coding style changes

This commit is contained in:
Heiko J Schick
2020-10-21 14:53:15 +02:00
parent 935d1d3f1a
commit 18a3c7a43a
+52 -34
View File
@@ -279,67 +279,85 @@ class BoundBox:
self.score = -1
def get_label(self):
"""
Gets the label of the current object
"""
if self.label == -1:
self.label = np.argmax(self.classes)
return self.label
def get_score(self):
"""
Gets the score of the current object
"""
if self.score == -1:
self.score = self.classes[self.get_label()]
return self.get_score
def _sigmoid(x):
return 1. /(1. + np.exp(-x))
def _sigmoid(inp):
return 1. / (1. + np.exp(-inp))
def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
"""
Decode output information of network.
"""
grid_h, grid_w = netout.shape[:2]
nb_box = 3
netout = netout.reshape((grid_h, grid_w, nb_box, -1))
nb_class = netout.shape[-1] - 5
boxes = []
netout[..., :2] = _sigmoid(netout[..., :2])
netout[..., 4:] = _sigmoid(netout[..., 4:])
netout[..., 5:] = netout[..., 4][..., np.newaxis] * netout[..., 5:]
netout[..., 5:] *= netout[..., 5:] > obj_thresh
for i in range(grid_h*grid_w):
for i in range(grid_h * grid_w):
row = i / grid_w
col = i % grid_w
for b in range(nb_box):
for j in range(nb_box):
# 4th element is objectness score
objectness = netout[int(row)][int(col)][b][4]
if objectness.all() <= obj_thresh: continue
# first 4 elements are x, y, w, and h
x, y, w, h = netout[int(row)][int(col)][b][:4]
x = (col + x) / grid_w # center position, unit: image width
y = (row + y) / grid_h # center position, unit: image height
w = anchors[2 * b + 0] * np.exp(w) / net_w # unit: image width
h = anchors[2 * b + 1] * np.exp(h) / net_h # unit: image height
# last elements are class probabilities
classes = netout[int(row)][col][b][5:]
box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, objectness, classes)
objectness = netout[int(row)][int(col)][j][4]
if objectness.all() <= obj_thresh:
continue
# First 4 elements are x, y, w, and h
x, y, w, h = netout[int(row)][int(col)][j][:4]
x = (col + x) / grid_w # Center position, unit: image width
y = (row + y) / grid_h # Center position, unit: image height
w = anchors[2 * j + 0] * np.exp(w) / net_w # Unit: image width
h = anchors[2 * j + 1] * np.exp(h) / net_h # Unit: image height
# Last elements are class probabilities
classes = netout[int(row)][col][j][5:]
box = BoundBox(x - w / 2, y - h / 2, x + w / 2, y + h / 2, objectness, classes)
boxes.append(box)
return boxes
"""**Step 5:** strech the box to be fit to the image normal shape"""
# Step 5
def correct_yolo_boxes(boxes, image_h, image_w, net_h, net_w):
"""
Strech the box to be fit to the image normal shape
"""
new_w, new_h = net_w, net_h
for i in range(len(boxes)):
x_offset, x_scale = (net_w - new_w)/2./net_w, float(new_w)/net_w
y_offset, y_scale = (net_h - new_h)/2./net_h, float(new_h)/net_h
boxes[i].xmin = int((boxes[i].xmin - x_offset) / x_scale * image_w)
boxes[i].xmax = int((boxes[i].xmax - x_offset) / x_scale * image_w)
boxes[i].ymin = int((boxes[i].ymin - y_offset) / y_scale * image_h)
boxes[i].ymax = int((boxes[i].ymax - y_offset) / y_scale * image_h)
for box in boxes:
x_offset, x_scale = (net_w - new_w) / 2. / net_w, float(new_w) / net_w
y_offset, y_scale = (net_h - new_h) / 2. / net_h, float(new_h) / net_h
"""**Step 6:** implementing IOU"""
box.xmin = int((box.xmin - x_offset) / x_scale * image_w)
box.xmax = int((box.xmax - x_offset) / x_scale * image_w)
box.ymin = int((box.ymin - y_offset) / y_scale * image_h)
box.ymax = int((box.ymax - y_offset) / y_scale * image_h)
# Step 6
def _interval_overlap(interval_a, interval_b):
"""
Implementing IOU
"""
x1, x2 = interval_a
x3, x4 = interval_b
if x3 < x1:
if x4 < x1:
return 0
@@ -418,14 +436,14 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
"""**step 7:** declare several configuration"""
# define the anchors
anchors = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
# Define the anchors
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
# define the probability threshold for detected objects
class_threshold = 0.6
# Define the probability threshold for detected objects
CLASS_THRESHOLD = 0.6
# define the labels
labels = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck",
# Define the labels
LABELS = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
@@ -482,7 +500,7 @@ def main():
boxes = list()
for i in range(len(yhat)):
# decode the output of the network
boxes += decode_netout(yhat[i][0], anchors[i], class_threshold, input_h, input_w)
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
# correct the sizes of the bounding boxes for the shape of the image
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
@@ -491,7 +509,7 @@ def main():
do_nms(boxes, 0.5)
# get the details of the detected objects
v_boxes, v_labels, v_scores = get_boxes(boxes, labels, class_threshold)
v_boxes, v_labels, v_scores = get_boxes(boxes, LABELS, CLASS_THRESHOLD)
# summarize what we found
for i in range(len(v_boxes)):