Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unbatcher node #1416

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion test/nodes/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.batch import Batcher
from torchdata.nodes.batch import Batcher, Unbatcher

from .utils import MockSource, run_test_save_load_state

Expand Down Expand Up @@ -48,3 +48,27 @@ def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
run_test_save_load_state(self, node, midpoint)


class TestUnbatcher(TestCase):
def test_unbatcher(self) -> None:
batch_size = 6
n = 20
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = Unbatcher(node)

results = list(node)
self.assertEqual(len(results), n)
for i in range(n):
self.assertEqual(results[i]["step"], i)
self.assertEqual(results[i]["test_tensor"], torch.tensor([i]))
self.assertEqual(results[i]["test_str"], f"str_{i}")

@parameterized.expand(itertools.product([0, 2], [True, False]))
def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool):
batch_size = 6
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size, drop_last=drop_last)
node = Unbatcher(node)
run_test_save_load_state(self, node, midpoint)
3 changes: 2 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .batch import Batcher, Unbatcher
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
Expand All @@ -31,6 +31,7 @@
"Stateful",
"StopCriteria",
"T",
"Unbatcher",
]

assert sorted(__all__) == __all__
55 changes: 54 additions & 1 deletion torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

from torchdata.nodes.base_node import BaseNode, T

Expand Down Expand Up @@ -56,3 +56,56 @@ def next(self) -> List[T]:

def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}


class Unbatcher(BaseNode[T]):
"""Unbatcher will flatten batches pulled from source, and
yields elements in sequential order when next() is called on it.

Args:
source (BaseNode[T]): The source node to pull batches from.
"""

SOURCE_KEY = "source"
BATCH_IDX_KEY = "batch_idx"

def __init__(self, source: BaseNode[Sequence[T]]):
super().__init__(self)
self.source = source

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.source.reset(initial_state[self.SOURCE_KEY])
self._cached_state_dict = initial_state[self.SOURCE_KEY]
try:
self._batch = next(self.source)
self._batch_idx = initial_state[self.BATCH_IDX_KEY]
except StopIteration:
# next(self.source) will be called upon subsequent self.next() call
# and raise StopIteration in the correct place.
self._batch = []
self._batch_idx = 0
else:
self.source.reset()
self._batch = []
self._cached_state_dict = None
self._batch_idx = 0

def next(self) -> T:
while self._batch_idx >= len(self._batch):
Copy link
Contributor

@divyanshk divyanshk Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be if instead of while ? i.e. if the _batch_idx overshoots the current _batch, get a new _batch and reset _batch_idx to 0.

EDIT: while also works though, in case the next batch if of size 0 and we want to skip that too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshk yes i was worried about next batch of size 0 case, although I'm assuming that's unlikely but it's an edge case none the less

self._cached_state_dict = self.source.state_dict()
self._batch = next(self.source)
self._batch_idx = 0

self._batch_idx += 1
return self._batch[self._batch_idx - 1]

def get_state(self) -> Dict[str, Any]:
if self._cached_state_dict is None:
self._cached_state_dict = self.source.state_dict()

return {
self.SOURCE_KEY: self._cached_state_dict,
self.BATCH_IDX_KEY: self._batch_idx,
}
Loading