-
Notifications
You must be signed in to change notification settings - Fork 15
/
fast_tensor_data_loader.py
48 lines (41 loc) · 1.62 KB
/
fast_tensor_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
class FastTensorDataLoader:
"""
A DataLoader-like object for a set of tensors that can be much faster than
TensorDataset + DataLoader because dataloader grabs individual indices of
the dataset and calls cat (slow).
Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
"""
def __init__(self, *tensors, batch_size=32, shuffle=False):
"""
Initialize a FastTensorDataLoader.
:param *tensors: tensors to store. Must have the same length @ dim 0.
:param batch_size: batch size to load.
:param shuffle: if True, shuffle the data *in-place* whenever an
iterator is created out of this object.
:returns: A FastTensorDataLoader.
"""
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
self.tensors = tensors
self.dataset_len = self.tensors[0].shape[0]
self.batch_size = batch_size
self.shuffle = shuffle
# Calculate # batches
n_batches, remainder = divmod(self.dataset_len, self.batch_size)
if remainder > 0:
n_batches += 1
self.n_batches = n_batches
def __iter__(self):
if self.shuffle:
r = torch.randperm(self.dataset_len)
self.tensors = [t[r] for t in self.tensors]
self.i = 0
return self
def __next__(self):
if self.i >= self.dataset_len:
raise StopIteration
batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
self.i += self.batch_size
return batch
def __len__(self):
return self.n_batches