Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FSDP #149

Merged
merged 18 commits into from
Jan 16, 2024
Merged

Support FSDP #149

merged 18 commits into from
Jan 16, 2024

Conversation

tocean
Copy link
Contributor

@tocean tocean commented Jan 8, 2024

Description
Support FSDP with FP8.

Major Revision

  • Add fsdp package
  • Add mnist example
  • Add FSDPAdam and FSDPAdamW optimizer
  • Add document
  • Add UT

@tocean
Copy link
Contributor Author

tocean commented Jan 8, 2024

Log for mnist_fsdp.py:
Train Epoch: 1 Loss: 0.536632
Test set: Average loss: 0.1431, Accuracy: 9581/10000 (95.81%)

Train Epoch: 2 Loss: 0.176441
Test set: Average loss: 0.0877, Accuracy: 9729/10000 (97.29%)

Train Epoch: 3 Loss: 0.127415
Test set: Average loss: 0.0687, Accuracy: 9793/10000 (97.93%)

Train Epoch: 4 Loss: 0.105661
Test set: Average loss: 0.0608, Accuracy: 9813/10000 (98.13%)

Train Epoch: 5 Loss: 0.096828
Test set: Average loss: 0.0551, Accuracy: 9826/10000 (98.26%)

Train Epoch: 6 Loss: 0.090231
Test set: Average loss: 0.0527, Accuracy: 9829/10000 (98.29%)

Train Epoch: 7 Loss: 0.083397
Test set: Average loss: 0.0506, Accuracy: 9833/10000 (98.33%)

Train Epoch: 8 Loss: 0.081701
Test set: Average loss: 0.0497, Accuracy: 9833/10000 (98.33%)

Train Epoch: 9 Loss: 0.081912
Test set: Average loss: 0.0488, Accuracy: 9839/10000 (98.39%)

Train Epoch: 10 Loss: 0.079299
Test set: Average loss: 0.0487, Accuracy: 9841/10000 (98.41%)

Train Epoch: 11 Loss: 0.078325
Test set: Average loss: 0.0482, Accuracy: 9837/10000 (98.37%)

Train Epoch: 12 Loss: 0.077337
Test set: Average loss: 0.0481, Accuracy: 9838/10000 (98.38%)

Train Epoch: 13 Loss: 0.077516
Test set: Average loss: 0.0479, Accuracy: 9836/10000 (98.36%)

Train Epoch: 14 Loss: 0.076482
Test set: Average loss: 0.0479, Accuracy: 9837/10000 (98.37%)

Copy link
Contributor

@wkcn wkcn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job!

I have the following questions:

  1. Which line of the code synchroinize the scaling factor (meta.scale or meta.scale_inv) across GPUs?
  2. Gradient accumulation seems to be not supported yet. We can raise a NotImplemented exception when the steps of gradient accumulation is larger than 1.
  3. Is there any comparision on memory footprint between BF16-FSDP and FP8-FSDP?

@tocean
Copy link
Contributor Author

tocean commented Jan 9, 2024

  1. Which line of the code synchroinize the scaling factor (meta.scale or meta.scale_inv) across GPUs?
  2. Gradient accumulation seems to be not supported yet. We can raise a NotImplemented exception when the steps of gradient accumulation is larger than 1.
  3. Is there any comparision on memory footprint between BF16-FSDP and FP8-FSDP?

For these 3 questions:

  1. See FSDPAdamW::step. In this function, all_reduce of amax is called.
  2. Add checking loigc in _fp8_post_backward_hook.
  3. I checked the memory saving use T5 example here. The model I use is t5-3b. The memory footprint for bf16, FP32, MS-AMP is 28GB, 40GB and 34GB respectively. It is a bit strange that MS-AMP uses more memory than BF16.

@wkcn
Copy link
Contributor

wkcn commented Jan 10, 2024

  1. Which line of the code synchroinize the scaling factor (meta.scale or meta.scale_inv) across GPUs?
  2. Gradient accumulation seems to be not supported yet. We can raise a NotImplemented exception when the steps of gradient accumulation is larger than 1.
  3. Is there any comparision on memory footprint between BF16-FSDP and FP8-FSDP?

For these 3 questions:

  1. See FSDPAdamW::step. In this function, all_reduce of amax is called.
  2. Add checking loigc in _fp8_post_backward_hook.
  3. I checked the memory saving use T5 example here. The model I use is t5-3b. The memory footprint for bf16, FP32, MS-AMP is 28GB, 40GB and 34GB respectively. It is a bit strange that MS-AMP uses more memory than BF16.

Regarding to the answer 3, it may be related to the argument mixed_precision of FSDP.

When mixed_precision is not None, FSDP will create a FP32 master weight for FP8Linear. It leads to duplicated master weights in FSDP and FP8Optimizer.

Copy link
Contributor

@wkcn wkcn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@tocean tocean requested a review from penghouwen January 11, 2024 06:02
Copy link

@penghouwen penghouwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an initial version of FSDP in MSAMP

@penghouwen
Copy link

In the future, we will do:

  • gradient accumulation
  • speed optimization
  • acc calibration

@penghouwen penghouwen closed this Jan 16, 2024
@penghouwen penghouwen reopened this Jan 16, 2024
@tocean tocean merged commit 2fbe898 into main Jan 16, 2024
17 checks passed
@tocean tocean deleted the yuxiang/fsdp_opt branch January 16, 2024 02:59
@tocean
Copy link
Contributor Author

tocean commented Jan 16, 2024

In the future, we will do:

  • gradient accumulation
  • speed optimization
  • acc calibration

Sure. Let's improve it in next iteration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants