Skip to content

Monolithic checkpointing #3876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def __init__(
precision_config: Optional[dict[str, Any]] = None,

# optimizers
# TODO: Deprecate optimizers and support `optimizer` instead since we
# don't support multiple optimizers
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,

# scaler
Expand Down Expand Up @@ -591,7 +593,8 @@ def _validate_parallelism_configs(self):
raise ValueError('load_monolith_rank0_only is not compatible with tensor parallelism (TP).')
assert self.fsdp_config is not None
error_message = ''
if self.fsdp_config.sync_module_states == False:
# FSDP2 automatically syncs module states, so we don't need to check for it
if isinstance(self.fsdp_config, FSDPConfig) and self.fsdp_config.sync_module_states == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires parallelism_config['fsdp']['sync_module_states'] to be True. "
"Either set parallelism_config['fsdp']['sync_module_states'] = True or set load_monolith_rank0_only = False.",
Expand Down Expand Up @@ -911,10 +914,14 @@ def fsdp_sharded_state_dict_enabled(self):

@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_config.auto_wrap and
self.fsdp_config.state_dict_type == 'full' and self.fsdp_config.load_monolith_rank0_only == True
should_load_monolith_rank0_only = (
self.fsdp_config is not None and self.fsdp_config.state_dict_type == 'full' and
self.fsdp_config.load_monolith_rank0_only == True
)
# TODO: Only FSDP1 has auto_wrap; if this is a legacy config, we should remove this check
if isinstance(self.fsdp_config, FSDPConfig):
should_load_monolith_rank0_only = should_load_monolith_rank0_only and self.fsdp_config.auto_wrap
return should_load_monolith_rank0_only

def _get_integrations_state_dict(self) -> dict[str, Any]:
"""Gets a dictionary of information about integrations to store in the state dict.
Expand Down Expand Up @@ -1325,14 +1332,14 @@ def load_model_state(
if self.load_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
with reproducibility.seed_context(self.rank_zero_seed):
from composer.distributed import prepare_fsdp_module
self._apply_fsdp()
log.debug('Finished wrapping model with FSDP.')

# TODO (FSDP2): support calling FSDP2 wrapper depending on the config type
assert isinstance(
self.fsdp_config,
FSDPConfig,
), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.fsdp_config)}'
def _apply_fsdp(self):
# Init with globally fixed seed so all FSDP/HSDP replicas have the same initial weights
with reproducibility.seed_context(self.rank_zero_seed):
if isinstance(self.fsdp_config, FSDPConfig):
from composer.distributed import prepare_fsdp_module
self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(
self.model,
self.optimizers,
Expand All @@ -1341,7 +1348,22 @@ def load_model_state(
self.device,
self.auto_microbatching,
)
log.debug('Finished wrapping model with FSDP.')
elif isinstance(self.fsdp_config, FSDP2Config):
from composer import ComposerModel
from composer.distributed.prepare_distributed import parallelize_composer_model

# FSDP2 doesn't support auto_microbatching (checked earlier, just validating here to be safe)
assert not self.auto_microbatching, 'auto_microbatching is not supported with FSDP2'
# To make pyright happy (instead of just adding a type: ignore)
assert isinstance(self.model, ComposerModel)

parallelize_composer_model(
self.model,
self.optimizers[0] if self.optimizers else None,
self.fsdp_config,
)
else:
raise ValueError(f'Unsupported FSDP config type for monolithic loading: {type(self.fsdp_config)}')

def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True):
"""Load the optimizer state.
Expand Down Expand Up @@ -1370,15 +1392,29 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True):

optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None

# Note: 'broadcast_from_rank0' is only required for FSDP2.
# - In `set_optimizer_state_dict`, FSDP1 follows a different code path where it detects FSDP1 modules and handles FlatParameters differently.
# Essentially, either `cpu_offload` or `broadcast_from_rank0` in FSDP1 cause broadcasting from rank 0 and that's why we only need to set
# `cpu_offload` to True for FSDP1. Setting `broadcast_from_rank0` to True for FSDP1 is essentially a no-op and this follows our previous
# implementation for FSDP1.
# - In FSDP2, we don't need to set `cpu_offload` to True as the model weights has already been sharded to DTensors on GPUs on all ranks.
# `set_optimizer_state_dict` will utilize those sharded weights to broadcast the relevant shards of the optimizer state dict (on CPU on rank 0)
# to the relevant GPUs on all ranks when `broadcast_from_rank0` is set to True.
cpu_offload = self.fsdp_enabled and isinstance(self.fsdp_config, FSDPConfig)
broadcast_from_rank0 = self.load_monolith_rank0_only and isinstance(self.fsdp_config, FSDP2Config)

state_dict_options = StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
broadcast_from_rank0=broadcast_from_rank0,
cpu_offload=cpu_offload,
strict=strict,
)

set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict, # type: ignore
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=self.fsdp_enabled,
),
options=state_dict_options,
)

def load_state_dict(
Expand Down
3 changes: 3 additions & 0 deletions composer/distributed/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def update_sync_module_states_if_needed(model: nn.Module, fsdp_config: FSDP2Conf
dist.all_reduce(any_ranks_meta, reduce_operation='MAX')
requires_sync = all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1

if fsdp_config.load_monolith_rank0_only:
fsdp_config.sync_module_states = True

if not fsdp_config.sync_module_states and requires_sync:
fsdp_config.sync_module_states = True

Expand Down
51 changes: 4 additions & 47 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@
from composer.distributed import (
DDPSyncStrategy,
ddp_sync_context,
parallelize_composer_model,
prepare_ddp_module,
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.shared_utils import generate_oom_hook
Expand Down Expand Up @@ -1610,7 +1608,7 @@ def __init__(
# original model for functions like `eval_forward`, `get_metrics`, etc.
self._original_model = self.state.model

self._wrap_model_for_distributed(model, optimizers, precision, device, auto_microbatching)
self._wrap_model_for_distributed(model, optimizers)

self.engine.run_event(Event.BEFORE_LOAD)

Expand Down Expand Up @@ -1731,25 +1729,8 @@ def __init__(

# FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if
# load_monolith_rank0_only=True but no checkpoint was loaded.
if (
not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and
self.state.load_monolith_rank0_only
):
# TODO (FSDP2): support calling FSDP2 wrapper depending on the config type
assert isinstance(
self.state.fsdp_config,
FSDPConfig,
), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.state.fsdp_config)}'
# Init with globally fixed seed so all HSDP replicas have the same initial weights
with reproducibility.seed_context(self.state.rank_zero_seed):
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config,
precision,
device,
auto_microbatching,
)
if not self.state.fsdp_enabled and self.state.load_monolith_rank0_only:
self.state._apply_fsdp()

# Set the iteration timestamp to the overall timestamp if loading from a checkpoint that was created before
# iteration was introduced in Composer v0.19.1. This is necessary to ensure that the iteration timestamp is
Expand Down Expand Up @@ -1795,9 +1776,6 @@ def _wrap_model_for_distributed(
self,
model: ComposerModel,
optimizers: Optional[torch.optim.Optimizer],
precision: Union[str, Precision],
device: Device,
auto_microbatching: bool,
):
"""Wrap the model for distributed training (TP, FSDP, etc.).

Expand Down Expand Up @@ -1825,28 +1803,7 @@ def _wrap_model_for_distributed(

# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and not self.state.load_monolith_rank0_only:
# Init with globally fixed seed so all HSDP replicas have the same initial weights
with reproducibility.seed_context(self.state.rank_zero_seed):
match self.state.fsdp_config_version:
case 1:
self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config, # type: ignore
precision,
device,
auto_microbatching,
self.state.seed,
)
case 2:
assert not auto_microbatching
parallelize_composer_model(
model,
optimizers,
self.state.fsdp_config, # type: ignore
)
case _:
raise ValueError(f'Unsupported FSDP config version: {self.state.fsdp_config_version}')
self.state._apply_fsdp()

@property
def saved_checkpoints(self) -> list[str]:
Expand Down
33 changes: 18 additions & 15 deletions composer/utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class FSDP2Config:
For 1D mesh, parameters are fully sharded across the mesh (FSDP).
For 2D mesh, parameters are sharded across the 1st dimension and replicated across the 0th dimension (HSDP).
reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward.
activation_checkpointing (bool): Whether to use activation checkpointing. Defaults to False.
activation_cpu_offload (bool): Whether to use activation CPU offloading. Defaults to False.
load_monolith_rank0_only (bool): Whether to load monolithic checkpoints on rank 0 only. Defaults to False.
state_dict_type (str): Type of state dict to use. Can be 'full' or 'sharded'. Defaults to 'sharded'.
verbose (bool): Whether to print verbose output. Defaults to False.
"""

# Settable core FSDP2 attrs
Expand All @@ -80,6 +85,9 @@ class FSDP2Config:
# in most of our use cases, we can decouple these two attributes from the FSDP2Config class.
activation_checkpointing: bool = False
activation_cpu_offload: bool = False
state_dict_type: str = 'sharded'
load_monolith_rank0_only: bool = False

verbose: bool = False

# Settable attrs that are automatically set during training
Expand Down Expand Up @@ -132,16 +140,7 @@ def from_compatible_attrs(cls, attrs: dict[str, Any]) -> 'FSDP2Config':
# Create and return a new FSDP2Config with the valid attributes
return FSDP2Config(**valid_attrs)

### Temporary read-only properties for FSDP 1 compatibility ###
# to be supported in FSDP2
@property
def auto_wrap(self) -> bool:
return False

@property
def load_monolith_rank0_only(self) -> bool:
return False

### Read-only properties for FSDP 1 compatibility ###
@property
def load_planner(self) -> Optional[Any]:
return None
Expand All @@ -162,18 +161,22 @@ def data_parallel_shard_degree(self) -> int:
def data_parallel_replicate_degree(self) -> Optional[int]:
return None

# to be deprecated in FSDP2
@property
def state_dict_type(self) -> str:
return 'sharded'

@property
def use_orig_params(self) -> bool:
return True

def __post_init__(self):
warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning)

# TODO: We might not need `load_monolith_rank0_only` as we can theoretically use
# self.monolith_rank0_only = self.state_dict_type == 'full' assuming that saving
# the model doesn't get affected by `load_monolith_rank0_only`
if self.load_monolith_rank0_only and self.state_dict_type != 'full':
raise ValueError(
'load_monolith_rank0_only=True requires state_dict_type="full". '
f'Got state_dict_type="{self.state_dict_type}"',
)


@dataclass
class TPConfig:
Expand Down
4 changes: 3 additions & 1 deletion tests/common/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def _check_item(
_check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol, ignore_keys=ignore_keys)
return
if isinstance(item1, (tuple, list)):
assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}'
# When we are broadcasting lists/tuples from rank0 (e.g. State.load_optim_state)
# tuples get converted to lists and so we don't want to validate the type, just
# the values
_check_list_recursively(item1, item2, path, atol=atol, rtol=rtol)
return
if isinstance(item1, ShardedTensor):
Expand Down
Loading
Loading