diff --git a/composer/core/state.py b/composer/core/state.py index bf196b771c..e645f7b02c 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -453,6 +453,8 @@ def __init__( precision_config: Optional[dict[str, Any]] = None, # optimizers + # TODO: Deprecate optimizers and support `optimizer` instead since we + # don't support multiple optimizers optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, # scaler @@ -591,7 +593,8 @@ def _validate_parallelism_configs(self): raise ValueError('load_monolith_rank0_only is not compatible with tensor parallelism (TP).') assert self.fsdp_config is not None error_message = '' - if self.fsdp_config.sync_module_states == False: + # FSDP2 automatically syncs module states, so we don't need to check for it + if isinstance(self.fsdp_config, FSDPConfig) and self.fsdp_config.sync_module_states == False: error_message += textwrap.dedent( "load_monolith_rank0_only requires parallelism_config['fsdp']['sync_module_states'] to be True. " "Either set parallelism_config['fsdp']['sync_module_states'] = True or set load_monolith_rank0_only = False.", @@ -911,10 +914,14 @@ def fsdp_sharded_state_dict_enabled(self): @property def load_monolith_rank0_only(self): - return ( - self.fsdp_config is not None and self.fsdp_config.auto_wrap and - self.fsdp_config.state_dict_type == 'full' and self.fsdp_config.load_monolith_rank0_only == True + should_load_monolith_rank0_only = ( + self.fsdp_config is not None and self.fsdp_config.state_dict_type == 'full' and + self.fsdp_config.load_monolith_rank0_only == True ) + # TODO: Only FSDP1 has auto_wrap; if this is a legacy config, we should remove this check + if isinstance(self.fsdp_config, FSDPConfig): + should_load_monolith_rank0_only = should_load_monolith_rank0_only and self.fsdp_config.auto_wrap + return should_load_monolith_rank0_only def _get_integrations_state_dict(self) -> dict[str, Any]: """Gets a dictionary of information about integrations to store in the state dict. @@ -1325,14 +1332,14 @@ def load_model_state( if self.load_monolith_rank0_only: assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') - with reproducibility.seed_context(self.rank_zero_seed): - from composer.distributed import prepare_fsdp_module + self._apply_fsdp() + log.debug('Finished wrapping model with FSDP.') - # TODO (FSDP2): support calling FSDP2 wrapper depending on the config type - assert isinstance( - self.fsdp_config, - FSDPConfig, - ), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.fsdp_config)}' + def _apply_fsdp(self): + # Init with globally fixed seed so all FSDP/HSDP replicas have the same initial weights + with reproducibility.seed_context(self.rank_zero_seed): + if isinstance(self.fsdp_config, FSDPConfig): + from composer.distributed import prepare_fsdp_module self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module( self.model, self.optimizers, @@ -1341,7 +1348,22 @@ def load_model_state( self.device, self.auto_microbatching, ) - log.debug('Finished wrapping model with FSDP.') + elif isinstance(self.fsdp_config, FSDP2Config): + from composer import ComposerModel + from composer.distributed.prepare_distributed import parallelize_composer_model + + # FSDP2 doesn't support auto_microbatching (checked earlier, just validating here to be safe) + assert not self.auto_microbatching, 'auto_microbatching is not supported with FSDP2' + # To make pyright happy (instead of just adding a type: ignore) + assert isinstance(self.model, ComposerModel) + + parallelize_composer_model( + self.model, + self.optimizers[0] if self.optimizers else None, + self.fsdp_config, + ) + else: + raise ValueError(f'Unsupported FSDP config type for monolithic loading: {type(self.fsdp_config)}') def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True): """Load the optimizer state. @@ -1370,15 +1392,29 @@ def load_optim_state(self, state_dict: dict[str, Any], strict: bool = True): optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None + # Note: 'broadcast_from_rank0' is only required for FSDP2. + # - In `set_optimizer_state_dict`, FSDP1 follows a different code path where it detects FSDP1 modules and handles FlatParameters differently. + # Essentially, either `cpu_offload` or `broadcast_from_rank0` in FSDP1 cause broadcasting from rank 0 and that's why we only need to set + # `cpu_offload` to True for FSDP1. Setting `broadcast_from_rank0` to True for FSDP1 is essentially a no-op and this follows our previous + # implementation for FSDP1. + # - In FSDP2, we don't need to set `cpu_offload` to True as the model weights has already been sharded to DTensors on GPUs on all ranks. + # `set_optimizer_state_dict` will utilize those sharded weights to broadcast the relevant shards of the optimizer state dict (on CPU on rank 0) + # to the relevant GPUs on all ranks when `broadcast_from_rank0` is set to True. + cpu_offload = self.fsdp_enabled and isinstance(self.fsdp_config, FSDPConfig) + broadcast_from_rank0 = self.load_monolith_rank0_only and isinstance(self.fsdp_config, FSDP2Config) + + state_dict_options = StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + broadcast_from_rank0=broadcast_from_rank0, + cpu_offload=cpu_offload, + strict=strict, + ) + set_optimizer_state_dict( model=self.model, optimizers=optimizer, optim_state_dict=optim_state_dict, # type: ignore - options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type == 'full', - strict=strict, - cpu_offload=self.fsdp_enabled, - ), + options=state_dict_options, ) def load_state_dict( diff --git a/composer/distributed/shared_utils.py b/composer/distributed/shared_utils.py index 51489179df..c25c6a0b6a 100644 --- a/composer/distributed/shared_utils.py +++ b/composer/distributed/shared_utils.py @@ -181,6 +181,9 @@ def update_sync_module_states_if_needed(model: nn.Module, fsdp_config: FSDP2Conf dist.all_reduce(any_ranks_meta, reduce_operation='MAX') requires_sync = all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1 + if fsdp_config.load_monolith_rank0_only: + fsdp_config.sync_module_states = True + if not fsdp_config.sync_module_states and requires_sync: fsdp_config.sync_module_states = True diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index bc4152ec47..cfe22067c9 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -74,9 +74,7 @@ from composer.distributed import ( DDPSyncStrategy, ddp_sync_context, - parallelize_composer_model, prepare_ddp_module, - prepare_fsdp_module, prepare_tp_module, ) from composer.distributed.shared_utils import generate_oom_hook @@ -1610,7 +1608,7 @@ def __init__( # original model for functions like `eval_forward`, `get_metrics`, etc. self._original_model = self.state.model - self._wrap_model_for_distributed(model, optimizers, precision, device, auto_microbatching) + self._wrap_model_for_distributed(model, optimizers) self.engine.run_event(Event.BEFORE_LOAD) @@ -1731,25 +1729,8 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_monolith_rank0_only=True but no checkpoint was loaded. - if ( - not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and - self.state.load_monolith_rank0_only - ): - # TODO (FSDP2): support calling FSDP2 wrapper depending on the config type - assert isinstance( - self.state.fsdp_config, - FSDPConfig, - ), f'prepare_fsdp_module requires FSDPConfig, got: {type(self.state.fsdp_config)}' - # Init with globally fixed seed so all HSDP replicas have the same initial weights - with reproducibility.seed_context(self.state.rank_zero_seed): - self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module( - model, - optimizers, - self.state.fsdp_config, - precision, - device, - auto_microbatching, - ) + if not self.state.fsdp_enabled and self.state.load_monolith_rank0_only: + self.state._apply_fsdp() # Set the iteration timestamp to the overall timestamp if loading from a checkpoint that was created before # iteration was introduced in Composer v0.19.1. This is necessary to ensure that the iteration timestamp is @@ -1795,9 +1776,6 @@ def _wrap_model_for_distributed( self, model: ComposerModel, optimizers: Optional[torch.optim.Optimizer], - precision: Union[str, Precision], - device: Device, - auto_microbatching: bool, ): """Wrap the model for distributed training (TP, FSDP, etc.). @@ -1825,28 +1803,7 @@ def _wrap_model_for_distributed( # FSDP wrap if not using monolith checkpoint on rank 0 only if self.state.fsdp_config is not None and not self.state.load_monolith_rank0_only: - # Init with globally fixed seed so all HSDP replicas have the same initial weights - with reproducibility.seed_context(self.state.rank_zero_seed): - match self.state.fsdp_config_version: - case 1: - self.state.automicrobatch_fsdp_hook_handles, self.state.fsdp_modules = prepare_fsdp_module( - model, - optimizers, - self.state.fsdp_config, # type: ignore - precision, - device, - auto_microbatching, - self.state.seed, - ) - case 2: - assert not auto_microbatching - parallelize_composer_model( - model, - optimizers, - self.state.fsdp_config, # type: ignore - ) - case _: - raise ValueError(f'Unsupported FSDP config version: {self.state.fsdp_config_version}') + self.state._apply_fsdp() @property def saved_checkpoints(self) -> list[str]: diff --git a/composer/utils/parallelism.py b/composer/utils/parallelism.py index 38ae71e67e..8e5eaed012 100644 --- a/composer/utils/parallelism.py +++ b/composer/utils/parallelism.py @@ -71,6 +71,11 @@ class FSDP2Config: For 1D mesh, parameters are fully sharded across the mesh (FSDP). For 2D mesh, parameters are sharded across the 1st dimension and replicated across the 0th dimension (HSDP). reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward. + activation_checkpointing (bool): Whether to use activation checkpointing. Defaults to False. + activation_cpu_offload (bool): Whether to use activation CPU offloading. Defaults to False. + load_monolith_rank0_only (bool): Whether to load monolithic checkpoints on rank 0 only. Defaults to False. + state_dict_type (str): Type of state dict to use. Can be 'full' or 'sharded'. Defaults to 'sharded'. + verbose (bool): Whether to print verbose output. Defaults to False. """ # Settable core FSDP2 attrs @@ -80,6 +85,9 @@ class FSDP2Config: # in most of our use cases, we can decouple these two attributes from the FSDP2Config class. activation_checkpointing: bool = False activation_cpu_offload: bool = False + state_dict_type: str = 'sharded' + load_monolith_rank0_only: bool = False + verbose: bool = False # Settable attrs that are automatically set during training @@ -132,16 +140,7 @@ def from_compatible_attrs(cls, attrs: dict[str, Any]) -> 'FSDP2Config': # Create and return a new FSDP2Config with the valid attributes return FSDP2Config(**valid_attrs) - ### Temporary read-only properties for FSDP 1 compatibility ### - # to be supported in FSDP2 - @property - def auto_wrap(self) -> bool: - return False - - @property - def load_monolith_rank0_only(self) -> bool: - return False - + ### Read-only properties for FSDP 1 compatibility ### @property def load_planner(self) -> Optional[Any]: return None @@ -162,11 +161,6 @@ def data_parallel_shard_degree(self) -> int: def data_parallel_replicate_degree(self) -> Optional[int]: return None - # to be deprecated in FSDP2 - @property - def state_dict_type(self) -> str: - return 'sharded' - @property def use_orig_params(self) -> bool: return True @@ -174,6 +168,15 @@ def use_orig_params(self) -> bool: def __post_init__(self): warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning) + # TODO: We might not need `load_monolith_rank0_only` as we can theoretically use + # self.monolith_rank0_only = self.state_dict_type == 'full' assuming that saving + # the model doesn't get affected by `load_monolith_rank0_only` + if self.load_monolith_rank0_only and self.state_dict_type != 'full': + raise ValueError( + 'load_monolith_rank0_only=True requires state_dict_type="full". ' + f'Got state_dict_type="{self.state_dict_type}"', + ) + @dataclass class TPConfig: diff --git a/tests/common/compare.py b/tests/common/compare.py index 4e2e0b9ce6..53f492f4f4 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -57,7 +57,9 @@ def _check_item( _check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol, ignore_keys=ignore_keys) return if isinstance(item1, (tuple, list)): - assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' + # When we are broadcasting lists/tuples from rank0 (e.g. State.load_optim_state) + # tuples get converted to lists and so we don't want to validate the type, just + # the values _check_list_recursively(item1, item2, path, atol=atol, rtol=rtol) return if isinstance(item1, ShardedTensor): diff --git a/tests/trainer/test_fsdp2.py b/tests/trainer/test_fsdp2.py index 3dd76fd2ae..f1b2579df3 100644 --- a/tests/trainer/test_fsdp2.py +++ b/tests/trainer/test_fsdp2.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import os import pathlib from typing import Optional @@ -20,6 +21,7 @@ SimpleWeightTiedModel, world_size, ) +from tests.trainer.test_fsdp_checkpoint import _assert_checkpoints_equivalent _INIT_DEVICES = ['cuda', 'meta'] @@ -27,6 +29,8 @@ def create_trainer_with_model( model: ComposerClassifier, num_classes: int = 10, + dataset_size: int = 2, + batch_size: int = 1, max_duration: str = '10ep', use_fsdp2: bool = True, optimizer: Optional[torch.optim.Optimizer] = None, @@ -34,31 +38,44 @@ def create_trainer_with_model( activation_cpu_offload: bool = False, auto_microbatching: bool = False, fsdp1_sync_module_states: bool = False, + state_dict_type: str = 'sharded', + load_monolith_rank0_only: bool = False, + save_folder: Optional[str] = None, + save_filename: str = 'ba{batch}-rank{rank}.pt', + save_interval: str = '10ba', + load_path: Optional[str] = None, ) -> Trainer: """Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration.""" - dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes) - dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + dataset = RandomClassificationDataset(shape=(num_classes,), size=dataset_size, num_classes=num_classes) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=batch_size) parallelism_config = ParallelismConfig() if use_fsdp2: parallelism_config.fsdp2 = FSDP2Config( activation_checkpointing=activation_checkpointing, activation_cpu_offload=activation_cpu_offload, + state_dict_type=state_dict_type, + load_monolith_rank0_only=load_monolith_rank0_only, ) else: parallelism_config.fsdp = FSDPConfig( - state_dict_type='sharded', + state_dict_type=state_dict_type, sync_module_states=fsdp1_sync_module_states, ) if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + trainer = Trainer( model=model, - optimizers=optimizer, train_dataloader=dataloader, + optimizers=optimizer, max_duration=max_duration, parallelism_config=parallelism_config, device_train_microbatch_size='auto' if auto_microbatching else None, + save_folder=save_folder, + save_filename=save_filename, + save_interval=save_interval, + load_path=load_path, ) return trainer @@ -605,3 +622,83 @@ def test_fsdp2_sync_module_states_mixed_init_weight_equivalence( ) self._compare_weights(fsdp1_weights, fsdp2_weights, tolerance=1e-5) + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('model_class', [SimpleComposerMLP, SimpleWeightTiedModel, PartialWeightTiedModel]) +@pytest.mark.parametrize('device', ['cuda', 'cpu']) +def test_fsdp2_monolithic_checkpoint_save_and_load( + world_size: int, + model_class: type, + tmp_path: pathlib.Path, + device: str, +): + """Test FSDP2 monolithic checkpoint saving and loading with proper model initialization.""" + NUM_FEATURES = 10 + NUM_CLASSES = 10 + BATCH_SIZE = 2 + DATASET_SIZE = 16 # 8 batches, 2 batches per rank, 2 ranks = 32 samples + + # Use tmp_path from all ranks to ensure consistency + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = tmp_paths[0] + + save_interval = '1ba' + save_filename = 'ba{batch}-rank{rank}.pt' + resume_file = 'ba1-rank{rank}.pt' + final_checkpoint = 'latest-rank{rank}.pt' + total_batches_str = '8ba' + + # Create initial model and trainer on CUDA initially for training + kwargs = {'num_classes': NUM_CLASSES} if model_class == SimpleComposerMLP else {} + model1 = model_class(num_features=NUM_FEATURES, device='cuda', **kwargs) + model1.add_fsdp_wrap_attribute_to_children() + if dist.get_local_rank() == 0: + model1.apply(model1.param_init_fn) + + # train for 8 batches but save every 1 batch + trainer1 = create_trainer_with_model( + model=model1, + num_classes=NUM_CLASSES, + dataset_size=DATASET_SIZE, + batch_size=BATCH_SIZE, + max_duration=total_batches_str, + use_fsdp2=True, + save_folder=os.path.join(save_folder, 'first'), + save_filename=save_filename, + state_dict_type='full', + save_interval=save_interval, + ) + + trainer1.fit() + trainer1.close() + + # Create second trainer to load checkpoint with monolithic checkpoint + # On either CPU/GPU based on the device parameter + resolved_device = device if dist.get_local_rank() == 0 else 'meta' + model2 = model_class(num_features=NUM_FEATURES, device=resolved_device, **kwargs) + model2.add_fsdp_wrap_attribute_to_children() + resume_path = os.path.join(save_folder, 'first', resume_file) + + # train for 8 batches to measure equality of checkpoints + trainer2 = create_trainer_with_model( + model=model2, + num_classes=NUM_CLASSES, + dataset_size=DATASET_SIZE, + batch_size=BATCH_SIZE, + max_duration=total_batches_str, + use_fsdp2=True, + state_dict_type='full', + load_monolith_rank0_only=True, + load_path=resume_path, + save_folder=os.path.join(save_folder, 'second'), + save_filename=save_filename, + ) + trainer2.fit() + trainer2.close() + + _assert_checkpoints_equivalent( + os.path.join(save_folder, 'first', final_checkpoint), + os.path.join(save_folder, 'second', final_checkpoint), + ) diff --git a/tests/trainer/test_fsdp2_config.py b/tests/trainer/test_fsdp2_config.py index ed76fa9fcd..90dcd57e08 100644 --- a/tests/trainer/test_fsdp2_config.py +++ b/tests/trainer/test_fsdp2_config.py @@ -12,7 +12,6 @@ def test_fsdp2_config(): config = FSDP2Config() # Test reading properties (should succeed) - assert config.auto_wrap is False assert config.load_monolith_rank0_only is False assert config.sync_module_states is False assert config.activation_cpu_offload is False @@ -20,14 +19,12 @@ def test_fsdp2_config(): assert config.data_parallel_replicate_degree is None assert config.state_dict_type == 'sharded' assert config.use_orig_params is True + assert config.load_monolith_rank0_only is False # Test setting properties (should fail) read_only_props = [ - ('auto_wrap', False), - ('load_monolith_rank0_only', True), ('data_parallel_shard_degree', 2), ('data_parallel_replicate_degree', 2), - ('state_dict_type', 'full'), ('use_orig_params', False), ] @@ -90,3 +87,21 @@ def test_fsdp2config_from_fsdp1_multiple_invalid_attributes(): assert any('invalid_attribute2: value2' in msg for msg in warning_messages) assert any('auto_wrap: True' in msg for msg in warning_messages) assert any('sync_module_states: True' in msg for msg in warning_messages) + + +def test_fsdp2_config_monolithic_validation(): + """Test FSDP2Config validation for monolithic checkpointing.""" + # Test valid monolithic config + config = FSDP2Config( + state_dict_type='full', + load_monolith_rank0_only=True, + ) + assert config.state_dict_type == 'full' + assert config.load_monolith_rank0_only is True + + # Test invalid monolithic config + with pytest.raises(ValueError, match='load_monolith_rank0_only=True requires state_dict_type="full"'): + FSDP2Config( + state_dict_type='sharded', + load_monolith_rank0_only=True, + )