forked from Kashu7100/Qualia2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstl10.py
86 lines (76 loc) · 2.81 KB
/
stl10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-
from .. import to_cpu
from ..core import *
from .dataset import *
from .transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import tarfile
class STL10(Dataset):
'''STL10 Dataset\n
Args:
train (bool): if True, load training dataset
transforms (transforms): transforms to apply on the features
target_transforms (transforms): transforms to apply on the labels
Shape:
- data: [N, 3, 96, 96] if flatten [N, 3*96*96]
'''
def __init__(self, train=True,
transforms=Compose([ToTensor(), Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
target_transforms=None):
super().__init__(train, transforms, target_transforms)
def __len__(self):
if self.train:
return 5000
else:
return 8000
def state_dict(self):
return {
'label_map': stl10_labels
}
def prepare(self):
url = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'
self._download(url, 'stl10.tar.gz')
tarfile.open(self.root+'/stl10.tar.gz', 'r:gz').extractall(self.root)
if self.train:
self.data = self._load_data(self.root+'/stl10_binary/train_X.bin')
self.label = STL10.to_one_hot(self._load_label(self.root+'/stl10_binary/train_y.bin'), 10)
else:
self.data = self._load_data(self.root+'/stl10_binary/test_X.bin')
self.label = STL10.to_one_hot(self._load_label(self.root+'/stl10_binary/test_y.bin'), 10)
def _load_data(self, filename):
with open(filename, 'rb') as file:
if gpu:
import numpy
data = np.asarray(numpy.fromfile(file, numpy.uint8))
else:
data = np.fromfile(file, np.uint8)
return data.reshape(-1, 3, 96, 96).transpose(0,1,3,2)
def _load_label(self, filename):
with open(filename, 'rb') as file:
if gpu:
import numpy
labels = np.asarray(numpy.fromfile(file, numpy.uint8))
else:
labels = np.fromfile(file, np.uint8)
return labels-1
def show(self, row=5, col=5):
H, W = 96, 96
img = np.zeros((H*row, W*col, 3))
for r in range(row):
for c in range(col):
img[r*H:(r+1)*H, c*W:(c+1)*W] = self.data[random.randint(0, len(self.data)-1)].reshape(3,H,W).transpose(1,2,0)/255
plt.imshow(to_cpu(img), interpolation='nearest')
plt.axis('off')
plt.show()
stl10_labels = {
0: 'airplane',
1: 'bird',
2: 'car',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'horse',
7: 'monkey',
8: 'ship',
9: 'truck',
}