Minor restructuring

This commit is contained in:
Heiko J Schick
2020-10-23 21:16:10 +02:00
parent b032b58ffa
commit aa3d8f33a4
+10 -10
View File
@@ -52,7 +52,7 @@ class WeightReader:
for i in range(106): for i in range(106):
try: try:
conv_layer = model.get_layer('conv_' + str(i)) conv_layer = model.get_layer('conv_' + str(i))
print("loading weights of convolution #" + str(i)) print("Loading weights of convolution #" + str(i))
if i not in [81, 93, 105]: if i not in [81, 93, 105]:
norm_layer = model.get_layer('bnorm_' + str(i)) norm_layer = model.get_layer('bnorm_' + str(i))
@@ -76,7 +76,7 @@ class WeightReader:
conv_layer.set_weights([kernel]) conv_layer.set_weights([kernel])
except ValueError: except ValueError:
print("no convolution #" + str(i)) print("No convolution #" + str(i))
def reset(self): def reset(self):
""" """
@@ -359,14 +359,16 @@ def _interval_overlap(interval_a, interval_b):
if x3 < x1: if x3 < x1:
if x4 < x1: if x4 < x1:
return 0 ret = 0
else: else:
return min(x2,x4) - x1 ret = min(x2,x4) - x1
else: else:
if x2 < x3: if x2 < x3:
return 0 ret = 0
else: else:
return min(x2,x4) - x3 ret = min(x2,x4) - x3
return ret
def bbox_iou(box1, box2): def bbox_iou(box1, box2):
""" """
@@ -391,9 +393,7 @@ def do_nms(boxes, nms_thresh):
for c in range(nb_class): for c in range(nb_class):
sorted_indices = np.argsort([-box.classes[c] for box in boxes]) sorted_indices = np.argsort([-box.classes[c] for box in boxes])
for i in range(len(sorted_indices)): for i, index_i in enumerate(sorted_indices):
index_i = sorted_indices[i]
if boxes[index_i].classes[c] == 0: if boxes[index_i].classes[c] == 0:
continue continue
@@ -455,7 +455,7 @@ def draw_boxes(filename, v_boxes, v_labels, v_scores):
# Define the anchors # Define the anchors
ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]] ANCHORS = [[116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]]
# Define the probability threshold for detected objects # Define thLOe probability threshold for detected objects
CLASS_THRESHOLD = 0.6 CLASS_THRESHOLD = 0.6
# Define the labels # Define the labels