Initial commit
This commit is contained in:
@@ -0,0 +1,130 @@
|
||||
# Source: https://www.guru99.com/keras-tutorial.html
|
||||
|
||||
#### Data preparation
|
||||
from keras.preprocessing.image import ImageDataGenerator
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
train_path = 'images/train/'
|
||||
test_path = 'images/test/'
|
||||
batch_size = 16
|
||||
image_size = 224
|
||||
num_class = 8
|
||||
|
||||
'''
|
||||
The ImageDataGenerator will make an X_training data from a directory.
|
||||
The sub-directory in that directory will be used as a class for each object.
|
||||
The image will be loaded with the RGB color mode, with the categorical class
|
||||
mode for the Y_training data, with a batch size of 16. Finally, shuffle the
|
||||
data.
|
||||
'''
|
||||
train_datagen = ImageDataGenerator(validation_split=0.3,
|
||||
shear_range=0.2,
|
||||
zoom_range=0.2,
|
||||
horizontal_flip=True)
|
||||
|
||||
train_generator = train_datagen.flow_from_directory(
|
||||
directory=train_path,
|
||||
target_size=(image_size,image_size),
|
||||
batch_size=batch_size,
|
||||
class_mode='categorical',
|
||||
color_mode='rgb',
|
||||
shuffle=True)
|
||||
|
||||
'''
|
||||
Let's see our images randomly by plotting them with matplotlib
|
||||
'''
|
||||
x_batch, y_batch = train_generator.next()
|
||||
fig=plt.figure()
|
||||
columns = 4
|
||||
rows = 4
|
||||
for i in range(1, columns*rows):
|
||||
num = np.random.randint(batch_size)
|
||||
image = x_batch[num].astype(np.int)
|
||||
fig.add_subplot(rows, columns, i)
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
|
||||
### Creating model
|
||||
import keras
|
||||
from keras.models import Model, load_model
|
||||
from keras.layers import Activation, Dropout, Flatten, Dense
|
||||
from keras.preprocessing.image import ImageDataGenerator
|
||||
from keras.applications.vgg16 import VGG16
|
||||
|
||||
'''
|
||||
Let's create our network model from VGG16 with imageNet pre-trained weight.
|
||||
We will freeze these layers so that the layers are not trainable to help us
|
||||
reduce the computation time.
|
||||
'''
|
||||
|
||||
# Load the VGG model
|
||||
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
|
||||
|
||||
print(base_model.summary())
|
||||
|
||||
# Freeze the layers
|
||||
for layer in base_model.layers:
|
||||
layer.trainable = False
|
||||
|
||||
# Create the model
|
||||
model = keras.models.Sequential()
|
||||
|
||||
# Add the vgg convolutional base model
|
||||
model.add(base_model)
|
||||
|
||||
# Add new layers
|
||||
model.add(Flatten())
|
||||
model.add(Dense(1024, activation='relu'))
|
||||
model.add(Dense(1024, activation='relu'))
|
||||
model.add(Dense(num_class, activation='softmax'))
|
||||
|
||||
# Show a summary of the model. Check the number of trainable parameters
|
||||
print(model.summary())
|
||||
|
||||
input("Press Enter to continue...")
|
||||
|
||||
'''
|
||||
As you can see, the summary of our network model. From an input from
|
||||
VGG16 Layers, then we add 2 Fully Connected Layer which will extract 1024
|
||||
features and an output layer that will compute the 8 classes with the softmax
|
||||
activation.
|
||||
'''
|
||||
|
||||
#### Training
|
||||
# Compile the model
|
||||
from keras.optimizers import SGD
|
||||
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=SGD(lr=1e-3),
|
||||
metrics=['accuracy'])
|
||||
|
||||
'''
|
||||
# Start the training process
|
||||
model.fit(x_train, y_train, validation_split=0.30, batch_size=32, epochs=50, verbose=2)
|
||||
|
||||
# Save the model
|
||||
model.save('catdog.h5')
|
||||
'''
|
||||
|
||||
history = model.fit_generator(train_generator,
|
||||
steps_per_epoch=train_generator.n/batch_size,
|
||||
epochs=10)
|
||||
|
||||
model.save('fine_tune.h5')
|
||||
|
||||
# summarize history for accuracy
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.plot(history.history['loss'])
|
||||
plt.title('loss')
|
||||
plt.ylabel('loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['loss'], loc='upper left')
|
||||
plt.show()
|
||||
|
||||
'''
|
||||
As you can see, our losses are dropped significantly and the accuracy is
|
||||
almost 100%. For testing our model, we randomly picked images over the internet
|
||||
and put it on the test folder with a different class to test
|
||||
'''
|
||||
Reference in New Issue
Block a user