Skip to content

Commit

Permalink
wraps DDP models with DSD
Browse files Browse the repository at this point in the history
Summary:
Distributed State Dict is the current suggested way from PyTorch for ensuring parallelized models state dicts are compatible with save/loads in Single process or re-sharding scenarios. 

This diff updates dcp_saver to use DSD for DDP models. A good idea would be wrap all models in TNT with DSD, as this could replace some of the wrapper logic for FSDP and would guarantee future compat.


N5551629 also contains a workaround for current DDP model saved before this diff, by manually removing the "module." prefix in the checkpoint.

Differential Revision: D59234083
  • Loading branch information
LucasLLC authored and facebook-github-bot committed Jul 2, 2024
1 parent 5dad8d3 commit 5818bb8
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
DefaultSavePlanner,
)
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
)
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter

from torch.nn.parallel import DistributedDataParallel
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand All @@ -41,6 +45,7 @@
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

from torchtnt.utils.stateful import MultiStateful, Stateful


Expand All @@ -63,6 +68,24 @@
)


class DSDModelWrapper(Stateful):
"""This wrapper converts state dicts to Distributed State Dicts, essentially generating
state dicts as if they were created using single-device methods. This is useful for
when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed
settings.
"""

def __init__(self, mod: torch.nn.Module) -> None:
self.mod: torch.nn.Module = mod

def state_dict(self) -> Dict[str, Any]:
return get_model_state_dict(self.mod)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.mod, state_dict)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
A callback which periodically saves the application state during training using `Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_.
Expand Down Expand Up @@ -148,6 +171,11 @@ def _checkpoint_impl(
curr_snapshot_wait = hook == "on_train_end"

app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)

for key, obj in app_state.items():
if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
if self._async_checkpoint:
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
Expand Down Expand Up @@ -315,14 +343,17 @@ def restore(
)

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
for key, obj in app_state.items():
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

try:
dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down

0 comments on commit 5818bb8

Please sign in to comment.