forked from ppwwyyxx/RAM-multiprocess-dataloader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain-torchserialize.py
executable file
·38 lines (31 loc) · 1000 Bytes
/
main-torchserialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python
import sys
import time
import torch
import multiprocessing as mp
from common import MemoryMonitor, create_coco, DatasetFromList, read_sample
from serialize import TorchSerializedList
def worker(_, dataset: torch.utils.data.Dataset):
while True:
for sample in dataset:
# read the data, with a fake latency
time.sleep(0.000001)
result = read_sample(sample)
if __name__ == "__main__":
start_method = sys.argv[1]
monitor = MemoryMonitor()
ds = DatasetFromList(TorchSerializedList(create_coco()))
print(monitor.table())
if start_method == "forkserver":
# Reduce 150M-per-process USS due to "import torch".
mp.set_forkserver_preload(["torch"])
ctx = torch.multiprocessing.start_processes(
worker, (ds, ), nprocs=4, join=False,
daemon=True, start_method=start_method)
[monitor.add_pid(pid) for pid in ctx.pids()]
try:
for k in range(100):
print(monitor.table())
time.sleep(1)
finally:
ctx.join()