From f0ab13cffd97ffbbd2176a87a62d57081b5cc55b Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 29 Jan 2025 11:34:30 +0000 Subject: [PATCH] 2025-01-29 nightly release (5238ca31eed61d404a3ecdf134432086b91071c2) --- test/stateful_dataloader/test_dataloader.py | 118 ++++++++++++++++++ torchdata/nodes/map.py | 8 +- .../stateful_dataloader.py | 65 ++++++++-- 3 files changed, 179 insertions(+), 12 deletions(-) diff --git a/test/stateful_dataloader/test_dataloader.py b/test/stateful_dataloader/test_dataloader.py index 40ed43cdb..17abd0dc5 100644 --- a/test/stateful_dataloader/test_dataloader.py +++ b/test/stateful_dataloader/test_dataloader.py @@ -3024,6 +3024,124 @@ def test_conv_after_fork(self): self.assertEqual(x.shape, (1, 1, 1, 23999)) +class _TestSlowIndexDataset(Dataset): + def __init__(self, end: int, slow_index: int): + self.end = end + self.slow_index = slow_index + self._worker_id = None + + def __getitem__(self, idx): + if not self._worker_id: + worker_info = torch.utils.data.get_worker_info() + self._worker_id = worker_info.id + if idx == self.slow_index: + time.sleep(1.0) + return (self._worker_id, idx) + + def __len__(self): + return self.end + + +class _TestSlowIterableDataset(IterableDataset): + def __init__(self, start: int, end: int): + self.start = start + self.end = end + self.mid = math.ceil((self.end - self.start) / 2) + + def give_data(self, worker_id, iter_start, iter_end): + for i in range(iter_start, iter_end): + if i == self.mid: + time.sleep(1.0) + yield (worker_id, i) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + worker_id = worker_info.id + iter_start = self.start + worker_id * per_worker + iter_end = min(iter_start + per_worker, self.end) + return self.give_data(worker_id, iter_start, iter_end) + + +class TestOutOfOrderDataLoader(TestCase): + def test_in_order_index_ds(self): + dataset = _TestSlowIndexDataset(end=10, slow_index=0) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=True, + ) + + expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(expected_data, data) + + def test_out_of_order_index_ds(self): + dataset = _TestSlowIndexDataset(end=10, slow_index=0) + + dataloader = DataLoader( + dataset, + num_workers=2, + prefetch_factor=2, + in_order=False, + ) + + # worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue + # due to prefetch_factor being 2 + # this makes the test more deterministic as [0, 2] will be the last elements + expected_worker_ids = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0] + expected_data = [1, 3, 4, 5, 6, 7, 8, 9, 0, 2] + outputs = list(dataloader) + worker_ids = [o[0].item() for o in outputs] + data = [o[1].item() for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertNotEqual(data, list(range(10))) + self.assertEqual(expected_data, data) + + def test_in_order_iterable_ds(self): + dataset = _TestSlowIterableDataset(start=0, end=10) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=True, + ) + + expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + expected_data = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(expected_data, data) + + def test_out_of_order_iterable_ds(self): + dataset = _TestSlowIterableDataset(start=0, end=10) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=False, + ) + + # worker 0 has [0, 1, 2, 3, 4], worker 1 has [5, 6, 7, 8, 9] + # index 5 is slow, so expect all of worker 0 before worker 1 + expected_worker_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(sum(worker_ids), 5) + self.assertNotEqual(data, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]) + self.assertEqual(expected_data, data) + + instantiate_device_type_tests(TestDataLoaderDeviceType, globals()) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 04430d894..fab6a0f6c 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -318,8 +318,8 @@ def __init__( self._mp_context = mp.get_context(self.multiprocessing_context) if max_concurrent is not None and num_workers > 0: - if not isinstance(max_concurrent, int) and max_concurrent > num_workers: - raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + if isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!") self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None @@ -404,8 +404,8 @@ def __init__( self.method = method self.multiprocessing_context = multiprocessing_context if max_concurrent is not None and num_workers > 0: - if not isinstance(max_concurrent, int) and max_concurrent > num_workers: - raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + if isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!") self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency self.prebatch = prebatch diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 9b162b4f8..078b378ee 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -147,6 +147,8 @@ class StatefulDataLoader(DataLoader[_T_co]): maintain the workers `Dataset` instances alive. (default: ``False``) pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is ``True``. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) snapshot_every_n_steps (int, optional): Defines how often the state is transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step. @@ -177,6 +179,10 @@ class StatefulDataLoader(DataLoader[_T_co]): .. warning:: See `Reproducibility `_, and `Dataloader-workers-random-seed `_, and `Data-loading-randomness `_ notes for random seed related questions. + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data. + + .. warning:: Setting `in_order` to `False` currently has no guarantees for state management. + .. _multiprocessing context: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ @@ -202,6 +208,7 @@ def __init__( prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = "", + in_order: bool = True, snapshot_every_n_steps: Optional[int] = 1, ): torch._C._log_api_usage_once("python.stateful_data_loader") @@ -227,6 +234,13 @@ def __init__( if persistent_workers and num_workers == 0: raise ValueError("persistent_workers option needs num_workers > 0") + if num_workers > 0 and not in_order: + # TODO: remove warning log when state management is supported with in_order=False + logger.warning( + "using in_order=False with multiple workers does not give any guarantees for state management " + "and loading from a checkpoint may not work as expected." + ) + self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor @@ -235,6 +249,7 @@ def __init__( self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context + self.in_order = in_order # Adds forward compatibilities so classic DataLoader can work with DataPipes: # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler @@ -876,6 +891,7 @@ def __init__(self, loader, next_iter_state): super().__init__(loader) self._snapshot_interval = loader.snapshot_every_n_steps self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order assert self._num_workers > 0 assert self._prefetch_factor > 0 @@ -1083,6 +1099,11 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] + # A list of integers representing how many tasks are outstanding for each worker + # Incremented when a task is dispatched to the worker + # Decremented when that data has been given to the main thread + # Each worker should have at most self._prefetch_factor tasks outstanding + self._workers_num_tasks = [0 for i in range(self._num_workers)] # Reset the worker queue cycle so it resumes next epoch at worker 0 self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) remaining = self._num_workers @@ -1149,6 +1170,12 @@ def _update_worker_snapshot(self, worker_key, state_dict): self._worker_snapshots[worker_key].apply_delta(state_dict) def state_dict(self): + if not self._in_order: + # TODO: remove warning log when state management is supported with in_order=False + logger.warning( + "using in_order=False with multiple workers does not give any guarantees for state management " + "and loading from a checkpoint may not work as expected." + ) steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP] state_dict = { self._SNAPSHOT: self._snapshot, @@ -1352,11 +1379,12 @@ def _next_data(self): # call and `_IterableDatasetStopIteration` check below can mark # extra worker(s) as dead. while self._rcvd_idx < self._send_idx: - info = self._task_info[self._rcvd_idx] - worker_id = info[0] - if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active - break - del self._task_info[self._rcvd_idx] + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + break + del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) @@ -1374,6 +1402,7 @@ def _next_data(self): self._rcvd_idx += 1 continue else: + self._rcvd_idx += 1 return self._process_data(data, worker_id, state_dict) assert not self._shutdown and self._tasks_outstanding > 0 @@ -1394,6 +1423,13 @@ def _next_data(self): if idx != self._rcvd_idx: # store out-of-order samples + if not self._in_order: + # don't store it for later, process now + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + continue + del self._task_info[idx] + return self._process_data(data, worker_id, state_dict) self._task_info[idx] += ((data, worker_id, state_dict),) else: del self._task_info[idx] @@ -1402,6 +1438,7 @@ def _next_data(self): self._rcvd_idx += 1 continue else: + self._rcvd_idx += 1 return self._process_data(data, worker_id, state_dict) def _get_main_state(self): @@ -1433,7 +1470,8 @@ def _restore_main_state(self, state_dict): self._base_seed = state_dict[self._BASE_SEED] def _try_put_index(self): - assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + max_tasks = self._prefetch_factor * self._num_workers + assert self._tasks_outstanding < max_tasks try: index = self._next_index() @@ -1461,7 +1499,12 @@ def _try_put_index(self): for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: - break + if self._in_order: + break + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): + # when self._in_order is False, distribute work to a worker if it has capacity + # _workers_status is updated only in this thread, so the sum is guaranteed > 0 + break else: # not found (i.e., didn't break) return @@ -1472,11 +1515,12 @@ def _try_put_index(self): self._index_queues[worker_queue_idx].put((self._send_idx, (index, snapshot))) # type: ignore[possibly-undefined] self._task_info[self._send_idx] = (worker_queue_idx,) + self._workers_num_tasks[worker_queue_idx] += 1 self._tasks_outstanding += 1 self._send_idx += 1 def _process_data(self, data, worker_id, state_dict): - self._rcvd_idx += 1 + self._workers_num_tasks[worker_id] -= 1 self._try_put_index() if isinstance(data, ExceptionWrapper): data.reraise() @@ -1489,8 +1533,13 @@ def _process_data(self, data, worker_id, state_dict): return data def _take_snapshot(self): + main_snapshot_idx = None while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() + if not self._in_order and main_snapshot_idx is None: + # in_order is False and no main snapshot is available as we're ahead of rcvd_idx + # we can't take a snapshot with the current implementation + return assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) self._update_snapshot( self._num_yielded + 1,