Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Support dtype casting in fused adam #977

Merged
merged 12 commits into from
Aug 16, 2024

Conversation

Wong4j
Copy link
Contributor

@Wong4j Wong4j commented Jul 1, 2024

Description

FusedAdam updates the params in-place currently.
This PR adds dtype casting in FusedAdam kernel, in addition to updating the master params in-place, but also can update extra model params. The extra params can be of bf16, fp16, fp8 type.

Update:
I have validated the convergence using GPT training in Megatron-LM. The losses before and after enabling this feature are identical in bits.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Wong4j Wong4j changed the title Support dtype casting in fused adam [PyTorch] Support dtype casting in fused adam Jul 1, 2024
@Wong4j Wong4j changed the title [PyTorch] Support dtype casting in fused adam [WIP] [PyTorch] Support dtype casting in fused adam Jul 1, 2024
@Wong4j Wong4j changed the title [WIP] [PyTorch] Support dtype casting in fused adam [PyTorch] Support dtype casting in fused adam Jul 12, 2024
@Wong4j
Copy link
Contributor Author

Wong4j commented Jul 12, 2024

@timmoon10 Could you please take a look?
The corresponding changes to Megatron-LM are in our internal gitlab MR#1736.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10 timmoon10 self-requested a review July 12, 2024 20:26
@Wong4j
Copy link
Contributor Author

Wong4j commented Jul 15, 2024

Hi @timmoon10 , I encountered an issue when trying to update scale_inv inside the Adam kernel using *scale_inv_ptr = 1.0f / scale. This resulted in loss not being bit-wise aligned. The reason is that TE/PyTorch compilation uses --use_fast_math, which compiles the reciprocal calculation into a single MUFU.RCP instruction, producing an approximate result rather than an accurate one.
To achieve bit-wise alignment of the loss, I had to update scale_inv outside the Adam kernel. This also leads to suboptimal performance. Do you have any suggestions to address this?

@zlsh80826
Copy link
Collaborator

/te-ci pytorch

transformer_engine/pytorch/csrc/multi_tensor_apply.cuh Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions.h Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/multi_tensor_apply.cuh Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/multi_tensor_apply.cuh Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
tests/pytorch/test_fused_optimizer.py Show resolved Hide resolved
tests/pytorch/test_fused_optimizer.py Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice now that this file uses unittest, while the CI infrastructure uses pytest:

pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py

It may be better to fix that in a separate PR.

@timmoon10 timmoon10 self-requested a review July 16, 2024 18:27
@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 2 times, most recently from c47909c to 5ae1573 Compare July 19, 2024 06:35
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10
Copy link
Collaborator

timmoon10 commented Jul 24, 2024

Based on a discussion with @ptrendx, I think we should give more thought to the API. While this is primarily targeting Megatron-LM, it's important that other TE users can use it easily without relying on Mcore infrastructure.

@ptrendx's preferred API is for the optimizer to hold the model weights (including Float8Tensors) and to treat the master weights as optimizer state (similar to exp_avg and exp_avg_sq). This is similar to Option 1 in #977 (comment). The workflow should look like:

model = MyModel()  # Mix of fp32, bf16, fp8 params
optim = FusedAdam(model.parameters(), dtype=torch.float32)  # Create FP32 master weights for each non-fp32 param
optim.step()
# optim.state[bf16_param]["exp_avg"] is fp32 tensor
# optim.state[bf16_param]["exp_avg_sq"] is fp32 tensor
# optim.state[bf16_param]["master_param"] is fp32 tensor
# optim.state[fp32_param]["master_param"] is None

This API is more natural for standard PyTorch workflows and it doesn't require maintaining separate model weights/master weights like in Megatron-LM. That said, I can see value in keeping master_weights as an optional kwarg since Megatron-LM already allocates them:

model = MyModel()  # Mix of fp32, bf16, fp8 params
master_weights = [param.float() for param in model.parameters()]
optim = FusedAdam(model.parameters(), dtype=torch.float32, master_weights=master_weights)
# optim.state[param]["master_param"] is from my_master_weights

@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 2 times, most recently from 541e6e7 to ae375cd Compare August 6, 2024 14:21
@Wong4j
Copy link
Contributor Author

Wong4j commented Aug 6, 2024

Hi @timmoon10 , I have made modifications to the FusedAdam API based on your suggestions. I already tested my changes in Megatron-LM, and the training loss matches the previous results exactly.
However, there are still some issues that need to be discussed:

  1. I have restricted that master_weights must be provided by the user, and the user-provided master_weights must be a list of tensors. If the user does not provide master_weights (i.e., master_weights=None), only the model weights will be updated. Is this approach reasonable?

  2. In Megatron-LM, master_weights are created in the __init__ method of dist opt, while FusedAdam is created earlier. Therefore, I had to initially set master_weights to None, and then modify optimizer.master_weights in the __init__ method of dist opt with the following code:

# create optimizer
optimizer = FusedAdam(param_groups, ... , master_weights=None)
optimizer = DistributedOptimizer(optimizer, *other_args)

# inside __init__ of dist opt
master_weights = list(itertools.chain(*self.shard_fp32_from_float16_groups))
self.optimizer.master_weights = master_weights  # self.optimizer is FusedAdam

This usage is somewhat uncomfortable, but not entirely unusual. Any suggestions?

  1. Kunlun is currently implementing MX-FP16. After some discussion, we believe that it seems more reasonable to place the creation of master_weights inside FusedAdam. This is because exp_avg, exp_avg_sq and master_weight are optimizer states, and since "exp_avg" and "exp_avg_sq" are created and updated within FusedAdam, master_weight should be handled in the same way. However, this change would also conflict with the design logic of Megatron.

@Wong4j Wong4j force-pushed the jaywan/add_fused_adam branch 5 times, most recently from ccd54cd to 267c90a Compare August 13, 2024 02:10
@Wong4j
Copy link
Contributor Author

Wong4j commented Aug 13, 2024

@timmoon10 Could you please take a look?

Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Shijie Wang <[email protected]>
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for implementing all the API changes, this is much cleaner and easier to reason about. I think there are still some things that could be improved (options to construct master weights internally, cleaning up how to specify master weights, mixed FP16/BF16, fixing the tests), but those are internal changes that can be worked on later.

transformer_engine/pytorch/optimizers/fused_adam.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10 timmoon10 merged commit 4edcff5 into NVIDIA:main Aug 16, 2024
14 of 15 checks passed
BeingGod pushed a commit to BeingGod/TransformerEngine that referenced this pull request Aug 30, 2024
* support dtype casting fusion in FusedAdam

Signed-off-by: Shijie Wang <[email protected]>

* minor changes

Signed-off-by: Shijie Wang <[email protected]>

* fix lint

Signed-off-by: Shijie Wang <[email protected]>

* changes based on review comments

Signed-off-by: Shijie Wang <[email protected]>

* remove unused code

Signed-off-by: Shijie Wang <[email protected]>

* code refactor

Signed-off-by: Shijie Wang <[email protected]>

* fix typo

Signed-off-by: Shijie Wang <[email protected]>

* refactor

Signed-off-by: Shijie Wang <[email protected]>

* remove unused code

Signed-off-by: Shijie Wang <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Copy CUDA headers for framework sdists

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: beinggod <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants