Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial FSDP2 support #3394

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

WIP: Initial FSDP2 support #3394

wants to merge 17 commits into from

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented Feb 11, 2025

Draft PR, feel free to discuss changes to the user-facing api.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! Left some initial comments

src/accelerate/commands/to_fsdp2.py Outdated Show resolved Hide resolved
self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]
if self.fsdp_version != 2 and isinstance(self.reshard_after_forward, bool):
raise ValueError(
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"

Comment on lines 1653 to 1660
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`"
)
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we can simplify this

Suggested change
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`"
)
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy):
raise ValueError(
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`"
)
if isinstance(self.cpu_offload, CPUOffloadPolicy):
err = "`cpu_offload` set to `torch.distributed.fsdp.CPUOffloadPolicy`."
if self.fsdp_version != 2:
raise ValueError(f"{err} This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`")
else:
raise ValueError(f"{err} This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`")

self.cpu_offload = CPUOffload(offload_params=self.cpu_offload)
if self.fsdp_version == 2:
if not self.cpu_offload:
warnings.warn("Offload_params is set to False, however FSDP2 always offloads parameters and runs optimizer step on CPU. This will be overridden to True.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in general, users will ignore warnings (and will get annoyed it bloats the logs). So we should instead use logger.warn if we don't want to explicitly raise an error about this.

@@ -1657,12 +1704,22 @@ def __post_init__(self):

if self.use_orig_params is None:
self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1
if self.fsdp_version == 2 and self.use_orig_params is not None:
warnings.warn("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing many warning.warn,s let's accumulate them all and do one big logger.warn at the end

Comment on lines 1719 to 1722
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1"
)
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError(
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`"
)

"""
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.
"""
from torch.distributed.fsdp import MixedPrecision, MixedPrecisionPolicy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this import lead to issues on old pytorch versions?

@kmehant
Copy link
Contributor

kmehant commented Feb 13, 2025

@muellerzr @S1ro1

I see potential collaboration on this thread with my PRs

  1. Support TP + FSDPv2 / HSDP or just FSDPv2 / HSDP #3395
  2. [RFC] Support FSDP2 #3231

With changes on this PR I understand that the design is to convert FSDP1 args to FSDP2 args and put to use. With my initial discussions started at #3231, In the PR #3395 I propose to have FSDP2 as a separate distributed type training protocol alongsides FSDP1. Also, to support more complex combinations of parallelisms I as well propose a design as described in the PR's description.

Looking forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants