-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support FSDP #149
Support FSDP #149
Conversation
Log for mnist_fsdp.py: Train Epoch: 2 Loss: 0.176441 Train Epoch: 3 Loss: 0.127415 Train Epoch: 4 Loss: 0.105661 Train Epoch: 5 Loss: 0.096828 Train Epoch: 6 Loss: 0.090231 Train Epoch: 7 Loss: 0.083397 Train Epoch: 8 Loss: 0.081701 Train Epoch: 9 Loss: 0.081912 Train Epoch: 10 Loss: 0.079299 Train Epoch: 11 Loss: 0.078325 Train Epoch: 12 Loss: 0.077337 Train Epoch: 13 Loss: 0.077516 Train Epoch: 14 Loss: 0.076482 |
37d2bc0
to
f005202
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job!
I have the following questions:
- Which line of the code synchroinize the scaling factor (meta.scale or meta.scale_inv) across GPUs?
- Gradient accumulation seems to be not supported yet. We can raise a NotImplemented exception when the steps of gradient accumulation is larger than 1.
- Is there any comparision on memory footprint between BF16-FSDP and FP8-FSDP?
For these 3 questions:
|
Regarding to the answer 3, it may be related to the argument When |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an initial version of FSDP in MSAMP
In the future, we will do:
|
Sure. Let's improve it in next iteration. |
Description
Support FSDP with FP8.
Major Revision