-
Notifications
You must be signed in to change notification settings - Fork 12
/
datasets.py
123 lines (105 loc) · 4.58 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
from torchvision import transforms, datasets
def prepare(args, data_root='./data', train_data_aug=True):
if args.dataset == 'mnist':
'''
Setups:
-Original train data is split into (train, val)
-Original test data is used as a test split as it is
'''
# data augmentation
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])
train_set = datasets.MNIST(
root=data_root,
train=True, download=True,
transform = transform_train if train_data_aug else transform_test
)
if args.val_heldout > 0:
val_size = int(args.val_heldout * len(train_set))
train_size = len(train_set) - val_size
val_set = datasets.MNIST(
root=data_root,
train=True, download=True, transform=transform_test
)
train_set, _ = torch.utils.data.random_split(train_set, [train_size, val_size])
_, val_set = torch.utils.data.random_split(val_set, [train_size, val_size])
test_set = datasets.MNIST(
root=data_root,
train=False, download=True, transform=transform_test
)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=args.batch_size, shuffle=train_data_aug, pin_memory=args.use_cuda, num_workers=4
)
if args.val_heldout > 0:
val_loader = torch.utils.data.DataLoader(val_set,
batch_size=args.batch_size, shuffle=False, pin_memory=args.use_cuda, num_workers=4
)
else:
val_loader = None
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=args.batch_size, shuffle=False, pin_memory=args.use_cuda, num_workers=4
)
elif args.dataset == 'pets':
'''
Setups:
-Original train and val data are merged, then split into (train, val)
-Original test data is used as a test split as it is
'''
# data augmentation
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# official train+val split
train_set = datasets.OxfordIIITPet(
root = data_root, split = 'trainval',
transform = transform_train if train_data_aug else transform_test,
download = True
)
if args.val_heldout > 0: # re-split if args.val_heldout > 0
val_set = datasets.OxfordIIITPet(
root = data_root, split = 'trainval',
transform = transform_test, download = True
)
val_size = int(args.val_heldout * len(train_set))
train_size = len(train_set) - val_size
generator = torch.Generator().manual_seed(args.seed)
train_set, _ = torch.utils.data.random_split(train_set, [train_size, val_size], generator=generator)
val_set = torch.utils.data.Subset(val_set, np.setdiff1d(np.arange(len(val_set)), train_set.indices))
test_set = datasets.OxfordIIITPet(
root = data_root, split = 'test',
transform = transform_test, download = True
)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=args.batch_size, shuffle=train_data_aug, pin_memory=args.use_cuda, num_workers=4
)
if args.val_heldout > 0:
val_loader = torch.utils.data.DataLoader(val_set,
batch_size=args.batch_size, shuffle=False, pin_memory=args.use_cuda, num_workers=4
)
else:
val_loader = None
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=args.batch_size, shuffle=False, pin_memory=args.use_cuda, num_workers=4
)
args.num_classes = 37
else:
raise NotImplementedError
return train_loader, val_loader, test_loader, len(train_set)