-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHCPLoader.py
81 lines (55 loc) · 2.72 KB
/
HCPLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
from torchvision import datasets
from utils import *
# prepare data
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
def splitData(dataset, ratio):
indices = list(range(len(dataset)))
split_pos = int(np.floor(ratio * len(dataset)))
np.random.shuffle(indices) # incorporates shuffle here
trn_idx, tst_idx = indices[split_pos:], indices[:split_pos]
return trn_idx, tst_idx
def createLoader(idx, dataset, batch, shuff):
sampler = SubsetRandomSampler(idx) if shuff else SequentialSampler(idx)
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=batch)
return loader
def split_trn_tst(train_data, valid_data, ratio=.2, batch_size=16):
trn_idx, tst_idx = splitData(train_data, ratio)
trn_sampler = SubsetRandomSampler(trn_idx)
tst_sampler = SubsetRandomSampler(tst_idx)
trainloader = torch.utils.data.DataLoader(train_data, sampler=trn_sampler, batch_size=batch_size, num_workers=2)
validloader = torch.utils.data.DataLoader(valid_data, sampler=tst_sampler, batch_size=batch_size, num_workers=2,
drop_last=True)
return trainloader, validloader
def test_loader(source, trans_test):
test_data = datasets.ImageFolder(source + "test/", transform=trans_test)
testloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
print("test batch count: ", len(testloader))
return testloader
def train_valid_loader(source, trans_train, trans_test):
train_data = datasets.ImageFolder(source + "train/", transform=trans_train)
valid_data = datasets.ImageFolder(source + "train/", transform=trans_test)
trainloader, validloader = split_trn_tst(train_data, valid_data, batch_size=BATCH_SIZE, ratio=RATIO)
print("train batch count: ", len(trainloader))
print("validation batch count: ", len(validloader))
return trainloader, validloader
class MapDataset(torch.utils.data.Dataset):
"""
Given a dataset, creates a dataset which applies a mapping function
to its items (lazily, only when an item is called).
Note that data is not cloned/copied from the initial dataset.
"""
def __init__(self, dataset, map_fn):
self.dataset = dataset
self.map = map_fn
def __getitem__(self, index):
return self.map(self.dataset[index])
def __len__(self):
return len(self.dataset)