-
Notifications
You must be signed in to change notification settings - Fork 1
/
DataLoaders.py
91 lines (78 loc) · 3.24 KB
/
DataLoaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from faug.imagenet_c_bar.transform_finder import build_transform
from faug.imagenet_c_bar.utils.converters import PilToNumpy, NumpyToTensor
import torchvision.transforms as transforms
from random import choice
def read_corruption_csv(filename = "faug/imagenet_c_bar/imagenet_c_bar.csv"):
with open(filename) as f:
lines = [l.rstrip() for l in f.readlines()]
corruptions = []
for line in lines:
vals = line.split(",")
if not vals:
continue
corruptions.extend([(vals[0], float(v)) for v in vals[1:]])
return corruptions
class EvalDataset(Dataset):
def __init__(self, main_dir, transform, classes, annot=None):
self.main_dir = main_dir
self.transform = transform
self.targ = annot
self.img_dir = os.path.join(main_dir, 'images')
self.total_imgs = os.listdir(self.img_dir)
if self.targ:
self.targ = {}
with open(os.path.join(main_dir, annot), 'r') as f:
lines = f.read().split('\n')
for line in lines[:-1]:
line = line.split('\t')
self.targ[line[0]] = classes.index(line[1])
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = os.path.join(self.img_dir, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
if self.targ:
return tensor_image, torch.tensor(self.targ[self.total_imgs[idx]])
return tensor_image
class CBar(ImageFolder):
def __init__(self, root, size):
super().__init__(root, transform=None)
self.size = size
self.corruptions = read_corruption_csv()
def __getitem__(self, idx):
name, severity = choice(self.corruptions)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
PilToNumpy(),
build_transform(name=name, severity=severity, dataset_type='imagenet'),
NumpyToTensor(),
transforms.Resize((self.size, self.size)),
# these are the standard norm vectors used for imagenet
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225]),
])
return super().__getitem__(idx)
class EvalCBar(EvalDataset):
def __init__(self, root, classes, size, annot=None):
super().__init__(root, None, classes, annot)
self.size = size
self.corruptions = read_corruption_csv()
def __getitem__(self, idx):
name, severity = choice(self.corruptions)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
PilToNumpy(),
build_transform(name=name, severity=severity, dataset_type='imagenet'),
NumpyToTensor(),
transforms.Resize((self.size, self.size)),
# these are the standard norm vectors used for imagenet
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225]),
])
return super().__getitem__(idx)