Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP 2 doesn't pad tensors? #764

Open
cassanof opened this issue Dec 29, 2024 · 4 comments
Open

FSDP 2 doesn't pad tensors? #764

cassanof opened this issue Dec 29, 2024 · 4 comments
Labels
question Further information is requested

Comments

@cassanof
Copy link

Hi, I ran my model with FSDP 2, one of the linear layers has a dim that's not divisible by the world size (128), and so I got the following error:

torch.Size([...]) is not divisible by FSDP world size 128.

FSDP 1 circumvents this issue by padding the tensors. Is this not supported by FSDP 2? If not, will it be supported?

@cassanof
Copy link
Author

cassanof commented Dec 29, 2024

I updated torch to the latest nightly, and it seems to now work, except for float8, which produces the following error:

  RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

in torchao/float8/float8_utils.py, line 102, in tensor_to_amax

Without float8, the model trains fine. Specifically, this issue seems to be related to enable_fsdp_float8_all_gather, enabling enable_float8_linear and precompute_float8_dynamic_scale_for_fsdp works fine.

@awgu
Copy link
Contributor

awgu commented Dec 30, 2024

FSDP2 does pad tensors on the sharded dim. For your original error, I am not sure where it is coming from. It would be helpful to show more of the stack trace. For your new error, what is your linear module shape that you are using with float8 that has dim-0 smaller than the FSDP world size?

@tianyu-l tianyu-l added the question Further information is requested label Dec 31, 2024
@cassanof
Copy link
Author

Yes, seems like the tensor's dim-0 is smaller than the world size. Is this a limitation of fp8 all-gather?

Thanks!

@awgu
Copy link
Contributor

awgu commented Dec 31, 2024

Yes, I think currently the fp8 all-gather assumes the dim-0 is divisible by the world size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants