-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP: How to do all-gather using FP8? #1188
Comments
Hi @vgoklani -- TE modules can be initialized under the I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the There are two things to be mindful of here:
If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to |
Adding to this, FSDP support should just be a matter of implementing |
If you guys are at the CUDA-MODE hackathon this Saturday (IRL) then let's work on this!!!!! |
FSDP2 supports all-gather using FP8:
https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323
Wondering if we could do this directly using TransformerEngine instead of torch-ao?
Thanks!
The text was updated successfully, but these errors were encountered: