Print all netout and moved th downloads of all files in seperate function

This commit is contained in:
Heiko J Schick
2020-10-22 09:09:22 +02:00
parent f31e7c649c
commit 0a5635fe85
+26 -22
View File
@@ -31,6 +31,23 @@ def download_file(url, file_name):
response = get(url)
file.write(response.content)
def download_files():
"""
Download all data and label files.
"""
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'train-images-idx3-ubyte.gz')
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'train-labels-idx1-ubyte.gz')
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
't10k-images-idx3-ubyte.gz')
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
def read_mnist(images_path: str, labels_path: str):
"""
Read data and labels of the MNIST dataset.
@@ -128,19 +145,7 @@ def main():
# Step 1:
# Download the MNIST dataset with consist of labeled handwritten images (28x28 px).
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'train-images-idx3-ubyte.gz')
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'train-labels-idx1-ubyte.gz')
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
't10k-images-idx3-ubyte.gz')
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
download_files()
# Step 2:
# Read MNIST dataset (training and testing)
@@ -190,10 +195,9 @@ def main():
model = create_lenet5()
# Step 8:
# Train the LeNet-5 model
train_lenet5(model)
model.save('lenet5.h5')
# Train and save the LeNet-5 model
# train_lenet5(model)
# model.save('lenet5.h5')
model = load_model('lenet5.h5')
@@ -210,12 +214,12 @@ def main():
# display_image(test, 1)
# display_image(test, 2)
outputs = model.predict(test['features'])
netouts = model.predict(test['features'])
for i, item in enumerate(outputs[0]):
print("%d: %.3f" % (i, item))
print(test['labels'][0])
for i, netout in enumerate(netouts):
print("%d: %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f"
% (i, netout[0], netout[1], netout[2], netout[3], netout[4],
netout[5], netout[6], netout[7], netout[8], netout[9]))
if __name__ == "__main__":
main()