-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Initial FSDP2 support #3394
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! Left some initial comments
src/accelerate/utils/dataclasses.py
Outdated
self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] | ||
if self.fsdp_version != 2 and isinstance(self.reshard_after_forward, bool): | ||
raise ValueError( | ||
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"reshard_after_forward set to bool. This is not supported in FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" | |
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" |
src/accelerate/utils/dataclasses.py
Outdated
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy): | ||
raise ValueError( | ||
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`" | ||
) | ||
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy): | ||
raise ValueError( | ||
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we can simplify this
if self.fsdp_version != 2 and isinstance(self.cpu_offload, CPUOffloadPolicy): | |
raise ValueError( | |
"cpu_offload set to `torch.distributed.fsdp.CPUOffloadPolicy`. This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`" | |
) | |
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, CPUOffloadPolicy): | |
raise ValueError( | |
"cpu_offload set to `bool` or `torch.distributed.fsdp.CPUOffload`. This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`" | |
) | |
if isinstance(self.cpu_offload, CPUOffloadPolicy): | |
err = "`cpu_offload` set to `torch.distributed.fsdp.CPUOffloadPolicy`." | |
if self.fsdp_version != 2: | |
raise ValueError(f"{err} This is not supported in FSDP1, please set to a `bool` or an instance of `torch.distributed.fsdp.CPUOffload`") | |
else: | |
raise ValueError(f"{err} This is not supported in FSDP2, please set to an instance of `torch.distributed.fsdp.CPUOffloadPolicy`") |
src/accelerate/utils/dataclasses.py
Outdated
self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) | ||
if self.fsdp_version == 2: | ||
if not self.cpu_offload: | ||
warnings.warn("Offload_params is set to False, however FSDP2 always offloads parameters and runs optimizer step on CPU. This will be overridden to True.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that in general, users will ignore warnings
(and will get annoyed it bloats the logs). So we should instead use logger.warn
if we don't want to explicitly raise an error about this.
@@ -1657,12 +1704,22 @@ def __post_init__(self): | |||
|
|||
if self.use_orig_params is None: | |||
self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 | |||
if self.fsdp_version == 2 and self.use_orig_params is not None: | |||
warnings.warn("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing many warning.warn
,s let's accumulate them all and do one big logger.warn
at the end
src/accelerate/utils/dataclasses.py
Outdated
if self.fsdp_version == 2 and self.forward_prefetch is not None: | ||
raise ValueError( | ||
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.fsdp_version == 2 and self.forward_prefetch is not None: | |
raise ValueError( | |
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version` set to 1" | |
) | |
if self.fsdp_version == 2 and self.forward_prefetch is not None: | |
raise ValueError( | |
"forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`" | |
) |
src/accelerate/utils/dataclasses.py
Outdated
""" | ||
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. | ||
""" | ||
from torch.distributed.fsdp import MixedPrecision, MixedPrecisionPolicy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this import lead to issues on old pytorch versions?
I see potential collaboration on this thread with my PRs With changes on this PR I understand that the design is to convert FSDP1 args to FSDP2 args and put to use. With my initial discussions started at #3231, In the PR #3395 I propose to have FSDP2 as a separate distributed type training protocol alongsides FSDP1. Also, to support more complex combinations of parallelisms I as well propose a design as described in the PR's description. Looking forward. |
Draft PR, feel free to discuss changes to the user-facing api.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.