Skip to content

Commit

Permalink
store state only for dp rank
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
gokulavasan committed May 21, 2024
1 parent 9b07bb9 commit 80cefc0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.checkpoint import DataLoaderWrapper
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer


class TestDatasetCheckpoint:
class TestCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
Expand All @@ -19,32 +18,32 @@ def test_c4_resumption(self):
world_size = 4
rank = 0

dl_wrapper = self._create_dataloader_wrapper(
dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)

it = iter(dl_wrapper.dataloader)
it = iter(dl)
for _ in range(250):
next(it)
state = dl_wrapper.state_dict()
state = dl.state_dict()
expected_input_ids, expected_labels = next(it)

# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
dl_wrapper = self._create_dataloader_wrapper(
dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)
dl_wrapper.load_state_dict(state)
input_ids, labels = next(iter(dl_wrapper.dataloader))
dl.load_state_dict(state)
input_ids, labels = next(iter(dl))

assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def _create_dataloader_wrapper(
def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
dataloader = build_hf_data_loader(
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
Expand All @@ -53,4 +52,3 @@ def _create_dataloader_wrapper(
world_size=4,
rank=0,
)
return DataLoaderWrapper(dataloader)
32 changes: 1 addition & 31 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import enum
import os
import pickle
import re
import time
from multiprocessing import get_context
Expand All @@ -23,8 +22,6 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger

Expand Down Expand Up @@ -63,33 +60,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)


class DataLoaderWrapper(Stateful):
def __init__(self, dataloader: DataLoader) -> None:
self.dataloader = dataloader
# Use global rank for now even though dataloader state could be same across dp groups
self.rank_id = str(
dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0
)

def state_dict(self) -> Dict[str, Any]:
if isinstance(self.dataloader, StatefulDataLoader):
return {self.rank_id: pickle.dumps(self.dataloader.state_dict())}
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if isinstance(self.dataloader, StatefulDataLoader):
# State is empty
if not state_dict:
return

if self.rank_id not in state_dict:
logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ")
return

# Load state for the current rank
self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id]))


class Terminate:
pass

Expand Down Expand Up @@ -149,7 +119,7 @@ def __init__(
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"dataloader": DataLoaderWrapper(dataloader),
"dataloader": dataloader,
}
)

Expand Down
32 changes: 30 additions & 2 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
import pickle
from typing import Any, Dict, List, Optional

import torch
from torch.distributed.checkpoint.stateful import Stateful
Expand Down Expand Up @@ -159,6 +160,33 @@ def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DpAwareDataLoader(StatefulDataLoader):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}."
)
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
dataset_name: str,
dataset_path: Optional[str],
Expand All @@ -173,4 +201,4 @@ def build_hf_data_loader(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return StatefulDataLoader(hf_ds, batch_size=batch_size)
return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def loss_fn(pred, labels):
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
dp_rank=dp_rank,
)

if job_config.checkpoint.create_seed_checkpoint:
Expand Down

0 comments on commit 80cefc0

Please sign in to comment.