Skip to content

Commit

Permalink
2025-01-29 nightly release (5238ca3)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 29, 2025
1 parent c0dd69f commit f0ab13c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 12 deletions.
118 changes: 118 additions & 0 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,124 @@ def test_conv_after_fork(self):
self.assertEqual(x.shape, (1, 1, 1, 23999))


class _TestSlowIndexDataset(Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
self._worker_id = None

def __getitem__(self, idx):
if not self._worker_id:
worker_info = torch.utils.data.get_worker_info()
self._worker_id = worker_info.id
if idx == self.slow_index:
time.sleep(1.0)
return (self._worker_id, idx)

def __len__(self):
return self.end


class _TestSlowIterableDataset(IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)

def give_data(self, worker_id, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i == self.mid:
time.sleep(1.0)
yield (worker_id, i)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(worker_id, iter_start, iter_end)


class TestOutOfOrderDataLoader(TestCase):
def test_in_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=True,
)

expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(expected_data, data)

def test_out_of_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)

dataloader = DataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
# due to prefetch_factor being 2
# this makes the test more deterministic as [0, 2] will be the last elements
expected_worker_ids = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
expected_data = [1, 3, 4, 5, 6, 7, 8, 9, 0, 2]
outputs = list(dataloader)
worker_ids = [o[0].item() for o in outputs]
data = [o[1].item() for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertNotEqual(data, list(range(10)))
self.assertEqual(expected_data, data)

def test_in_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=True,
)

expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
expected_data = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(expected_data, data)

def test_out_of_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=False,
)

# worker 0 has [0, 1, 2, 3, 4], worker 1 has [5, 6, 7, 8, 9]
# index 5 is slow, so expect all of worker 0 before worker 1
expected_worker_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(sum(worker_ids), 5)
self.assertNotEqual(data, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])
self.assertEqual(expected_data, data)


instantiate_device_type_tests(TestDataLoaderDeviceType, globals())


Expand Down
8 changes: 4 additions & 4 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def __init__(
self._mp_context = mp.get_context(self.multiprocessing_context)

if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
if isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None
Expand Down Expand Up @@ -404,8 +404,8 @@ def __init__(
self.method = method
self.multiprocessing_context = multiprocessing_context
if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
if isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self.prebatch = prebatch
Expand Down
65 changes: 57 additions & 8 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class StatefulDataLoader(DataLoader[_T_co]):
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
snapshot_every_n_steps (int, optional): Defines how often the state is
transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step.
Expand Down Expand Up @@ -177,6 +179,10 @@ class StatefulDataLoader(DataLoader[_T_co]):
.. warning:: See `Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html#reproducibility>`_, and `Dataloader-workers-random-seed <https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed>`_, and
`Data-loading-randomness <https://pytorch.org/docs/stable/data.html#data-loading-randomness>`_ notes for random seed related questions.
.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.
.. warning:: Setting `in_order` to `False` currently has no guarantees for state management.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
Expand All @@ -202,6 +208,7 @@ def __init__(
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
snapshot_every_n_steps: Optional[int] = 1,
):
torch._C._log_api_usage_once("python.stateful_data_loader")
Expand All @@ -227,6 +234,13 @@ def __init__(
if persistent_workers and num_workers == 0:
raise ValueError("persistent_workers option needs num_workers > 0")

if num_workers > 0 and not in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)

self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
Expand All @@ -235,6 +249,7 @@ def __init__(
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
self.in_order = in_order

# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
Expand Down Expand Up @@ -876,6 +891,7 @@ def __init__(self, loader, next_iter_state):
super().__init__(loader)
self._snapshot_interval = loader.snapshot_every_n_steps
self._prefetch_factor = loader.prefetch_factor
self._in_order = loader.in_order

assert self._num_workers > 0
assert self._prefetch_factor > 0
Expand Down Expand Up @@ -1083,6 +1099,11 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True):
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
# A list of integers representing how many tasks are outstanding for each worker
# Incremented when a task is dispatched to the worker
# Decremented when that data has been given to the main thread
# Each worker should have at most self._prefetch_factor tasks outstanding
self._workers_num_tasks = [0 for i in range(self._num_workers)]
# Reset the worker queue cycle so it resumes next epoch at worker 0
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
remaining = self._num_workers
Expand Down Expand Up @@ -1149,6 +1170,12 @@ def _update_worker_snapshot(self, worker_key, state_dict):
self._worker_snapshots[worker_key].apply_delta(state_dict)

def state_dict(self):
if not self._in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)
steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP]
state_dict = {
self._SNAPSHOT: self._snapshot,
Expand Down Expand Up @@ -1352,11 +1379,12 @@ def _next_data(self):
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
info = self._task_info.get(self._rcvd_idx, None)
if info:
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
Expand All @@ -1374,6 +1402,7 @@ def _next_data(self):
self._rcvd_idx += 1
continue
else:
self._rcvd_idx += 1
return self._process_data(data, worker_id, state_dict)

assert not self._shutdown and self._tasks_outstanding > 0
Expand All @@ -1394,6 +1423,13 @@ def _next_data(self):

if idx != self._rcvd_idx:
# store out-of-order samples
if not self._in_order:
# don't store it for later, process now
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict)
continue
del self._task_info[idx]
return self._process_data(data, worker_id, state_dict)
self._task_info[idx] += ((data, worker_id, state_dict),)
else:
del self._task_info[idx]
Expand All @@ -1402,6 +1438,7 @@ def _next_data(self):
self._rcvd_idx += 1
continue
else:
self._rcvd_idx += 1
return self._process_data(data, worker_id, state_dict)

def _get_main_state(self):
Expand Down Expand Up @@ -1433,7 +1470,8 @@ def _restore_main_state(self, state_dict):
self._base_seed = state_dict[self._BASE_SEED]

def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
max_tasks = self._prefetch_factor * self._num_workers
assert self._tasks_outstanding < max_tasks

try:
index = self._next_index()
Expand Down Expand Up @@ -1461,7 +1499,12 @@ def _try_put_index(self):
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
if self._in_order:
break
elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status):
# when self._in_order is False, distribute work to a worker if it has capacity
# _workers_status is updated only in this thread, so the sum is guaranteed > 0
break
else:
# not found (i.e., didn't break)
return
Expand All @@ -1472,11 +1515,12 @@ def _try_put_index(self):

self._index_queues[worker_queue_idx].put((self._send_idx, (index, snapshot))) # type: ignore[possibly-undefined]
self._task_info[self._send_idx] = (worker_queue_idx,)
self._workers_num_tasks[worker_queue_idx] += 1
self._tasks_outstanding += 1
self._send_idx += 1

def _process_data(self, data, worker_id, state_dict):
self._rcvd_idx += 1
self._workers_num_tasks[worker_id] -= 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
Expand All @@ -1489,8 +1533,13 @@ def _process_data(self, data, worker_id, state_dict):
return data

def _take_snapshot(self):
main_snapshot_idx = None
while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1):
main_snapshot_idx, main_snapshot = self._main_snapshots.popleft()
if not self._in_order and main_snapshot_idx is None:
# in_order is False and no main snapshot is available as we're ahead of rcvd_idx
# we can't take a snapshot with the current implementation
return
assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1)
self._update_snapshot(
self._num_yielded + 1,
Expand Down

0 comments on commit f0ab13c

Please sign in to comment.