Coding style changes

This commit is contained in:
Heiko J Schick
2020-10-21 16:12:21 +02:00
parent 18a3c7a43a
commit 8eb406b02c
+52 -43
View File
@@ -17,7 +17,6 @@ from matplotlib.patches import Rectangle
# Step 1:
# Define WeightReader class
class WeightReader:
"""
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
@@ -85,7 +84,7 @@ class WeightReader:
"""
self.offset = 0
# Step 2:
# Step 2
def _conv_block(input_layer, convs, skip=True):
"""
Function to create convolutional layer.
@@ -370,15 +369,21 @@ def _interval_overlap(interval_a, interval_b):
return min(x2,x4) - x3
def bbox_iou(box1, box2):
"""
TODO
"""
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
intersect = intersect_w * intersect_h
w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin
w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin
union = w1*h1 + w2*h2 - intersect
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
union = w1 * h1 + w2 * h2 - intersect
return float(intersect) / union
def do_nms(boxes, nms_thresh):
"""
TODO
"""
if len(boxes) > 0:
nb_class = len(boxes[0].classes)
else:
@@ -387,54 +392,63 @@ def do_nms(boxes, nms_thresh):
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for i in range(len(sorted_indices)):
index_i = sorted_indices[i]
if boxes[index_i].classes[c] == 0: continue
if boxes[index_i].classes[c] == 0:
continue
for j in range(i+1, len(sorted_indices)):
index_j = sorted_indices[j]
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
boxes[index_j].classes[c] = 0
# get all of the results above a threshold
def get_boxes(boxes, labels, thresh):
"""
Get all of the results above a threshold
"""
v_boxes, v_labels, v_scores = list(), list(), list()
# enumerate all boxes
# Enumerate all boxes
for box in boxes:
# enumerate all possible labels
for i in range(len(labels)):
# check if the threshold for this label is high enough
# Enumerate all possible labels
for i, label in enumerate(labels):
# Check if the threshold for this label is high enough
if box.classes[i] > thresh:
v_boxes.append(box)
v_labels.append(labels[i])
v_labels.append(label)
v_scores.append(box.classes[i]*100)
# don't break, many labels may trigger for one box
# Don't break, many labels may trigger for one box
return v_boxes, v_labels, v_scores
# draw all results
def draw_boxes(filename, v_boxes, v_labels, v_scores):
# load the image
"""
Draw all results
"""
# Load the image
data = pyplot.imread(filename)
# plot the image
# Plot the image
pyplot.imshow(data)
# get the context for drawing boxes
# Get the context for drawing boxes
ax = pyplot.gca()
# plot each box
for i in range(len(v_boxes)):
box = v_boxes[i]
# get coordinates
# Plot each box
for i, box in enumerate(v_boxes):
# Get coordinates
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
# calculate width and height of the box
# Calculate width and height of the box
width, height = x2 - x1, y2 - y1
# create the shape
# Create the shape
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
# draw the box
# Draw the box
ax.add_patch(rect)
# draw text and score in top left corner
# Draw text and score in top left corner
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
pyplot.text(x1, y1, label, color='red')
# show the plot
# Show the plot
pyplot.show()
"""**step 7:** declare several configuration"""
# Step 7:
# Dclare several configurationd
# Define the anchors
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
@@ -482,25 +496,22 @@ def main():
# Step 8:
# Make Prediction
for photo_filename in glob.glob("images/test/dog/*"):
# for fn in upload.keys():
# photo_filename = '/content/' + fn
# photo_filename = 'test.jpg'
# define the expected input shape for the model
# Define the expected input shape for the model
input_w, input_h = 416, 416
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
# make prediction
yhat = yolov3.predict(image)
# summarize the shape of the list of arrays
print([a.shape for a in yhat])
# Make prediction
netouts = yolov3.predict(image)
# Summarize the shape of the list of arrays
print([a.shape for a in netouts])
boxes = list()
for i in range(len(yhat)):
# decode the output of the network
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
for i, netout in enumerate(netouts):
# Decode the output of the network
boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
# correct the sizes of the bounding boxes for the shape of the image
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
@@ -518,7 +529,5 @@ def main():
# draw what we found
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
print([a.shape for a in yhat])
if __name__ == "__main__":
main()