-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP 2 doesn't pad tensors? #764
Comments
I updated torch to the latest nightly, and it seems to now work, except for float8, which produces the following error:
in Without float8, the model trains fine. Specifically, this issue seems to be related to |
FSDP2 does pad tensors on the sharded dim. For your original error, I am not sure where it is coming from. It would be helpful to show more of the stack trace. For your new error, what is your linear module shape that you are using with float8 that has dim-0 smaller than the FSDP world size? |
Yes, seems like the tensor's dim-0 is smaller than the world size. Is this a limitation of fp8 all-gather? Thanks! |
Yes, I think currently the fp8 all-gather assumes the dim-0 is divisible by the world size. |
Hi, I ran my model with FSDP 2, one of the linear layers has a dim that's not divisible by the world size (128), and so I got the following error:
FSDP 1 circumvents this issue by padding the tensors. Is this not supported by FSDP 2? If not, will it be supported?
The text was updated successfully, but these errors were encountered: