Coding style changes
This commit is contained in:
@@ -16,30 +16,40 @@ from matplotlib import pyplot
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
# Step 1:
|
||||
# WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
|
||||
# memory in a format that we can set into keras model
|
||||
# Define WeightReader class
|
||||
|
||||
class WeightReader:
|
||||
"""
|
||||
WeightReader class is used to parse the "yolov3.weights" file and load the model weights into
|
||||
memory in a format that we can set into keras model.
|
||||
"""
|
||||
def __init__(self, weight_file):
|
||||
with open(weight_file, 'rb') as w_f:
|
||||
major, = struct.unpack('i', w_f.read(4))
|
||||
minor, = struct.unpack('i', w_f.read(4))
|
||||
revision, = struct.unpack('i', w_f.read(4))
|
||||
w_f.read(4) # ignore revision
|
||||
|
||||
if (major * 10 + minor) >= 2 and major < 1000 and minor < 1000:
|
||||
w_f.read(8)
|
||||
else:
|
||||
w_f.read(4)
|
||||
|
||||
transpose = (major > 1000) or (minor > 1000)
|
||||
binary = w_f.read()
|
||||
self.offset = 0
|
||||
self.all_weights = np.frombuffer(binary, dtype='float32')
|
||||
|
||||
def read_bytes(self, size):
|
||||
"""
|
||||
Helper function to read bytes from all_weights.
|
||||
"""
|
||||
self.offset = self.offset + size
|
||||
|
||||
return self.all_weights[self.offset - size:self.offset]
|
||||
|
||||
def load_weights(self, model):
|
||||
"""
|
||||
Load weights into created model
|
||||
"""
|
||||
for i in range(106):
|
||||
try:
|
||||
conv_layer = model.get_layer('conv_' + str(i))
|
||||
@@ -52,7 +62,7 @@ class WeightReader:
|
||||
gamma = self.read_bytes(size) # scale
|
||||
mean = self.read_bytes(size) # mean
|
||||
var = self.read_bytes(size) # variance
|
||||
weights = norm_layer.set_weights([gamma, beta, mean, var])
|
||||
norm_layer.set_weights([gamma, beta, mean, var])
|
||||
|
||||
if len(conv_layer.get_weights()) > 1:
|
||||
bias = self.read_bytes(np.prod(conv_layer.get_weights()[1].shape))
|
||||
@@ -70,33 +80,36 @@ class WeightReader:
|
||||
print("no convolution #" + str(i))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets offset to restart loading weights
|
||||
"""
|
||||
self.offset = 0
|
||||
|
||||
# Step 2:
|
||||
# _conv_block(input, convs, skip=True) is a function to create convolutional layer
|
||||
def _conv_block(inp, convs, skip=True):
|
||||
x = inp
|
||||
def _conv_block(input_layer, convs, skip=True):
|
||||
tmp = input_layer
|
||||
count = 0
|
||||
for conv in convs:
|
||||
if count == (len(convs) - 2) and skip:
|
||||
skip_connection = x
|
||||
skip_connection = tmp
|
||||
count += 1
|
||||
if conv['stride'] > 1: x = ZeroPadding2D(((1,0),(1,0)))(x) # peculiar padding as darknet
|
||||
# prefer left and top
|
||||
x = Conv2D(conv['filter'],
|
||||
# Peculiar padding as darknet prefer left and top
|
||||
if conv['stride'] > 1: tmp = ZeroPadding2D(((1,0),(1,0)))(tmp)
|
||||
tmp = Conv2D(conv['filter'],
|
||||
conv['kernel'],
|
||||
strides=conv['stride'],
|
||||
padding='valid' if conv['stride'] > 1 else 'same', # peculiar padding as darknet
|
||||
# prefer left and top
|
||||
# Peculiar padding as darknet prefer left and top
|
||||
padding='valid' if conv['stride'] > 1 else 'same',
|
||||
name='conv_' + str(conv['layer_idx']),
|
||||
use_bias=False if conv['bnorm'] else True)(x)
|
||||
use_bias=False if conv['bnorm'] else True)(tmp)
|
||||
|
||||
if conv['bnorm']: x = BatchNormalization(epsilon=0.001, name='bnorm_'
|
||||
+ str(conv['layer_idx']))(x)
|
||||
if conv['leaky']: x = LeakyReLU(alpha=0.1, name='leaky_'
|
||||
+ str(conv['layer_idx']))(x)
|
||||
if conv['bnorm']: tmp = BatchNormalization(epsilon=0.001, name='bnorm_'
|
||||
+ str(conv['layer_idx']))(tmp)
|
||||
if conv['leaky']: tmp = LeakyReLU(alpha=0.1, name='leaky_'
|
||||
+ str(conv['layer_idx']))(tmp)
|
||||
|
||||
return add([skip_connection, x]) if skip else x
|
||||
return add([skip_connection, tmp]) if skip else tmp
|
||||
|
||||
# make_yolov3_model() is a function to create layers of convoluational and stack together as a
|
||||
# whole yolo model
|
||||
@@ -218,28 +231,8 @@ def make_yolov3_model():
|
||||
|
||||
model = Model(input_image, [yolo_82, yolo_94, yolo_106])
|
||||
|
||||
print(model.summary())
|
||||
|
||||
return model
|
||||
|
||||
"""**Step 3:**
|
||||
- define the model
|
||||
- load the weight
|
||||
- save the model
|
||||
"""
|
||||
|
||||
# define the yolo v3 model
|
||||
yolov3 = make_yolov3_model()
|
||||
|
||||
# load the weights
|
||||
weight_reader = WeightReader('yolov3.weights')
|
||||
|
||||
# set the weights
|
||||
weight_reader.load_weights(yolov3)
|
||||
|
||||
# save the model to file
|
||||
yolov3.save('model.h5')
|
||||
|
||||
"""**step 4:** Prediction
|
||||
by loading the image to model and make prediction
|
||||
"""
|
||||
@@ -435,11 +428,33 @@ labels = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train",
|
||||
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
|
||||
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
|
||||
|
||||
"""**Step 8:** Make Prediction"""
|
||||
|
||||
# from google.colab import files
|
||||
# upload = files.upload()
|
||||
def main():
|
||||
"""
|
||||
Defined starting point of source code.
|
||||
"""
|
||||
|
||||
# Step 3:
|
||||
# (1) Define the model
|
||||
# (2) Load the weight
|
||||
# (3) Save the model
|
||||
|
||||
# Define the YOLO v3 model
|
||||
yolov3 = make_yolov3_model()
|
||||
print(yolov3.summary())
|
||||
|
||||
# Load the weights
|
||||
# Source: https://pjreddie.com/media/files/yolov3.weights
|
||||
weight_reader = WeightReader('yolov3.weights')
|
||||
|
||||
# Set the weights
|
||||
weight_reader.load_weights(yolov3)
|
||||
|
||||
# Save the model to file
|
||||
yolov3.save('yolov3.h5')
|
||||
|
||||
# Step 8:
|
||||
# Make Prediction
|
||||
for photo_filename in glob.glob("images/test/dog/*"):
|
||||
|
||||
# for fn in upload.keys():
|
||||
@@ -478,3 +493,6 @@ for photo_filename in glob.glob("images/test/dog/*"):
|
||||
draw_boxes(photo_filename, v_boxes, v_labels, v_scores)
|
||||
|
||||
print([a.shape for a in yhat])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user