Skip to content

Commit

Permalink
2024-12-21 nightly release (62092dd)
Browse files Browse the repository at this point in the history
pytorchbot committed Dec 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent c17eb5e commit bdd6635
Showing 2 changed files with 15 additions and 6 deletions.
18 changes: 12 additions & 6 deletions torchdata/nodes/base_node.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ def __iter__(self):
def reset(self, initial_state: Optional[dict] = None):
"""Resets the iterator to the beginning, or to the state passed in by initial_state.
Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called.
Reset is a good place to put expensive initialization, as it will be lazily called when ``next()`` or ``state_dict()`` is called.
Subclasses must call ``super().reset(initial_state)``.
Args:
@@ -57,14 +57,18 @@ def reset(self, initial_state: Optional[dict] = None):
self.__initialized = True

def get_state(self) -> Dict[str, Any]:
"""Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode.
:return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future
"""Subclasses must implement this method, instead of ``state_dict()``. Should only be called by BaseNode.
Returns:
Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future
"""
raise NotImplementedError(type(self))

def next(self) -> T:
"""Subclasses must implement this method, instead of ``__next``. Should only be called by BaseNode.
:return: T - the next value in the sequence, or throw StopIteration
"""Subclasses must implement this method, instead of ``__next__``. Should only be called by BaseNode.
Returns:
T - the next value in the sequence, or throw StopIteration
"""
raise NotImplementedError(type(self))

@@ -83,7 +87,9 @@ def __next__(self):

def state_dict(self) -> Dict[str, Any]:
"""Get a state_dict for this BaseNode.
:return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future.
Returns:
Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future.
"""
try:
self.__initialized
3 changes: 3 additions & 0 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
@@ -24,13 +24,15 @@ class MultiNodeWeightedSampler(BaseNode[T]):
weights for sampling. `seed` is used to initialize the random number generator.
The node implements the state using the following keys:
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- EPOCH_KEY: An epoch counter used to initialize the random number generator.
- NUM_YIELDED_KEY: The number of items yielded.
- WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler.
We support multiple stopping criteria:
- CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets are exhausted. This is the default behavior.
- FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
- ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted.
@@ -203,6 +205,7 @@ class _WeightedSampler:
"""A weighted sampler that samples from a list of weights.
The class implements the state using the following keys:
- g_state: The state of the random number generator.
- g_rank_state: The state of the random number generator for the rank.
- offset: The offset of the batch of indices.

0 comments on commit bdd6635

Please sign in to comment.