From 4a9fc7ef12888a730a3d9a75a51cc690415555c7 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Thu, 26 Dec 2024 12:00:06 -0800 Subject: [PATCH 1/6] Add unbatcher --- test/nodes/test_batch.py | 26 +++++++++++++++++- torchdata/nodes/__init__.py | 3 +- torchdata/nodes/batch.py | 55 ++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/test/nodes/test_batch.py b/test/nodes/test_batch.py index 1e2dd7fa8..ca518324e 100644 --- a/test/nodes/test_batch.py +++ b/test/nodes/test_batch.py @@ -9,7 +9,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import TestCase -from torchdata.nodes.batch import Batcher +from torchdata.nodes.batch import Batcher, Unbatcher from .utils import MockSource, run_test_save_load_state @@ -48,3 +48,27 @@ def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool): src = MockSource(num_samples=20) node = Batcher(src, batch_size=batch_size, drop_last=drop_last) run_test_save_load_state(self, node, midpoint) + + +class TestUnbatcher(TestCase): + def test_unbatcher(self) -> None: + batch_size = 6 + n = 20 + src = MockSource(num_samples=n) + node = Batcher(src, batch_size=batch_size, drop_last=False) + node = Unbatcher(node) + + results = list(node) + self.assertEqual(len(results), n) + for i in range(n): + self.assertEqual(results[i]["step"], i) + self.assertEqual(results[i]["test_tensor"], torch.tensor([i])) + self.assertEqual(results[i]["test_str"], f"str_{i}") + + @parameterized.expand(itertools.product([0, 2], [True, False])) + def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool): + batch_size = 6 + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size, drop_last=drop_last) + node = Unbatcher(node) + run_test_save_load_state(self, node, midpoint) diff --git a/torchdata/nodes/__init__.py b/torchdata/nodes/__init__.py index 2f8d1f287..62eaae517 100644 --- a/torchdata/nodes/__init__.py +++ b/torchdata/nodes/__init__.py @@ -6,7 +6,7 @@ from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper from .base_node import BaseNode, T -from .batch import Batcher +from .batch import Batcher, Unbatcher from .loader import Loader from .map import Mapper, ParallelMapper from .pin_memory import PinMemory @@ -31,6 +31,7 @@ "Stateful", "StopCriteria", "T", + "Unbatcher", ] assert sorted(__all__) == __all__ diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index 184608907..a44f05d5f 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence from torchdata.nodes.base_node import BaseNode, T @@ -56,3 +56,56 @@ def next(self) -> List[T]: def get_state(self) -> Dict[str, Any]: return {self.SOURCE_KEY: self.source.state_dict()} + + +class Unbatcher(BaseNode[Sequence[T]]): + """Unbatcher will flatten batches pulled from source, and + yields elements in sequential order when next() is called on it. + + Args: + source (BaseNode[T]): The source node to pull batches from. + """ + + SOURCE_KEY = "source" + BATCH_IDX_KEY = "batch_idx" + + def __init__(self, source: BaseNode[Sequence[T]]): + super().__init__(self) + self.source = source + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + if initial_state is not None: + self.source.reset(initial_state[self.SOURCE_KEY]) + self._cached_state_dict = initial_state[self.SOURCE_KEY] + try: + self._batch = next(self.source) + self._batch_idx = initial_state[self.BATCH_IDX_KEY] + except StopIteration: + # next(self.source) will be called upon subsequent self.next() call + # and raise StopIteration in the correct place. + self._batch = [] + self._batch_idx = 0 + else: + self.source.reset() + self._batch = [] + self._cached_state_dict = None + self._batch_idx = 0 + + def next(self) -> T: + while self._batch_idx >= len(self._batch): + self._cached_state_dict = self.source.state_dict() + self._batch = next(self.source) + self._batch_idx = 0 + + self._batch_idx += 1 + return self._batch[self._batch_idx - 1] + + def get_state(self) -> Dict[str, Any]: + if self._cached_state_dict is None: + self._cached_state_dict = self.source.state_dict() + + return { + self.SOURCE_KEY: self._cached_state_dict, + self.BATCH_IDX_KEY: self._batch_idx, + } From b771bea428a8b4dcb248c697747a48e4e0b720d4 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Thu, 26 Dec 2024 12:01:44 -0800 Subject: [PATCH 2/6] fix type annotation --- torchdata/nodes/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index a44f05d5f..7e7ca47de 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -58,7 +58,7 @@ def get_state(self) -> Dict[str, Any]: return {self.SOURCE_KEY: self.source.state_dict()} -class Unbatcher(BaseNode[Sequence[T]]): +class Unbatcher(BaseNode[T]): """Unbatcher will flatten batches pulled from source, and yields elements in sequential order when next() is called on it. From 8a9ba5b229efb81226051f249da1afd3e6cbb088 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Thu, 26 Dec 2024 13:26:19 -0800 Subject: [PATCH 3/6] add prebatch feature --- .pre-commit-config.yaml | 2 +- test/nodes/test_map.py | 37 +++++++++--- torchdata/nodes/map.py | 129 ++++++++++++++++++++++++++++++++-------- 3 files changed, 134 insertions(+), 34 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6696de652..1db0155da 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,6 @@ repos: - usort == 1.0.0 - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.1.0 hooks: - id: flake8 diff --git a/test/nodes/test_map.py b/test/nodes/test_map.py index 7caacd428..25f3c6fd3 100644 --- a/test/nodes/test_map.py +++ b/test/nodes/test_map.py @@ -7,7 +7,7 @@ import itertools import unittest -from typing import List +from typing import List, Optional from parameterized import parameterized from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase @@ -55,7 +55,7 @@ def test_exception_handling_mapper_multiprocess(self): def test_exception_handling_mapper_multiprocess_cuda(self): self._test_exception_handling_mapper(True, "process") - def _test_map(self, in_order, method) -> None: + def _test_map(self, in_order, method, prebatch) -> None: batch_size = 6 n = 80 multiprocessing_context = None if IS_WINDOWS else "forkserver" @@ -68,6 +68,7 @@ def _test_map(self, in_order, method) -> None: in_order=in_order, method=method, multiprocessing_context=multiprocessing_context, + prebatch=prebatch, ) node = Prefetcher(node, prefetch_factor=2) @@ -98,25 +99,40 @@ def _test_map(self, in_order, method) -> None: ) def test_in_order_threads(self): - self._test_map(True, "thread") + self._test_map(True, "thread", None) def test_out_of_order_threads(self): - self._test_map(False, "thread") + self._test_map(False, "thread", None) def test_in_order_process(self): - self._test_map(True, "process") + self._test_map(True, "process", None) def test_out_of_order_process(self): - self._test_map(False, "process") + self._test_map(False, "process", None) + + def test_in_order_thread_prebatch(self): + self._test_map(True, "thread", 3) + + def test_out_of_order_thread_prebatch(self): + self._test_map(False, "thread", 3) + + def test_in_order_process_prebatch(self): + self._test_map(True, "process", 3) + + def test_out_of_order_process_prebatch(self): + self._test_map(False, "process", 3) @parameterized.expand( itertools.product( [0, 7, 13], [True], # TODO: define and fix in_order = False [0, 1, 9], # TODO: define and fix in_order = False + [None, 3], # prebatch ) ) - def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int): + def test_save_load_state_thread( + self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int] + ): method = "thread" batch_size = 6 n = 80 @@ -129,6 +145,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr in_order=in_order, method=method, snapshot_frequency=snapshot_frequency, + prebatch=prebatch, ) node = Prefetcher(node, prefetch_factor=2) run_test_save_load_state(self, node, midpoint) @@ -138,9 +155,12 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr [0, 7, 13], [True], # TODO: define and fix in_order = False [0, 1, 9], # TODO: define and fix in_order = False + [None, 3], # prebatch ) ) - def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_frequency: int): + def test_save_load_state_process( + self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int] + ): method = "process" batch_size = 6 n = 80 @@ -155,6 +175,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f method=method, multiprocessing_context=multiprocessing_context, snapshot_frequency=snapshot_frequency, + prebatch=prebatch, ) node = Prefetcher(node, prefetch_factor=2) run_test_save_load_state(self, node, midpoint) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 2b110cd0d..7d8555f02 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -7,10 +7,11 @@ import queue import threading import time -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union import torch.multiprocessing as mp from torchdata.nodes.base_node import BaseNode, T +from torchdata.nodes.batch import Batcher, Unbatcher from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper from torchdata.nodes.snapshot_store import QueueSnapshotStore, SnapshotStore @@ -52,6 +53,14 @@ def Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) -> "ParallelMapper[T]" ) +class MapOverBatch(Callable[[Sequence[X]], T]): + def __init__(self, map_fn: Callable[[X], T]): + self.map_fn = map_fn + + def __call__(self, xlist: Sequence[X]) -> Sequence[T]: + return [self.map_fn(x) for x in xlist] + + def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event): buffer: Dict[int, Any] = {} cur_idx = 0 @@ -272,6 +281,77 @@ def _shutdown(self): t.join(timeout=QUEUE_TIMEOUT * 5) +class _ParallelMapperImpl(BaseNode[T]): + """This class implements _ParallelMapperIter as a BaseNode, allowing it + to be composed with other BaseNodes. + + TODO: In the future, this class may go away once we implement reset() on + _ParallelMapperIter itself so we don't need this additional level of abstraction + """ + + def __init__( + self, + source: BaseNode[X], + map_fn: Callable[[X], T], + num_workers: int, + in_order: bool = True, + method: Literal["thread", "process"] = "thread", + multiprocessing_context: Optional[str] = None, + max_concurrent: Optional[int] = None, + snapshot_frequency: int = 1, + ): + super().__init__() + assert method in ["thread", "process"] + self.source = source + self.map_fn = map_fn + self.num_workers = num_workers + self.in_order = in_order + self.method = method + self.multiprocessing_context = multiprocessing_context + self._mp_context: Any = mp + if self.method == "process" and self.multiprocessing_context is not None: + self._mp_context = mp.get_context(self.multiprocessing_context) + + if max_concurrent is not None and num_workers > 0: + if not isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + self.max_concurrent = max_concurrent + self.snapshot_frequency = snapshot_frequency + self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + if self._it is not None: + del self._it + + if self.num_workers > 0: + self._it = self._parallel_reset(initial_state) + else: + self._it = self._inline_reset(initial_state) + + def _inline_reset(self, initial_state: Optional[Dict[str, Any]]): + return _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state) + + def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]): + return _ParallelMapperIter( + source=self.source, + map_fn=self.map_fn, + num_workers=self.num_workers, + in_order=self.in_order, + method=self.method, + mp_context=self._mp_context, + max_concurrent=self.max_concurrent, + snapshot_frequency=self.snapshot_frequency, + initial_state=initial_state, + ) + + def next(self): + return next(self._it) # type: ignore[arg-type, union-attr] + + def get_state(self) -> Dict[str, Any]: + return self._it.get_state() # type: ignore[union-attr] + + class ParallelMapper(BaseNode[T]): """ParallelMapper executes map_fn in parallel either in num_workers threads or processes. For processes, multiprocessing_context can be spawn, forkserver, fork, @@ -294,8 +374,12 @@ class ParallelMapper(BaseNode[T]): multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None. max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None. snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1. + prebatch (Optional[int]): Optionally perform pre-batching of items from source before mapping. + For small items, this may improve throughput at the expense of peak memory. """ + IT_STATE_KEY = "it_state" + def __init__( self, source: BaseNode[X], @@ -306,6 +390,7 @@ def __init__( multiprocessing_context: Optional[str] = None, max_concurrent: Optional[int] = None, snapshot_frequency: int = 1, + prebatch: Optional[int] = None, ): super().__init__() assert method in ["thread", "process"] @@ -315,49 +400,43 @@ def __init__( self.in_order = in_order self.method = method self.multiprocessing_context = multiprocessing_context - self._mp_context: Any = mp - if self.method == "process" and self.multiprocessing_context is not None: - self._mp_context = mp.get_context(self.multiprocessing_context) - if max_concurrent is not None and num_workers > 0: if not isinstance(max_concurrent, int) and max_concurrent > num_workers: raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency - self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None - - def reset(self, initial_state: Optional[Dict[str, Any]] = None): - super().reset(initial_state) - if self._it is not None: - self._it._shutdown() - del self._it - - if self.num_workers > 0: - self._parallel_reset(initial_state) - else: - self._inline_reset(initial_state) + self.prebatch = prebatch + if self.prebatch is not None: + assert prebatch > 0, f"{prebatch=} must be a positive integer!" + self.map_fn = MapOverBatch(map_fn=self.map_fn) + self.source = Batcher(self.source, batch_size=prebatch, drop_last=False) - def _inline_reset(self, initial_state: Optional[Dict[str, Any]]): - self._it = _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state) - - def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]): - self._it = _ParallelMapperIter( + self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = _ParallelMapperImpl( source=self.source, map_fn=self.map_fn, num_workers=self.num_workers, in_order=self.in_order, method=self.method, - mp_context=self._mp_context, + multiprocessing_context=self.multiprocessing_context, max_concurrent=self.max_concurrent, snapshot_frequency=self.snapshot_frequency, - initial_state=initial_state, ) + if self.prebatch is not None: + self._it = Unbatcher(self._it) + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + if initial_state is not None: + self._it.reset(initial_state[self.IT_STATE_KEY]) + else: + self._it.reset() + def next(self): return next(self._it) # type: ignore[arg-type, union-attr] def get_state(self) -> Dict[str, Any]: - return self._it.get_state() # type: ignore[union-attr] + return {self.IT_STATE_KEY: self._it.state_dict()} # type: ignore[union-attr] _WorkerType = Callable[ From b2bd1c998402910749e57ce6013ac40937225cc7 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Mon, 30 Dec 2024 09:16:05 -0800 Subject: [PATCH 4/6] update docstring --- torchdata/nodes/map.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 7d8555f02..e1c03bf72 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -282,11 +282,12 @@ def _shutdown(self): class _ParallelMapperImpl(BaseNode[T]): - """This class implements _ParallelMapperIter as a BaseNode, allowing it - to be composed with other BaseNodes. + """This class implements _ParallelMapperIter and _InlineMapperIter as a BaseNode, + allowing them to be composed with other BaseNodes. TODO: In the future, this class may go away once we implement reset() on - _ParallelMapperIter itself so we don't need this additional level of abstraction + _ParallelMapperIter and _InlineMapperIter themselves so we don't need this + additional level of abstraction. """ def __init__( From fbe616ae78da4e9306ff178b66c28699d693181b Mon Sep 17 00:00:00 2001 From: andrewkho Date: Mon, 30 Dec 2024 10:02:58 -0800 Subject: [PATCH 5/6] fix mypy --- torchdata/nodes/map.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index e1c03bf72..90a8cf640 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -7,7 +7,7 @@ import queue import threading import time -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union import torch.multiprocessing as mp from torchdata.nodes.base_node import BaseNode, T @@ -53,7 +53,11 @@ def Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) -> "ParallelMapper[T]" ) -class MapOverBatch(Callable[[Sequence[X]], T]): +Xseq = Sequence[X] +Tseq = Sequence[T] + + +class MapOverBatch(Generic[X, T]): def __init__(self, map_fn: Callable[[X], T]): self.map_fn = map_fn @@ -346,7 +350,7 @@ def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]): initial_state=initial_state, ) - def next(self): + def next(self) -> T: return next(self._it) # type: ignore[arg-type, union-attr] def get_state(self) -> Dict[str, Any]: @@ -395,8 +399,6 @@ def __init__( ): super().__init__() assert method in ["thread", "process"] - self.source = source - self.map_fn = map_fn self.num_workers = num_workers self.in_order = in_order self.method = method @@ -407,12 +409,16 @@ def __init__( self.max_concurrent = max_concurrent self.snapshot_frequency = snapshot_frequency self.prebatch = prebatch - if self.prebatch is not None: - assert prebatch > 0, f"{prebatch=} must be a positive integer!" - self.map_fn = MapOverBatch(map_fn=self.map_fn) - self.source = Batcher(self.source, batch_size=prebatch, drop_last=False) + if prebatch is None: + self.map_fn = map_fn + self.source = source + else: + if prebatch <= 0: + raise ValueError(f"{prebatch=} must be a positive integer!") + self.map_fn = MapOverBatch(map_fn=map_fn) # type: ignore[assignment] + self.source = Batcher(self.source, batch_size=prebatch, drop_last=False) # type: ignore[assignment] - self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = _ParallelMapperImpl( + _it = _ParallelMapperImpl( source=self.source, map_fn=self.map_fn, num_workers=self.num_workers, @@ -423,8 +429,10 @@ def __init__( snapshot_frequency=self.snapshot_frequency, ) - if self.prebatch is not None: - self._it = Unbatcher(self._it) + if self.prebatch is None: + self._it = _it + else: + self._it = Unbatcher(_it) # type: ignore[arg-type, assignment] def reset(self, initial_state: Optional[Dict[str, Any]] = None): super().reset(initial_state) @@ -433,7 +441,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): else: self._it.reset() - def next(self): + def next(self) -> T: return next(self._it) # type: ignore[arg-type, union-attr] def get_state(self) -> Dict[str, Any]: From 6a999171896121c77f50a16d5cc4e487aa8a5b69 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Mon, 30 Dec 2024 10:38:17 -0800 Subject: [PATCH 6/6] fix test --- torchdata/nodes/map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index 90a8cf640..04430d894 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -416,7 +416,7 @@ def __init__( if prebatch <= 0: raise ValueError(f"{prebatch=} must be a positive integer!") self.map_fn = MapOverBatch(map_fn=map_fn) # type: ignore[assignment] - self.source = Batcher(self.source, batch_size=prebatch, drop_last=False) # type: ignore[assignment] + self.source = Batcher(source, batch_size=prebatch, drop_last=False) # type: ignore[assignment] _it = _ParallelMapperImpl( source=self.source,