Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam #1078

Merged
merged 3 commits into from
Nov 1, 2024

Conversation

kunlunl
Copy link
Contributor

@kunlunl kunlunl commented Aug 5, 2024

Description

Add options to set the dtypes of master weights, exp_avg and exp_avg_sq of FusedAdam.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Support using fp32/fp16 master weights
  • Support using fp32/fp16/fp8 exp_avg
  • Support using fp32/fp16/fp8 exp_avg_sq

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@kunlunl kunlunl changed the title Add MX-FP16 Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam Aug 6, 2024
@kunlunl kunlunl changed the title Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam Draft: Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam Aug 6, 2024
@kunlunl kunlunl changed the title Draft: Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam Oct 29, 2024
@kunlunl
Copy link
Contributor Author

kunlunl commented Oct 30, 2024

@timmoon10 Hello, I noticed no one has commented on this MR for a long time, could you please take a look, or could you help find someone to review it?

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good. It would be more general if we disentangled the state dtypes and state scaling (e.g. why not have scaled FP32 states or unscaled BF16 states?), but this does cover the specific cases in the MS-AMP paper.

For future reference, this PR adapts logic from NVIDIA/apex#1771. This is a proof-of-concept with several opporunities for future improvement:

  • TE kernel for computing absmax and scale
  • Fusing scale/unscale within Adam kernel
  • Reduce memory usage in optimizer step, perhaps by processing params in chunks
  • Reduce memory usage in checkpointing, perhaps by storing checkpoint buffers in CPU

transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
@@ -112,9 +149,6 @@ def __init__(
self.set_grad_none = set_grad_none

self.capturable = capturable

if master_weights is not None:
assert isinstance(master_weights, list), "master_weights must be a list if provided"
self.master_weights = master_weights
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the use-case where the master weights are provided externally (added in #977). I personally like this change since it makes things cleaner, but will it have an effect on Mcore integration? Pinging @Wong4j.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know this problem. I talked with @Wong4j offline and invited him to review this PR.
His MR in MCore (fuse dtype casting) has not been merged yet, so I put the "fusing dtype casting" function into a new MR in MCore, together with this precision-aware optimizer.

tests/pytorch/test_fused_optimizer.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@yaox12
Copy link
Collaborator

yaox12 commented Oct 31, 2024

/te-ci pytorch

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Kunlun Li <[email protected]>
@timmoon10
Copy link
Collaborator

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI and confirmation from @Wong4j that this won't break Mcore integration.

@Wong4j
Copy link
Contributor

Wong4j commented Nov 1, 2024

LGTM.
@timmoon10 This design is better. My mcore PR is not merged yet. So it won't break mcore.

@timmoon10 timmoon10 merged commit 05c0fb0 into NVIDIA:main Nov 1, 2024
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants