Print all netout and moved th downloads of all files in seperate function

This commit is contained in:
Heiko J Schick
2020-10-22 09:09:22 +02:00
parent f31e7c649c
commit 0a5635fe85
+26 -22
View File
@@ -31,6 +31,23 @@ def download_file(url, file_name):
response = get(url) response = get(url)
file.write(response.content) file.write(response.content)
def download_files():
"""
Download all data and label files.
"""
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'train-images-idx3-ubyte.gz')
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'train-labels-idx1-ubyte.gz')
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
't10k-images-idx3-ubyte.gz')
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
def read_mnist(images_path: str, labels_path: str): def read_mnist(images_path: str, labels_path: str):
""" """
Read data and labels of the MNIST dataset. Read data and labels of the MNIST dataset.
@@ -128,19 +145,7 @@ def main():
# Step 1: # Step 1:
# Download the MNIST dataset with consist of labeled handwritten images (28x28 px). # Download the MNIST dataset with consist of labeled handwritten images (28x28 px).
download_files()
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'train-images-idx3-ubyte.gz')
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'train-labels-idx1-ubyte.gz')
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
't10k-images-idx3-ubyte.gz')
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
# Step 2: # Step 2:
# Read MNIST dataset (training and testing) # Read MNIST dataset (training and testing)
@@ -190,10 +195,9 @@ def main():
model = create_lenet5() model = create_lenet5()
# Step 8: # Step 8:
# Train the LeNet-5 model # Train and save the LeNet-5 model
train_lenet5(model) # train_lenet5(model)
# model.save('lenet5.h5')
model.save('lenet5.h5')
model = load_model('lenet5.h5') model = load_model('lenet5.h5')
@@ -210,12 +214,12 @@ def main():
# display_image(test, 1) # display_image(test, 1)
# display_image(test, 2) # display_image(test, 2)
outputs = model.predict(test['features']) netouts = model.predict(test['features'])
for i, item in enumerate(outputs[0]): for i, netout in enumerate(netouts):
print("%d: %.3f" % (i, item)) print("%d: %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f"
% (i, netout[0], netout[1], netout[2], netout[3], netout[4],
print(test['labels'][0]) netout[5], netout[6], netout[7], netout[8], netout[9]))
if __name__ == "__main__": if __name__ == "__main__":
main() main()