From 4f7c08cbcf120f10a4382828dd05376ce424528c Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Thu, 16 May 2024 18:42:17 -0700 Subject: [PATCH] Integrate stateful dataloader to torchtitan Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 1 + test/__init__.py | 5 +++ test/datasets/test_dataset_checkpoint.py | 56 ++++++++++++++++++++++++ torchtitan/checkpoint.py | 32 ++++++++++++++ torchtitan/datasets/hf_datasets.py | 51 +++++++++++++++++---- train.py | 1 + 6 files changed, 137 insertions(+), 9 deletions(-) create mode 100644 test/datasets/test_dataset_checkpoint.py diff --git a/README.md b/README.md index 21634d0be..a8d1fcc4c 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 +pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly ``` ### Downloading a tokenizer diff --git a/test/__init__.py b/test/__init__.py index e69de29bb..2e41cd717 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/test_dataset_checkpoint.py b/test/datasets/test_dataset_checkpoint.py new file mode 100644 index 000000000..e4be71d60 --- /dev/null +++ b/test/datasets/test_dataset_checkpoint.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtitan.checkpoint import DataLoaderWrapper +from torchtitan.datasets.hf_datasets import build_hf_data_loader +from torchtitan.datasets.tokenizer import create_tokenizer + + +class TestDatasetCheckpoint: + def test_c4_resumption(self): + dataset_name = "c4_mini" + dataset_path = "./torchtitan/datasets/c4_mini" + batch_size = 1 + seq_len = 1024 + world_size = 4 + rank = 0 + + dl_wrapper = self._create_dataloader_wrapper( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + + it = iter(dl_wrapper.dataloader) + for _ in range(250): + next(it) + state = dl_wrapper.state_dict() + expected_input_ids, expected_labels = next(it) + + # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above + dl_wrapper = self._create_dataloader_wrapper( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + dl_wrapper.load_state_dict(state) + input_ids, labels = next(iter(dl_wrapper.dataloader)) + + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(labels, expected_labels) + + def _create_dataloader_wrapper( + self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ): + tokenizer_type = "tiktoken" + tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + dataloader = build_hf_data_loader( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + batch_size=1, + seq_len=1024, + world_size=4, + rank=0, + ) + return DataLoaderWrapper(dataloader) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 33fe8c05c..6856db2e5 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -6,6 +6,7 @@ import enum import os +import pickle import re import time from multiprocessing import get_context @@ -22,6 +23,8 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import DataLoader +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import init_logger, logger @@ -67,6 +70,33 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict) +class DataLoaderWrapper(Stateful): + def __init__(self, dataloader: DataLoader) -> None: + self.dataloader = dataloader + # Use global rank for now even though dataloader state could be same across dp groups + self.rank_id = str( + dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0 + ) + + def state_dict(self) -> Dict[str, Any]: + if isinstance(self.dataloader, StatefulDataLoader): + return {self.rank_id: pickle.dumps(self.dataloader.state_dict())} + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if isinstance(self.dataloader, StatefulDataLoader): + # State is empty + if not state_dict: + return + + if self.rank_id not in state_dict: + logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ") + return + + # Load state for the current rank + self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id])) + + class Terminate: pass @@ -110,6 +140,7 @@ def __init__( model: nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -125,6 +156,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, + "dataloader": DataLoaderWrapper(dataloader), } ) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index f6d09faac..a98c9467f 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -7,7 +7,9 @@ from typing import List, Optional import torch -from torch.utils.data import DataLoader, IterableDataset +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging_utils import logger @@ -23,7 +25,7 @@ } -class HuggingFaceDataset(IterableDataset): +class HuggingFaceDataset(IterableDataset, Stateful): """PyTorch Representation of the HuggingFace Dataset. Args: @@ -99,32 +101,63 @@ def __init__( self.seq_len = seq_len self.infinite = infinite + # variables for checkpointing + self._sample_idx = 0 + self._all_tokens: List[int] = [] + def __iter__(self): max_buffer_token_len = 1 + self.seq_len - all_tokens: List[int] = [] while True: - for sample in iter(self._data): + for sample in self._get_data_iter(): sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] + self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] yield input, label + if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data.") break else: + # Reset offset for the next iteration + self._sample_idx = 0 logger.warning( f"Dataset {self.dataset_name} is being re-looped. " "Loss related metrics might be misleading." ) + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + # Skip samples + if isinstance(self._data, IterableDataset): + it = iter(self._data) + # Naively iterate through the samples as skip may not be supported + for _ in range(self._sample_idx): + next(it) + return it + + # As skipping to the end throws an error in case of map-style dataset, return an empty iterator + if self._sample_idx == len(self._data): + return iter([]) + return iter(self._data.skip(self._sample_idx)) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + def build_hf_data_loader( dataset_name: str, @@ -140,4 +173,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DataLoader(hf_ds, batch_size=batch_size) + return StatefulDataLoader(hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 318c7174e..a0bb337e6 100644 --- a/train.py +++ b/train.py @@ -245,6 +245,7 @@ def loss_fn(pred, labels): model=model, optimizer=optimizer, lr_scheduler=scheduler, + dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, )