-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_loader_v5.py
76 lines (62 loc) · 2.94 KB
/
data_loader_v5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
"""
args: path
"""
dict = {'ca': 0, 'ne': 1, 'se': 2}
# train transformer
train_transformer = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()])
# transforms.Normalize((0.4793, 0.4921, 0.4731), (0.0670, 0.0837, 0.1140))])
# evl and test transformer
eval_transformer = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()])
# transforms.Normalize((0.4789, 0.4905, 0.4740), (0.2007, 0.2004, 0.2277))])
class BuildingDatasetWithRegion(Dataset):
def __init__(self, data_dir, transform):
self.transform = transform
self.images = []
self.labels_type = []
self.labels_region = []
self.root = os.listdir(data_dir)
self.filenames = [os.path.join(data_dir, f) for f in self.root]
for region in self.filenames:
type = [os.path.join(region, f) for f in os.listdir(region)]
for i, file in enumerate(type):
for image in os.listdir(file):
src = os.path.join(file, image)
self.images.append(src)
self.labels_type.append(i)
self.labels_region.append(os.path.basename(os.path.normpath(region)))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx])
image = self.transform(image)
return image, self.labels_type[idx], dict[self.labels_region[idx]]
# load a train, val, text in mini-batch size
def fetch_dataloader(types, data_dir, params):
dataloaders = {}
for split in ['train', 'test']:
if split in types:
path = os.path.join(data_dir, "{}".format(split))
# use the train_transformer if training data, else use eval_transformer without random flip
if split == 'train':
dl = DataLoader(BuildingDatasetWithRegion(path, train_transformer), batch_size=params.batch_size, shuffle=True,
num_workers=params.num_workers, pin_memory=params.cuda)
val_dl_length = int(len(dl) * 0.1)
train_dl, val_dl = random_split(dl, [int(len(dl) - val_dl_length), val_dl_length])
# automatically split val data from train data as 10% of total train data
dataloaders[split] = train_dl.dataset
dataloaders['val'] = val_dl.dataset
else:
dl = DataLoader(BuildingDatasetWithRegion(path, eval_transformer), batch_size=params.batch_size, shuffle=False,
num_workers=params.num_workers, pin_memory=params.cuda)
dataloaders[split] = dl
return dataloaders