-
Notifications
You must be signed in to change notification settings - Fork 34
/
datasets.py
102 lines (86 loc) · 4.44 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import logging
import numpy as np
from PIL import Image
from scipy.ndimage import convolve1d
from torch.utils import data
import torchvision.transforms as transforms
from utils import get_lds_kernel_window
print = logging.info
class IMDBWIKI(data.Dataset):
def __init__(self, df, data_dir, img_size, split='train', reweight='none',
lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
self.df = df
self.data_dir = data_dir
self.img_size = img_size
self.split = split
self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
index = index % len(self.df)
row = self.df.iloc[index]
img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB')
transform = self.get_transform()
img = transform(img)
label = np.asarray([row['age']]).astype('float32')
weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])
return img, label, weight
def get_transform(self):
if self.split == 'train':
transform = transforms.Compose([
transforms.Resize((self.img_size, self.img_size)),
transforms.RandomCrop(self.img_size, padding=16),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
])
else:
transform = transforms.Compose([
transforms.Resize((self.img_size, self.img_size)),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
])
return transform
def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
assert reweight in {'none', 'inverse', 'sqrt_inv'}
assert reweight != 'none' if lds else True, \
"Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"
value_dict = {x: 0 for x in range(max_target)}
labels = self.df['age'].values
for label in labels:
value_dict[min(max_target - 1, int(label))] += 1
if reweight == 'sqrt_inv':
value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
elif reweight == 'inverse':
value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight
num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
if not len(num_per_label) or reweight == 'none':
return None
print(f"Using re-weighting: [{reweight.upper()}]")
if lds:
lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
smoothed_value = convolve1d(
np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]
weights = [np.float32(1 / x) for x in num_per_label]
scaling = len(weights) / np.sum(weights)
weights = [scaling * x for x in weights]
return weights
def get_bucket_info(self, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
value_dict = {x: 0 for x in range(max_target)}
labels = self.df['age'].values
for label in labels:
if int(label) < max_target:
value_dict[int(label)] += 1
bucket_centers = np.asarray([k for k, _ in value_dict.items()])
bucket_weights = np.asarray([v for _, v in value_dict.items()])
if lds:
lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
bucket_weights = convolve1d(bucket_weights, weights=lds_kernel_window, mode='constant')
bucket_centers = np.asarray([bucket_centers[k] for k, v in enumerate(bucket_weights) if v > 0])
bucket_weights = np.asarray([bucket_weights[k] for k, v in enumerate(bucket_weights) if v > 0])
bucket_weights = bucket_weights / bucket_weights.sum()
return bucket_centers, bucket_weights