From a62f3173c63d861158a67eb75dbf20d12dd2963b Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 14:45:23 -0700 Subject: [PATCH] Address PR comments --- .github/workflows/integration_test_periodic.yaml | 1 + pyproject.toml | 2 +- torchtitan/datasets/hf_datasets.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration_test_periodic.yaml b/.github/workflows/integration_test_periodic.yaml index bc717cd1..488fc4da 100644 --- a/.github/workflows/integration_test_periodic.yaml +++ b/.github/workflows/integration_test_periodic.yaml @@ -34,6 +34,7 @@ jobs: - name: Install dependencies run: | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - name: Run test_runner.py diff --git a/pyproject.toml b/pyproject.toml index 2a8f9557..a5c1b72f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ keywords = ["pytorch", "training", "llm"] dependencies = [ # Hugging Face integrations - "datasets", + "datasets>=2.19.0", # Tokenization "blobfile", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 3710b1fd..52c13697 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -160,7 +160,7 @@ def state_dict(self): return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} -class DpAwareDataLoader(StatefulDataLoader, Stateful): +class DPAwareDataLoader(StatefulDataLoader, Stateful): """ A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks. """ @@ -201,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size) + return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)