Coding style changes
This commit is contained in:
@@ -17,7 +17,6 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
# Step 1:
|
||||
# Define WeightReader class
|
||||
|
||||
class WeightReader:
|
||||
"""
|
||||
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
|
||||
@@ -85,7 +84,7 @@ class WeightReader:
|
||||
"""
|
||||
self.offset = 0
|
||||
|
||||
# Step 2:
|
||||
# Step 2
|
||||
def _conv_block(input_layer, convs, skip=True):
|
||||
"""
|
||||
Function to create convolutional layer.
|
||||
@@ -370,15 +369,21 @@ def _interval_overlap(interval_a, interval_b):
|
||||
return min(x2,x4) - x3
|
||||
|
||||
def bbox_iou(box1, box2):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])
|
||||
intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax])
|
||||
intersect = intersect_w * intersect_h
|
||||
w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin
|
||||
w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin
|
||||
union = w1*h1 + w2*h2 - intersect
|
||||
w1, h1 = box1.xmax - box1.xmin, box1.ymax - box1.ymin
|
||||
w2, h2 = box2.xmax - box2.xmin, box2.ymax - box2.ymin
|
||||
union = w1 * h1 + w2 * h2 - intersect
|
||||
return float(intersect) / union
|
||||
|
||||
def do_nms(boxes, nms_thresh):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
if len(boxes) > 0:
|
||||
nb_class = len(boxes[0].classes)
|
||||
else:
|
||||
@@ -387,54 +392,63 @@ def do_nms(boxes, nms_thresh):
|
||||
sorted_indices = np.argsort([-box.classes[c] for box in boxes])
|
||||
for i in range(len(sorted_indices)):
|
||||
index_i = sorted_indices[i]
|
||||
if boxes[index_i].classes[c] == 0: continue
|
||||
|
||||
if boxes[index_i].classes[c] == 0:
|
||||
continue
|
||||
|
||||
for j in range(i+1, len(sorted_indices)):
|
||||
index_j = sorted_indices[j]
|
||||
|
||||
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
||||
boxes[index_j].classes[c] = 0
|
||||
|
||||
# get all of the results above a threshold
|
||||
def get_boxes(boxes, labels, thresh):
|
||||
"""
|
||||
Get all of the results above a threshold
|
||||
"""
|
||||
v_boxes, v_labels, v_scores = list(), list(), list()
|
||||
# enumerate all boxes
|
||||
|
||||
# Enumerate all boxes
|
||||
for box in boxes:
|
||||
# enumerate all possible labels
|
||||
for i in range(len(labels)):
|
||||
# check if the threshold for this label is high enough
|
||||
# Enumerate all possible labels
|
||||
for i, label in enumerate(labels):
|
||||
# Check if the threshold for this label is high enough
|
||||
if box.classes[i] > thresh:
|
||||
v_boxes.append(box)
|
||||
v_labels.append(labels[i])
|
||||
v_labels.append(label)
|
||||
v_scores.append(box.classes[i]*100)
|
||||
# don't break, many labels may trigger for one box
|
||||
# Don't break, many labels may trigger for one box
|
||||
|
||||
return v_boxes, v_labels, v_scores
|
||||
|
||||
# draw all results
|
||||
def draw_boxes(filename, v_boxes, v_labels, v_scores):
|
||||
|
||||
# load the image
|
||||
"""
|
||||
Draw all results
|
||||
"""
|
||||
# Load the image
|
||||
data = pyplot.imread(filename)
|
||||
# plot the image
|
||||
# Plot the image
|
||||
pyplot.imshow(data)
|
||||
# get the context for drawing boxes
|
||||
# Get the context for drawing boxes
|
||||
ax = pyplot.gca()
|
||||
# plot each box
|
||||
for i in range(len(v_boxes)):
|
||||
box = v_boxes[i]
|
||||
# get coordinates
|
||||
# Plot each box
|
||||
for i, box in enumerate(v_boxes):
|
||||
# Get coordinates
|
||||
y1, x1, y2, x2 = box.ymin, box.xmin, box.ymax, box.xmax
|
||||
# calculate width and height of the box
|
||||
# Calculate width and height of the box
|
||||
width, height = x2 - x1, y2 - y1
|
||||
# create the shape
|
||||
# Create the shape
|
||||
rect = Rectangle((x1, y1), width, height, fill=False, color='red', linewidth = '2')
|
||||
# draw the box
|
||||
# Draw the box
|
||||
ax.add_patch(rect)
|
||||
# draw text and score in top left corner
|
||||
# Draw text and score in top left corner
|
||||
label = "%s (%.3f)" % (v_labels[i], v_scores[i])
|
||||
pyplot.text(x1, y1, label, color='red')
|
||||
# show the plot
|
||||
# Show the plot
|
||||
pyplot.show()
|
||||
|
||||
"""**step 7:** declare several configuration"""
|
||||
# Step 7:
|
||||
# Dclare several configurationd
|
||||
|
||||
# Define the anchors
|
||||
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
|
||||
@@ -482,25 +496,22 @@ def main():
|
||||
# Step 8:
|
||||
# Make Prediction
|
||||
for photo_filename in glob.glob("images/test/dog/*"):
|
||||
|
||||
# for fn in upload.keys():
|
||||
# photo_filename = '/content/' + fn
|
||||
# photo_filename = 'test.jpg'
|
||||
|
||||
# define the expected input shape for the model
|
||||
# Define the expected input shape for the model
|
||||
input_w, input_h = 416, 416
|
||||
|
||||
image, image_w, image_h = load_image_pixels(photo_filename, (input_w, input_h))
|
||||
|
||||
# make prediction
|
||||
yhat = yolov3.predict(image)
|
||||
# summarize the shape of the list of arrays
|
||||
print([a.shape for a in yhat])
|
||||
# Make prediction
|
||||
netouts = yolov3.predict(image)
|
||||
|
||||
# Summarize the shape of the list of arrays
|
||||
print([a.shape for a in netouts])
|
||||
|
||||
boxes = list()
|
||||
for i in range(len(yhat)):
|
||||
# decode the output of the network
|
||||
boxes += decode_netout(yhat[i][0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
|
||||
|
||||
for i, netout in enumerate(netouts):
|
||||
# Decode the output of the network
|
||||
boxes += decode_netout(netout[0], ANCHORS[i], CLASS_THRESHOLD, input_h, input_w)
|
||||
|
||||
# correct the sizes of the bounding boxes for the shape of the image
|
||||
correct_yolo_boxes(boxes, image_h, image_w, input_h, input_w)
|
||||
@@ -518,7 +529,5 @@ def main():
|
||||
# draw what we found
|
||||
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
||||
|
||||
print([a.shape for a in yhat])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user