Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nodes] Add Prebatch setting to ParallelMapper #1417

Merged
merged 7 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
- usort == 1.0.0

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.1.0
hooks:
- id: flake8
26 changes: 25 additions & 1 deletion test/nodes/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.batch import Batcher
from torchdata.nodes.batch import Batcher, Unbatcher

from .utils import MockSource, run_test_save_load_state

Expand Down Expand Up @@ -48,3 +48,27 @@ def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
run_test_save_load_state(self, node, midpoint)


class TestUnbatcher(TestCase):
def test_unbatcher(self) -> None:
batch_size = 6
n = 20
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = Unbatcher(node)

results = list(node)
self.assertEqual(len(results), n)
for i in range(n):
self.assertEqual(results[i]["step"], i)
self.assertEqual(results[i]["test_tensor"], torch.tensor([i]))
self.assertEqual(results[i]["test_str"], f"str_{i}")

@parameterized.expand(itertools.product([0, 2], [True, False]))
def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
batch_size = 6
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
node = Unbatcher(node)
run_test_save_load_state(self, node, midpoint)
37 changes: 29 additions & 8 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools

import unittest
from typing import List
from typing import List, Optional

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_exception_handling_mapper_multiprocess(self):
def test_exception_handling_mapper_multiprocess_cuda(self):
self._test_exception_handling_mapper(True, "process")

def _test_map(self, in_order, method) -> None:
def _test_map(self, in_order, method, prebatch) -> None:
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
Expand All @@ -68,6 +68,7 @@ def _test_map(self, in_order, method) -> None:
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)

Expand Down Expand Up @@ -98,25 +99,40 @@ def _test_map(self, in_order, method) -> None:
)

def test_in_order_threads(self):
self._test_map(True, "thread")
self._test_map(True, "thread", None)

def test_out_of_order_threads(self):
self._test_map(False, "thread")
self._test_map(False, "thread", None)

def test_in_order_process(self):
self._test_map(True, "process")
self._test_map(True, "process", None)

def test_out_of_order_process(self):
self._test_map(False, "process")
self._test_map(False, "process", None)

def test_in_order_thread_prebatch(self):
self._test_map(True, "thread", 3)

def test_out_of_order_thread_prebatch(self):
self._test_map(False, "thread", 3)

def test_in_order_process_prebatch(self):
self._test_map(True, "process", 3)

def test_out_of_order_process_prebatch(self):
self._test_map(False, "process", 3)

@parameterized.expand(
itertools.product(
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_thread(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "thread"
batch_size = 6
n = 80
Expand All @@ -129,6 +145,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
in_order=in_order,
method=method,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
Expand All @@ -138,9 +155,12 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_process(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "process"
batch_size = 6
n = 80
Expand All @@ -155,6 +175,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
3 changes: 2 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .batch import Batcher, Unbatcher
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
Expand All @@ -31,6 +31,7 @@
"Stateful",
"StopCriteria",
"T",
"Unbatcher",
]

assert sorted(__all__) == __all__
55 changes: 54 additions & 1 deletion torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

from torchdata.nodes.base_node import BaseNode, T

Expand Down Expand Up @@ -56,3 +56,56 @@ def next(self) -> List[T]:

def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}


class Unbatcher(BaseNode[T]):
"""Unbatcher will flatten batches pulled from source, and
yields elements in sequential order when next() is called on it.

Args:
source (BaseNode[T]): The source node to pull batches from.
"""

SOURCE_KEY = "source"
BATCH_IDX_KEY = "batch_idx"

def __init__(self, source: BaseNode[Sequence[T]]):
super().__init__(self)
self.source = source

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.source.reset(initial_state[self.SOURCE_KEY])
self._cached_state_dict = initial_state[self.SOURCE_KEY]
try:
self._batch = next(self.source)
self._batch_idx = initial_state[self.BATCH_IDX_KEY]
except StopIteration:
# next(self.source) will be called upon subsequent self.next() call
# and raise StopIteration in the correct place.
self._batch = []
self._batch_idx = 0
else:
self.source.reset()
self._batch = []
self._cached_state_dict = None
self._batch_idx = 0

def next(self) -> T:
while self._batch_idx >= len(self._batch):
self._cached_state_dict = self.source.state_dict()
self._batch = next(self.source)
self._batch_idx = 0

self._batch_idx += 1
return self._batch[self._batch_idx - 1]

def get_state(self) -> Dict[str, Any]:
if self._cached_state_dict is None:
self._cached_state_dict = self.source.state_dict()

return {
self.SOURCE_KEY: self._cached_state_dict,
self.BATCH_IDX_KEY: self._batch_idx,
}
Loading
Loading