Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to FSDP wrap by groups of blocks #340

Merged
merged 3 commits into from
Oct 26, 2023
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Oct 26, 2023

Closes #289.

Adds the ability to wrap multiple sequential transformer blocks together into a single parent FSDP wrapper by setting --model.block_group_size to the desired size and --fsdp.wrapping_strategy to "by_block_group". For example:

scripts/train.py ... \
  --model.block_group_size=4 \
  --fsdp.wrapping_strategy=by_block_group

In this case every group of 4 sequential blocks would be wrapped together.

@epwalsh epwalsh mentioned this pull request Oct 26, 2023
4 tasks
Comment on lines +1055 to +1056
if recurse:
return True # always recurse for simplicity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this works. When does this ever get called with recurse == False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand this function basically gets called twice on every module. Once with recurse=False to check if that module itself should be wrapped, and once again with recurse=True to check if it should go deeper into submodules of the current module to potentially wrap any of those.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the meaning of the return value changes depending on the value of recurse, which is confusing... and is the reason we had the wrapping bug in the first place where we thought we were wrapping by block but we were actually just wrapping the whole model.

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks quite good. Can't wait to try it.

return isinstance(module, OlmoBlock)

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this strategy, the input and output embeddings are never wrapped. I think that's fine at this point in time, but we should experiment with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I didn't want to change too many things.

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

olmo/config.py Outdated
block_group_size: int = 1
"""
The number of blocks to group together into a single parent block.
This is has no affect on the number of parameters in the model and is only used to wrap groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is has -> has

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh epwalsh merged commit cd73387 into main Oct 26, 2023
8 of 9 checks passed
@epwalsh epwalsh deleted the epwalsh/block-groups branch October 26, 2023 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FSDP optimizations to try
3 participants