From 0a5635fe85943a1135a25a9ee5be47c6836154fc Mon Sep 17 00:00:00 2001 From: Heiko J Schick Date: Thu, 22 Oct 2020 09:09:22 +0200 Subject: [PATCH] Print all netout and moved th downloads of all files in seperate function --- lenet5.py | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/lenet5.py b/lenet5.py index 8455d58..2845c2f 100644 --- a/lenet5.py +++ b/lenet5.py @@ -31,6 +31,23 @@ def download_file(url, file_name): response = get(url) file.write(response.content) +def download_files(): + """ + Download all data and label files. + """ + # train-images-idx3-ubyte.gz: training set images (9912422 bytes) + download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', + 'train-images-idx3-ubyte.gz') + # train-labels-idx1-ubyte.gz: training set labels (28881 bytes) + download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', + 'train-labels-idx1-ubyte.gz') + # t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) + download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', + 't10k-images-idx3-ubyte.gz') + # t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) + download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', + 't10k-labels-idx1-ubyte.gz') + def read_mnist(images_path: str, labels_path: str): """ Read data and labels of the MNIST dataset. @@ -128,19 +145,7 @@ def main(): # Step 1: # Download the MNIST dataset with consist of labeled handwritten images (28x28 px). - - # train-images-idx3-ubyte.gz: training set images (9912422 bytes) - download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', - 'train-images-idx3-ubyte.gz') - # train-labels-idx1-ubyte.gz: training set labels (28881 bytes) - download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', - 'train-labels-idx1-ubyte.gz') - # t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) - download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', - 't10k-images-idx3-ubyte.gz') - # t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) - download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', - 't10k-labels-idx1-ubyte.gz') + download_files() # Step 2: # Read MNIST dataset (training and testing) @@ -190,10 +195,9 @@ def main(): model = create_lenet5() # Step 8: - # Train the LeNet-5 model - train_lenet5(model) - - model.save('lenet5.h5') + # Train and save the LeNet-5 model + # train_lenet5(model) + # model.save('lenet5.h5') model = load_model('lenet5.h5') @@ -210,12 +214,12 @@ def main(): # display_image(test, 1) # display_image(test, 2) - outputs = model.predict(test['features']) + netouts = model.predict(test['features']) - for i, item in enumerate(outputs[0]): - print("%d: %.3f" % (i, item)) - - print(test['labels'][0]) + for i, netout in enumerate(netouts): + print("%d: %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f" + % (i, netout[0], netout[1], netout[2], netout[3], netout[4], + netout[5], netout[6], netout[7], netout[8], netout[9])) if __name__ == "__main__": main()