Print all netout and moved th downloads of all files in seperate function
This commit is contained in:
@@ -31,6 +31,23 @@ def download_file(url, file_name):
|
|||||||
response = get(url)
|
response = get(url)
|
||||||
file.write(response.content)
|
file.write(response.content)
|
||||||
|
|
||||||
|
def download_files():
|
||||||
|
"""
|
||||||
|
Download all data and label files.
|
||||||
|
"""
|
||||||
|
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
|
||||||
|
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
||||||
|
'train-images-idx3-ubyte.gz')
|
||||||
|
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
|
||||||
|
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
||||||
|
'train-labels-idx1-ubyte.gz')
|
||||||
|
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
|
||||||
|
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
||||||
|
't10k-images-idx3-ubyte.gz')
|
||||||
|
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
|
||||||
|
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
||||||
|
't10k-labels-idx1-ubyte.gz')
|
||||||
|
|
||||||
def read_mnist(images_path: str, labels_path: str):
|
def read_mnist(images_path: str, labels_path: str):
|
||||||
"""
|
"""
|
||||||
Read data and labels of the MNIST dataset.
|
Read data and labels of the MNIST dataset.
|
||||||
@@ -128,19 +145,7 @@ def main():
|
|||||||
|
|
||||||
# Step 1:
|
# Step 1:
|
||||||
# Download the MNIST dataset with consist of labeled handwritten images (28x28 px).
|
# Download the MNIST dataset with consist of labeled handwritten images (28x28 px).
|
||||||
|
download_files()
|
||||||
# train-images-idx3-ubyte.gz: training set images (9912422 bytes)
|
|
||||||
download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
|
|
||||||
'train-images-idx3-ubyte.gz')
|
|
||||||
# train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
|
|
||||||
download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
|
|
||||||
'train-labels-idx1-ubyte.gz')
|
|
||||||
# t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
|
|
||||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
|
|
||||||
't10k-images-idx3-ubyte.gz')
|
|
||||||
# t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
|
|
||||||
download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
|
|
||||||
't10k-labels-idx1-ubyte.gz')
|
|
||||||
|
|
||||||
# Step 2:
|
# Step 2:
|
||||||
# Read MNIST dataset (training and testing)
|
# Read MNIST dataset (training and testing)
|
||||||
@@ -190,10 +195,9 @@ def main():
|
|||||||
model = create_lenet5()
|
model = create_lenet5()
|
||||||
|
|
||||||
# Step 8:
|
# Step 8:
|
||||||
# Train the LeNet-5 model
|
# Train and save the LeNet-5 model
|
||||||
train_lenet5(model)
|
# train_lenet5(model)
|
||||||
|
# model.save('lenet5.h5')
|
||||||
model.save('lenet5.h5')
|
|
||||||
|
|
||||||
model = load_model('lenet5.h5')
|
model = load_model('lenet5.h5')
|
||||||
|
|
||||||
@@ -210,12 +214,12 @@ def main():
|
|||||||
# display_image(test, 1)
|
# display_image(test, 1)
|
||||||
# display_image(test, 2)
|
# display_image(test, 2)
|
||||||
|
|
||||||
outputs = model.predict(test['features'])
|
netouts = model.predict(test['features'])
|
||||||
|
|
||||||
for i, item in enumerate(outputs[0]):
|
for i, netout in enumerate(netouts):
|
||||||
print("%d: %.3f" % (i, item))
|
print("%d: %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f %.0f"
|
||||||
|
% (i, netout[0], netout[1], netout[2], netout[3], netout[4],
|
||||||
print(test['labels'][0])
|
netout[5], netout[6], netout[7], netout[8], netout[9]))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user