-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist.py
80 lines (60 loc) · 2.74 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gzip
import os
import struct
from array import array
import numpy as np
from scipy import stats
class MNIST():
def __init__(self,
path='./datasets/mnist',
one_hot_encoding=False,
z_score=False,
intercept=False,
shuffle=False):
self.path = path
self.test_img_fname = 't10k-images-idx3-ubyte.gz'
self.test_lbl_fname = 't10k-labels-idx1-ubyte.gz'
self.train_img_fname = 'train-images-idx3-ubyte.gz'
self.train_lbl_fname = 'train-labels-idx1-ubyte.gz'
test_data = self._load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))
self.test_images, self.test_labels = test_data
train_data = self._load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))
self.train_images, self.train_labels = train_data
N, _ = self.train_images.shape
M, _ = self.test_images.shape
if z_score:
self.train_images = stats.zscore(self.train_images, axis=1)
self.test_images = stats.zscore(self.test_images, axis=1)
if intercept:
self.train_images = np.concatenate((np.ones((N, 1), dtype=float),
self.train_images), axis=1)
self.test_images = np.concatenate((np.ones((M, 1), dtype=float),
self.test_images), axis=1)
if shuffle:
p = np.random.permutation(N)
self.train_images = self.train_images[p]
self.train_labels = self.train_labels[p]
if one_hot_encoding:
train_labels = np.zeros((N, 10), dtype=int)
for i, label in enumerate(self.train_labels):
train_labels[i][label] = 1
self.train_labels = train_labels
test_labels = np.zeros((M, 10), dtype=int)
for i, label in enumerate(self.test_labels):
test_labels[i][label] = 1
self.test_labels = test_labels
def _load(self, path_img, path_lbl):
with gzip.open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
labels = array("B", file.read())
with gzip.open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
image_data = array("B", file.read())
images = []
for i in range(size):
images.append([0] * rows * cols)
for i in range(size):
images[i][:] = image_data[i * rows * cols:(i + 1) * rows * cols]
return np.array(images), np.array(labels)