-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam #1078
Conversation
Signed-off-by: kunlunl <[email protected]>
@timmoon10 Hello, I noticed no one has commented on this MR for a long time, could you please take a look, or could you help find someone to review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good. It would be more general if we disentangled the state dtypes and state scaling (e.g. why not have scaled FP32 states or unscaled BF16 states?), but this does cover the specific cases in the MS-AMP paper.
For future reference, this PR adapts logic from NVIDIA/apex#1771. This is a proof-of-concept with several opporunities for future improvement:
- TE kernel for computing absmax and scale
- Fusing scale/unscale within Adam kernel
- Reduce memory usage in optimizer step, perhaps by processing params in chunks
- Reduce memory usage in checkpointing, perhaps by storing checkpoint buffers in CPU
@@ -112,9 +149,6 @@ def __init__( | |||
self.set_grad_none = set_grad_none | |||
|
|||
self.capturable = capturable | |||
|
|||
if master_weights is not None: | |||
assert isinstance(master_weights, list), "master_weights must be a list if provided" | |||
self.master_weights = master_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know this problem. I talked with @Wong4j offline and invited him to review this PR.
His MR in MCore (fuse dtype casting) has not been merged yet, so I put the "fusing dtype casting" function into a new MR in MCore, together with this precision-aware optimizer.
/te-ci pytorch |
/te-ci pytorch |
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Kunlun Li <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI and confirmation from @Wong4j that this won't break Mcore integration.
LGTM. |
Description
Add options to set the dtypes of master weights, exp_avg and exp_avg_sq of FusedAdam.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: