diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index fc3b4455..2321627e 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,5 +1,5 @@ torchdata >= 0.8.0 -datasets >= 2.19.0 +datasets >= 2.21.0 tomli >= 1.1.0 ; python_version < "3.11" tensorboard sentencepiece diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index cc13012e..9db036b0 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -11,18 +11,12 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset -try: - from torchdata.stateful_dataloader import StatefulDataLoader -except ImportError as e: - raise ImportError( - "Please install the latest torchdata nightly to use StatefulDataloader via:" - "pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly" - ) from e +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging import logger -from datasets import load_dataset +from datasets import Dataset, load_dataset from datasets.distributed import split_dataset_by_node # map from dataset name to a local directory, or @@ -102,7 +96,7 @@ def __init__( else: ds = load_dataset(dataset_path, split="train") - # TODO: support shuffling and checkpointing + # TODO: support shuffling self.dataset_name = dataset_name self._data = split_dataset_by_node(ds, rank, world_size) self._tokenizer = tokenizer @@ -143,17 +137,10 @@ def _get_data_iter(self): if self._sample_idx == 0: return iter(self._data) - # Skip samples - if isinstance(self._data, IterableDataset): - it = iter(self._data) - # Naively iterate through the samples as skip may not be supported - for _ in range(self._sample_idx): - next(it) - return it - # As skipping to the end throws an error in case of map-style dataset, return an empty iterator - if self._sample_idx == len(self._data): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): return iter([]) + return iter(self._data.skip(self._sample_idx)) def load_state_dict(self, state_dict): @@ -179,7 +166,7 @@ def state_dict(self) -> Dict[str, Any]: return {self._rank_id: pickle.dumps(super().state_dict())} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - # State being empty is valid, don't log a warning + # State being empty is valid if not state_dict: return