Skip to content

Commit

Permalink
[feat] support a context for loading state_dict for FSDP (#1065)
Browse files Browse the repository at this point in the history
* [fix]: add a context for supporting state_dict from a non-FSDP parent module

* formatting

Co-authored-by: Min Xu <[email protected]>
  • Loading branch information
min-xu-ai and flying-x authored Sep 7, 2022
1 parent 3cc7fa8 commit 4b126c7
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
8 changes: 7 additions & 1 deletion fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from typing import List

from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
no_pre_load_state_dict_hook,
)
from .sharded_ddp import ShardedDataParallel

__all__: List[str] = []
20 changes: 18 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap

from . import fsdp_optim_utils as ou
Expand Down Expand Up @@ -762,6 +762,7 @@ def _shard_parameters_(self) -> None:
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?

# Replace p.data with the relevant shard.
orig_data = p.data
Expand Down Expand Up @@ -2581,10 +2582,25 @@ def apply_to_tensor(obj: torch.Tensor) -> torch.Tensor:
return state_dict


@contextlib.contextmanager
def no_pre_load_state_dict_hook() -> Generator:
"""Disable the pre-load hook.
This is needed if we are loading a state_dict that was not produced by
a root FSDP instance.
"""
global _enable_pre_load_state_dict_hook
bak = _enable_pre_load_state_dict_hook
_enable_pre_load_state_dict_hook = False
yield
_enable_pre_load_state_dict_hook = bak


def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
if _enable_pre_load_state_dict_hook:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")


def _clean_path(path: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
from fairscale.nn.checkpoint import checkpoint_wrapper

from .flatten_params_wrapper import FlattenParamsWrapper
from .flatten_params_wrapper import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from .param_bucket import GradBucket, ParamBucket

__all__: List[str] = []
5 changes: 5 additions & 0 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

# See no_pre_load_state_dict_hook context manager function in FSDP for more details.
_enable_pre_load_state_dict_hook = True


class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
Expand Down Expand Up @@ -543,6 +546,8 @@ def _post_state_dict_hook(
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
) -> None:
if not _enable_pre_load_state_dict_hook:
return
# Push everything down to ._fpw_module level.
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
# The flat_param_* keys actually needs to move one level up.
Expand Down

0 comments on commit 4b126c7

Please sign in to comment.