Coding style changes
This commit is contained in:
@@ -17,7 +17,6 @@ from matplotlib.patches import Rectangle
|
|||||||
|
|
||||||
# Step 1:
|
# Step 1:
|
||||||
# Define WeightReader class
|
# Define WeightReader class
|
||||||
|
|
||||||
class WeightReader:
|
class WeightReader:
|
||||||
"""
|
"""
|
||||||
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
|
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
|
||||||
@@ -85,7 +84,7 @@ class WeightReader:
|
|||||||
"""
|
"""
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
|
|
||||||
# Step 2:
|
# Step 2
|
||||||
def _conv_block(input_layer, convs, skip=True):
|
def _conv_block(input_layer, convs, skip=True):
|
||||||
"""
|
"""
|
||||||
Function to create convolutional layer.
|
Function to create convolutional layer.
|
||||||
@@ -370,6 +369,9 @@ def _interval_overlap(interval_a, interval_b):
|
|||||||
return min(x2,x4) - x3
|
return min(x2,x4) - x3
|
||||||
|
|
||||||
def bbox_iou(box1, box2):
|
def bbox_iou(box1, box2):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
||||||
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
||||||
intersect = intersect_w * intersect_h
|
intersect = intersect_w * intersect_h
|
||||||
@@ -379,6 +381,9 @@ def bbox_iou(box1, box2):
|
|||||||
return float(intersect) / union
|
return float(intersect) / union
|
||||||
|
|
||||||
def do_nms(boxes, nms_thresh):
|
def do_nms(boxes, nms_thresh):
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
if len(boxes) > 0:
|
if len(boxes) > 0:
|
||||||
nb_class = len(boxes[0].classes)
|
nb_class = len(boxes[0].classes)
|
||||||
else:
|
else:
|
||||||
@@ -387,54 +392,63 @@ def do_nms(boxes, nms_thresh):
|
|||||||
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
|
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
|
||||||
for i in range(len(sorted_indices)):
|
for i in range(len(sorted_indices)):
|
||||||
index_i = sorted_indices[i]
|
index_i = sorted_indices[i]
|
||||||
if boxes[index_i].classes[c] == 0: continue
|
|
||||||
|
if boxes[index_i].classes[c] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
for j in range(i+1, len(sorted_indices)):
|
for j in range(i+1, len(sorted_indices)):
|
||||||
index_j = sorted_indices[j]
|
index_j = sorted_indices[j]
|
||||||
|
|
||||||
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
||||||
boxes[index_j].classes[c] = 0
|
boxes[index_j].classes[c] = 0
|
||||||
|
|
||||||
# get all of the results above a threshold
|
|
||||||
def get_boxes(boxes, labels, thresh):
|
def get_boxes(boxes, labels, thresh):
|
||||||
|
"""
|
||||||
|
Get all of the results above a threshold
|
||||||
|
"""
|
||||||
v_boxes, v_labels, v_scores = list(), list(), list()
|
v_boxes, v_labels, v_scores = list(), list(), list()
|
||||||
# enumerate all boxes
|
|
||||||
|
# Enumerate all boxes
|
||||||
for box in boxes:
|
for box in boxes:
|
||||||
# enumerate all possible labels
|
# Enumerate all possible labels
|
||||||
for i in range(len(labels)):
|
for i, label in enumerate(labels):
|
||||||
# check if the threshold for this label is high enough
|
# Check if the threshold for this label is high enough
|
||||||
if box.classes[i] > thresh:
|
if box.classes[i] > thresh:
|
||||||
v_boxes.append(box)
|
v_boxes.append(box)
|
||||||
v_labels.append(labels[i])
|
v_labels.append(label)
|
||||||
v_scores.append(box.classes[i]*100)
|
v_scores.append(box.classes[i]*100)
|
||||||
# don't break, many labels may trigger for one box
|
# Don't break, many labels may trigger for one box
|
||||||
|
|
||||||
return v_boxes, v_labels, v_scores
|
return v_boxes, v_labels, v_scores
|
||||||
|
|
||||||
# draw all results
|
|
||||||
def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
||||||
|
"""
|
||||||
# load the image
|
Draw all results
|
||||||
|
"""
|
||||||
|
# Load the image
|
||||||
data = pyplot.imread(filename)
|
data = pyplot.imread(filename)
|
||||||
# plot the image
|
# Plot the image
|
||||||
pyplot.imshow(data)
|
pyplot.imshow(data)
|
||||||
# get the context for drawing boxes
|
# Get the context for drawing boxes
|
||||||
ax = pyplot.gca()
|
ax = pyplot.gca()
|
||||||
# plot each box
|
# Plot each box
|
||||||
for i in range(len(v_boxes)):
|
for i, box in enumerate(v_boxes):
|
||||||
box = v_boxes[i]
|
# Get coordinates
|
||||||
# get coordinates
|
|
||||||
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
|
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
|
||||||
# calculate width and height of the box
|
# Calculate width and height of the box
|
||||||
width, height = x2 - x1, y2 - y1
|
width, height = x2 - x1, y2 - y1
|
||||||
# create the shape
|
# Create the shape
|
||||||
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
|
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
|
||||||
# draw the box
|
# Draw the box
|
||||||
ax.add_patch(rect)
|
ax.add_patch(rect)
|
||||||
# draw text and score in top left corner
|
# Draw text and score in top left corner
|
||||||
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
||||||
pyplot.text(x1, y1, label, color='red')
|
pyplot.text(x1, y1, label, color='red')
|
||||||
# show the plot
|
# Show the plot
|
||||||
pyplot.show()
|
pyplot.show()
|
||||||
|
|
||||||
"""**step 7:** declare several configuration"""
|
# Step 7:
|
||||||
|
# Dclare several configurationd
|
||||||
|
|
||||||
# Define the anchors
|
# Define the anchors
|
||||||
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
|
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
|
||||||
@@ -482,25 +496,22 @@ def main():
|
|||||||
# Step 8:
|
# Step 8:
|
||||||
# Make Prediction
|
# Make Prediction
|
||||||
for photo_filename in glob.glob("images/test/dog/*"):
|
for photo_filename in glob.glob("images/test/dog/*"):
|
||||||
|
# Define the expected input shape for the model
|
||||||
# for fn in upload.keys():
|
|
||||||
# photo_filename = '/content/' + fn
|
|
||||||
# photo_filename = 'test.jpg'
|
|
||||||
|
|
||||||
# define the expected input shape for the model
|
|
||||||
input_w, input_h = 416, 416
|
input_w, input_h = 416, 416
|
||||||
|
|
||||||
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
|
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
|
||||||
|
|
||||||
# make prediction
|
# Make prediction
|
||||||
yhat = yolov3.predict(image)
|
netouts = yolov3.predict(image)
|
||||||
# summarize the shape of the list of arrays
|
|
||||||
print([a.shape for a in yhat])
|
# Summarize the shape of the list of arrays
|
||||||
|
print([a.shape for a in netouts])
|
||||||
|
|
||||||
boxes = list()
|
boxes = list()
|
||||||
for i in range(len(yhat)):
|
|
||||||
# decode the output of the network
|
for i, netout in enumerate(netouts):
|
||||||
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
|
# Decode the output of the network
|
||||||
|
boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
|
||||||
|
|
||||||
# correct the sizes of the bounding boxes for the shape of the image
|
# correct the sizes of the bounding boxes for the shape of the image
|
||||||
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
|
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
|
||||||
@@ -518,7 +529,5 @@ def main():
|
|||||||
# draw what we found
|
# draw what we found
|
||||||
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
||||||
|
|
||||||
print([a.shape for a in yhat])
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user