Re-enabled training
This commit is contained in:
@@ -164,10 +164,6 @@ def main():
|
|||||||
test['features'], test['labels'] = read_mnist('t10k-images-idx3-ubyte.gz',
|
test['features'], test['labels'] = read_mnist('t10k-images-idx3-ubyte.gz',
|
||||||
't10k-labels-idx1-ubyte.gz')
|
't10k-labels-idx1-ubyte.gz')
|
||||||
|
|
||||||
print(type(train))
|
|
||||||
print(type(train['features']))
|
|
||||||
print(type(train['labels']))
|
|
||||||
|
|
||||||
# Step 3:
|
# Step 3:
|
||||||
# Explore the dataset
|
# Explore the dataset
|
||||||
print('Number of training images:', train['features'].shape[0])
|
print('Number of training images:', train['features'].shape[0])
|
||||||
@@ -208,9 +204,9 @@ def main():
|
|||||||
model = create_lenet5()
|
model = create_lenet5()
|
||||||
|
|
||||||
# Step 8:
|
# Step 8:
|
||||||
# Train and save the LeNet-5 model
|
# Train, save and load the LeNet-5 model
|
||||||
# train_lenet5(model)
|
train_lenet5(model)
|
||||||
# model.save('lenet5.h5')
|
model.save('lenet5.h5')
|
||||||
model = load_model('lenet5.h5')
|
model = load_model('lenet5.h5')
|
||||||
|
|
||||||
# Step 9:
|
# Step 9:
|
||||||
|
|||||||
Reference in New Issue
Block a user