Skip to content

Commit

Permalink
[Bugfix] fix dataloader when setting persistent_workers=True
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII committed Jul 5, 2024
1 parent 9ce1226 commit 7a3835f
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions python/paddle/io/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import paddle
from paddle import profiler
from paddle.base.framework import _current_expected_place, _set_expected_place
from paddle.incubate import multiprocessing
from paddle.pir.core import datatype_to_vartype
from paddle.profiler.timer import benchmark
from paddle.profiler.utils import in_profiler_mode
Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(self, loader):
# only single process is used in single-process mode, we
# can record the data structure sequencely in a list without
# recording the send and recv index
self._structure_infos = []
self._structure_infos = multiprocessing.Queue()

# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache "self._prefetch_factor" iteration datas
Expand Down Expand Up @@ -251,7 +252,7 @@ def _thread_loop(self, legacy_expected_place):

# flat batch and record structure infos
batch, structure = _flatten_batch(batch)
self._structure_infos.append(structure)
self._structure_infos.put(structure)

if self._thread_done_event.is_set():
break
Expand Down Expand Up @@ -297,15 +298,16 @@ def __next__(self):
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
data = _restore_batch(data, self._structure_infos.pop(0))
structure_info = self._structure_infos.get()
data = _restore_batch(data, structure_info)
else:
# in static graph mode
if self._return_list:
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
structs = [
self._structure_infos.pop(0)
self._structure_infos.get()
for _ in range(len(self._places))
]
data = [_restore_batch(d, s) for d, s in zip(data, structs)]
Expand Down Expand Up @@ -384,7 +386,7 @@ def __init__(self, loader):
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []
self._structure_infos = multiprocessing.Queue()

# indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity.
Expand Down Expand Up @@ -433,8 +435,6 @@ def __init__(self, loader):
self._shutdown = False

def _init_workers(self):
from paddle.incubate import multiprocessing

# multiprocess worker and indice queue list initial as empty
self._workers = []
self._worker_status = []
Expand Down Expand Up @@ -562,7 +562,8 @@ def _reset(self):
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []
while not self._structure_infos.empty():
self._structure_infos.get()

# set all worker status available
self._worker_status = [True] * self._num_workers
Expand Down Expand Up @@ -690,7 +691,7 @@ def _get_data(self):
and len(self._task_infos[self._rcvd_idx]) == 3
):
info = self._task_infos.pop(self._rcvd_idx)
self._structure_infos.append(info[2])
self._structure_infos.put(info[2])
return info[1]

try:
Expand Down Expand Up @@ -767,7 +768,7 @@ def _get_data(self):
if idx == self._rcvd_idx:
if idx in self._task_infos:
del self._task_infos[idx]
self._structure_infos.append(structure)
self._structure_infos.put(structure)
return batch
else:
self._task_infos[idx] += (batch, structure)
Expand Down Expand Up @@ -838,14 +839,15 @@ def __next__(self):
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
data = _restore_batch(data, self._structure_infos.pop(0))
structure_info = self._structure_infos.get()
data = _restore_batch(data, structure_info)
else:
if self._return_list:
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
structs = [
self._structure_infos.pop(0)
self._structure_infos.get()
for _ in range(len(self._places))
]
data = [_restore_batch(d, s) for d, s in zip(data, structs)]
Expand Down

0 comments on commit 7a3835f

Please sign in to comment.