-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathutil_MNIST.py
71 lines (55 loc) · 2.14 KB
/
util_MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
img_rows, img_cols = 28, 28
def retrieveMNISTTrainingData(batch_size=128, shuffle=True):
"""
Retrieve a training dataset of MNIST.
Arguments:
batch_size: batch size
shuffle: whether the training data should be shuffled
Returns:
data loader for the MNIST training data
"""
transform = transforms.Compose([transforms.ToTensor()])
MNIST_train_data = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
MNIST_train_data, batch_size=batch_size, shuffle=shuffle, num_workers=0)
return train_loader
def retrieveMNISTTestData(batch_size=128, shuffle=False):
"""
Retrieve a test dataset of MNIST.
Arguments:
batch_size: batch size
shuffle: whether the test data should be shuffled
Returns:
data loader for the MNIST test data
"""
transform = transforms.Compose([transforms.ToTensor()])
MNIST_test_data = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
MNIST_test_data, batch_size=batch_size, shuffle=shuffle, num_workers=0)
return test_loader
def displayImage(image, label):
"""
Display an image of a digit from MNIST.
Arguments:
image: input image. The shape of this input must be compatible
with (img_rows, img_cols).
label: prediction on this input image
"""
image = image.view((img_rows, img_cols))
plt.imshow(image, vmin=0.0, vmax=1.0, cmap='gray')
plt.title("Predicted label: {}".format(label))
plt.show()
if __name__ == "__main__":
train_loader = retrieveMNISTTrainingData(batch_size=1, shuffle=False)
print("MNIST training data are loaded.")
train_iterator = iter(train_loader)
images, labels = train_iterator.next()
print("The type of the image is {}.".format(type(images)))
print("The size of the image is {}.".format(images.size()))