Skip to content

Commit

Permalink
uniformly use skip for both (map-style) Dataset and IterableDataset
Browse files Browse the repository at this point in the history
ghstack-source-id: c8f611742ffbb4859988b97e706b9e0d1b4ad6f1
Pull Request resolved: #521
  • Loading branch information
tianyu-l committed Aug 16, 2024
1 parent f339363 commit 81c555f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torchdata >= 0.8.0
datasets >= 2.19.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
25 changes: 6 additions & 19 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,12 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

try:
from torchdata.stateful_dataloader import StatefulDataLoader
except ImportError as e:
raise ImportError(
"Please install the latest torchdata nightly to use StatefulDataloader via:"
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
) from e
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import load_dataset
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
Expand Down Expand Up @@ -102,7 +96,7 @@ def __init__(
else:
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling and checkpointing
# TODO: support shuffling
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down Expand Up @@ -143,17 +137,10 @@ def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
Expand All @@ -179,7 +166,7 @@ def state_dict(self) -> Dict[str, Any]:
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
# State being empty is valid
if not state_dict:
return

Expand Down

0 comments on commit 81c555f

Please sign in to comment.