Print all netout and moved th downloads of all files in seperate function
This commit is contained in:
@@ -31,6 +31,23 @@ def download_file(url, file_name):
|
||||
response = get(url)
|
||||
file.write(response.content)
|
||||
|
||||
def download_files():
|
||||
"""
|
||||
Download all data and label files.
|
||||
"""
|
||||
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
||||
'train-images-idx3-ubyte.gz')
|
||||
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
||||
'train-labels-idx1-ubyte.gz')
|
||||
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
||||
't10k-images-idx3-ubyte.gz')
|
||||
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz')
|
||||
|
||||
def read_mnist(images_path: str, labels_path: str):
|
||||
"""
|
||||
Read data and labels of the MNIST dataset.
|
||||
@@ -128,19 +145,7 @@ def main():
|
||||
|
||||
# Step 1:
|
||||
# Download the MNIST dataset with consist of labeled handwritten images (28x28 px).
|
||||
|
||||
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
||||
'train-images-idx3-ubyte.gz')
|
||||
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
||||
'train-labels-idx1-ubyte.gz')
|
||||
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
||||
't10k-images-idx3-ubyte.gz')
|
||||
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
|
||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
||||
't10k-labels-idx1-ubyte.gz')
|
||||
download_files()
|
||||
|
||||
# Step 2:
|
||||
# Read MNIST dataset (training and testing)
|
||||
@@ -190,10 +195,9 @@ def main():
|
||||
model = create_lenet5()
|
||||
|
||||
# Step 8:
|
||||
# Train the LeNet-5 model
|
||||
train_lenet5(model)
|
||||
|
||||
model.save('lenet5.h5')
|
||||
# Train and save the LeNet-5 model
|
||||
# train_lenet5(model)
|
||||
# model.save('lenet5.h5')
|
||||
|
||||
model = load_model('lenet5.h5')
|
||||
|
||||
@@ -210,12 +214,12 @@ def main():
|
||||
# display_image(test, 1)
|
||||
# display_image(test, 2)
|
||||
|
||||
outputs = model.predict(test['features'])
|
||||
netouts = model.predict(test['features'])
|
||||
|
||||
for i, item in enumerate(outputs[0]):
|
||||
print("%d: %.3f" % (i, item))
|
||||
|
||||
print(test['labels'][0])
|
||||
for i, netout in enumerate(netouts):
|
||||
print("%d: %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f"
|
||||
% (i, netout[0], netout[1], netout[2], netout[3], netout[4],
|
||||
netout[5], netout[6], netout[7], netout[8], netout[9]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user