-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathshared_dict.py
51 lines (38 loc) · 1.12 KB
/
shared_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Script to demonstrate the usage of shared dicts using multiple workers.
In the first epoch the shared dict in the dataset will be filled with
random values. The next epochs will just use the dict without "loading" the
data again.
@author: ptrblck
"""
from multiprocessing import Manager
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, shared_dict, length):
self.shared_dict = shared_dict
self.length = length
def __getitem__(self, index):
if index not in self.shared_dict:
print('Adding {} to shared_dict'.format(index))
self.shared_dict[index] = torch.tensor(index)
return self.shared_dict[index]
def __len__(self):
return self.length
# Init
manager = Manager()
shared_dict = manager.dict()
dataset = MyDataset(shared_dict, length=100)
loader = DataLoader(
dataset,
batch_size=10,
num_workers=6,
shuffle=True,
pin_memory=True
)
# First loop will add data to the shared_dict
for x in loader:
print(x)
# The second loop will just get the data
for x in loader:
print(x)