-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConcatDataset.py
203 lines (167 loc) · 8.41 KB
/
ConcatDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import bisect
from torch.utils.data import Dataset
import warnings
from torch._utils import _accumulate
from torch import randperm
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class IterableDataset(Dataset[T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
class ChainDataset(IterableDataset):
r"""Dataset for chainning multiple :class:`IterableDataset` s.
This class is useful to assemble different existing dataset streams. The
chainning operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.
Arguments:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ChainDataset, self).__init__()
self.datasets = datasets
def __iter__(self):
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
for x in d:
yield x
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
# Cannot verify that all self.datasets are Sized
total += len(d) # type: ignore
return total
class ConcatDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes