-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Support dtype casting in fused adam #977
Conversation
aa11601
to
4277bd1
Compare
fd68cdd
to
f65a320
Compare
f6f7c49
to
b4c90a8
Compare
@timmoon10 Could you please take a look? |
/te-ci pytorch |
Hi @timmoon10 , I encountered an issue when trying to update scale_inv inside the Adam kernel using |
/te-ci pytorch |
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice now that this file uses unittest
, while the CI infrastructure uses pytest
:
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py |
It may be better to fix that in a separate PR.
4c2d42e
to
47a448b
Compare
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu
Outdated
Show resolved
Hide resolved
c47909c
to
5ae1573
Compare
/te-ci pytorch |
Based on a discussion with @ptrendx, I think we should give more thought to the API. While this is primarily targeting Megatron-LM, it's important that other TE users can use it easily without relying on Mcore infrastructure. @ptrendx's preferred API is for the optimizer to hold the model weights (including model = MyModel() # Mix of fp32, bf16, fp8 params
optim = FusedAdam(model.parameters(), dtype=torch.float32) # Create FP32 master weights for each non-fp32 param
optim.step()
# optim.state[bf16_param]["exp_avg"] is fp32 tensor
# optim.state[bf16_param]["exp_avg_sq"] is fp32 tensor
# optim.state[bf16_param]["master_param"] is fp32 tensor
# optim.state[fp32_param]["master_param"] is None This API is more natural for standard PyTorch workflows and it doesn't require maintaining separate model weights/master weights like in Megatron-LM. That said, I can see value in keeping model = MyModel() # Mix of fp32, bf16, fp8 params
master_weights = [param.float() for param in model.parameters()]
optim = FusedAdam(model.parameters(), dtype=torch.float32, master_weights=master_weights)
# optim.state[param]["master_param"] is from my_master_weights |
541e6e7
to
ae375cd
Compare
Hi @timmoon10 , I have made modifications to the FusedAdam API based on your suggestions. I already tested my changes in Megatron-LM, and the training loss matches the previous results exactly.
# create optimizer
optimizer = FusedAdam(param_groups, ... , master_weights=None)
optimizer = DistributedOptimizer(optimizer, *other_args)
# inside __init__ of dist opt
master_weights = list(itertools.chain(*self.shard_fp32_from_float16_groups))
self.optimizer.master_weights = master_weights # self.optimizer is FusedAdam This usage is somewhat uncomfortable, but not entirely unusual. Any suggestions?
|
ccd54cd
to
267c90a
Compare
@timmoon10 Could you please take a look? |
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
267c90a
to
44dca61
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for implementing all the API changes, this is much cleaner and easier to reason about. I think there are still some things that could be improved (options to construct master weights internally, cleaning up how to specify master weights, mixed FP16/BF16, fixing the tests), but those are internal changes that can be worked on later.
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
* support dtype casting fusion in FusedAdam Signed-off-by: Shijie Wang <[email protected]> * minor changes Signed-off-by: Shijie Wang <[email protected]> * fix lint Signed-off-by: Shijie Wang <[email protected]> * changes based on review comments Signed-off-by: Shijie Wang <[email protected]> * remove unused code Signed-off-by: Shijie Wang <[email protected]> * code refactor Signed-off-by: Shijie Wang <[email protected]> * fix typo Signed-off-by: Shijie Wang <[email protected]> * refactor Signed-off-by: Shijie Wang <[email protected]> * remove unused code Signed-off-by: Shijie Wang <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Copy CUDA headers for framework sdists Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Shijie Wang <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: beinggod <[email protected]>
Description
FusedAdam updates the params in-place currently.
This PR adds dtype casting in FusedAdam kernel, in addition to updating the master params in-place, but also can update extra model params. The extra params can be of bf16, fp16, fp8 type.
Update:
I have validated the convergence using GPT training in Megatron-LM. The losses before and after enabling this feature are identical in bits.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: