-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_util.py
156 lines (128 loc) · 4.71 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from torch.utils.data import (
Subset,
DataLoader,
BatchSampler,
RandomSampler,
)
from util import (
cached_property,
)
class Loaders():
split_names = ['train', 'dev', 'test']
use_batch_sampler = False
def loaders(
self,
batch_size,
*args,
eval_batch_size=None, split_names=None,
use_batch_sampler=False,
log=None,
**kwargs):
if not split_names:
split_names = self.split_names
self.batch_size = batch_size
self.eval_batch_size = eval_batch_size or batch_size
self.use_batch_sampler = use_batch_sampler
loaders = {
split_name: getattr(
self, split_name + '_loader')(*args, **kwargs)
for split_name in split_names}
loaders['train_inference'] = DataLoader(
Subset(self.train, list(range(len(self.dev)))),
batch_size=eval_batch_size)
if log is not None:
for split_name, loader in loaders.items():
log(f'{split_name} batches: {len(loader)}')
return loaders
def train_loader(self, *args, **kwargs):
assert 'train' in self.split_names
batch_size = (
kwargs.pop('batch_size')
if 'batch_size' in kwargs
else self.batch_size)
if self.use_batch_sampler:
batch_sampler = BatchSampler(
RandomSampler(self.train), batch_size, drop_last=False)
return DataLoader(
self.train, *args, batch_sampler=batch_sampler, **kwargs)
return DataLoader(
self.train, *args, batch_size=batch_size, **kwargs)
def dev_loader(self, *args, **kwargs):
assert 'dev' in self.split_names
batch_size = (
kwargs.pop('batch_size')
if 'batch_size' in kwargs
else self.eval_batch_size)
return DataLoader(
self.dev, *args, batch_size=batch_size, **kwargs, shuffle=False)
def test_loader(self, *args, **kwargs):
assert 'test' in self.split_names
batch_size = (
kwargs.pop('batch_size')
if 'batch_size' in kwargs
else self.eval_batch_size)
return DataLoader(
self.test, *args, batch_size=batch_size, **kwargs, shuffle=False)
class FixedSplits(Loaders):
def __init__(self, train=None, dev=None, test=None):
self.train = train
self.dev = dev
self.test = test
class WithSplits:
split_names = ['train', 'dev', 'test']
def __init__(self, *args, do_check_overlap=True, **kwargs):
self.do_check_overlap = do_check_overlap
@cached_property
def splits(self):
raise NotImplementedError
def load_raw_split(self, split_name):
raise NotImplementedError
@cached_property
def train_raw(self):
return self.load_raw_split('train')
@cached_property
def dev_raw(self):
return self.load_raw_split('dev')
@cached_property
def test_raw(self):
return self.load_raw_split('test')
@cached_property
def raw(self):
split_name2split = {
split_name: getattr(self, split_name + '_raw')
for split_name in self.split_names
}
if hasattr(self, 'instance_id') and getattr(self, 'do_check_overlap', True):
self.check_splits_overlap(split_name2split.values())
return split_name2split
@property
def train_loader(self):
return self.splits.train_loader(batch_size=self.conf.batch_size)
@property
def dev_loader(self):
return self.splits.dev_loader(batch_size=self.conf.eval_batch_size)
@property
def test_loader(self):
return self.splits.test_loader(batch_size=self.conf.eval_batch_size)
def check_splits_overlap(self, splits):
"""Make sure that there isn't any overlap between the splits, i.e.,
there shouldn't be any instances that are part of more than one split.
"""
from itertools import combinations
idss = [set(map(self.instance_id, split)) for split in splits]
for ids0, ids1 in combinations(idss, 2):
overlap = ids0 & ids1
assert not overlap
class TensorDictDataset():
"""Like Pytorch's TensorDict, but instead of storing multiple tensors
in a tuple, stores tensors in a dict."""
def __init__(self, **tensors):
assert all(
next(iter(tensors.values())).size(0) == t.size(0)
for t in tensors.values()
)
self.tensors = tensors
def __getitem__(self, index):
return {k: t[index] for k, t in self.tensors.items()}
def __len__(self):
return len(next(iter(self.tensors.values())))