-
Notifications
You must be signed in to change notification settings - Fork 184
/
handlers.py
49 lines (40 loc) · 1.47 KB
/
handlers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class MNIST_Handler(Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def __getitem__(self, index):
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x.numpy(), mode='L')
x = self.transform(x)
return x, y, index
def __len__(self):
return len(self.X)
class SVHN_Handler(Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))])
def __getitem__(self, index):
x, y = self.X[index], self.Y[index]
x = Image.fromarray(np.transpose(x, (1, 2, 0)))
x = self.transform(x)
return x, y, index
def __len__(self):
return len(self.X)
class CIFAR10_Handler(Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
def __getitem__(self, index):
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
x = self.transform(x)
return x, y, index
def __len__(self):
return len(self.X)