Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add high_precision_init_val to model params when using fp8_model_init #1121

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kunlunl
Copy link
Contributor

@kunlunl kunlunl commented Aug 19, 2024

Description

When using fp8_model_init to create a model, the weights will be casted to Float8Tensor. However, in scenarios where high-precision (FP32) master weights are needed, initializing the master weights with these FP8 weights can affect the loss convergence compared to using bf16/fp16 to initialize master weights (especially in the early stages of training).
This PR stores the original bf16/fp16 params as cpu tensors within the FP8 weights, which can be used to initialize master weights in other frameworks like MCore.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Stores the original bf16/fp16 params as cpu tensors within the FP8 weights

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ptrendx
Copy link
Member

ptrendx commented Aug 19, 2024

Hmmm, I see the problem you are trying to solve, although I don't think I like the approach (don't think I have an alternative ready yet though :-( ). Did you consider any other ways to solve this issue?

@kunlunl
Copy link
Contributor Author

kunlunl commented Aug 20, 2024

I also feel it is a very ugly approach, but I can't think of a better way to do it. Then I asked @timmoon10 if he has any insight, I quote his ideas here, @timmoon10 you can add more comments if you have.

  • Storing on CPU is a pretty hacky approach, so we could modify te.fp8_model_init so that the CPU copy is optional.
  • I wonder how much the problem is because the initial scaling factor is 1, which is likely too low and results in many underflows. One approach is to do the FP8 cast twice: once to figure out the amax and again with an optimal scaling factor. One problem is that this doesn't handle tensor parallelism well, since we want the amaxes to be synchronized over the TP group. We either need to have many TP max all-reduces (one per param) or we need to make structural changes in how we initialize FP8 params
  • If we have master weights in FP32, how about perturbing them to approximate the original distribution:
    • fp32_params = fp8_params.from_float8()
    • fp32_params += torch.abs(fp32_params) + (torch.rand_like(fp32_params) - 0.5) * fp8_eps
    • This is simple, but the numerics are subtly different and I'm not sure if it'll also affect convergence

I feel that method 1 (make CPU copy optional) and method 3 (add random perturbation to master weights) are more feasible methods. You can decide whether to adopt method 1, and I can go to test whether method 3 can help convergence.

However, for method 3, my concerns are:

  • The distribution of master weighs and fp8 weights may be inconsistent after adding random perturbations. Even if I make them consistent through some hard-coded method, if the initialization parameters of fp8 weights change in the future, their distribution will be different again.
  • Even if I test it and find it can help convergence, it may still not work on other models that I don't test, after all I can't test all models.
  • (In addition, I'm not sure whether the introduction of such random perturbations can be accepted by MCore.)

@ptrendx
Copy link
Member

ptrendx commented Aug 20, 2024

Right... The best option would have been to create the master weights first, but that is not really possible due to the API of pyTorch.

Ok, so let's maybe do this:

  • create the option in fp8_model_init to preserve_high_precision_initialization which would be the trigger to save the copy on the CPU. We should document it properly
  • add a function to the fp8 parameters to clear the high precision weights so that after they are stored in the master weights they can be freed properly

Then for the pretraining people can use this option, while for inference/Lora/etc where those weights come pretrained they will not incur the CPU memory overhead.

@kunlunl
Copy link
Contributor Author

kunlunl commented Aug 21, 2024

Ok, so let's maybe do this:

  • create the option in fp8_model_init to preserve_high_precision_initialization which would be the trigger to save the copy on the CPU. We should document it properly
  • add a function to the fp8 parameters to clear the high precision weights so that after they are stored in the master weights they can be freed properly

Ok, I'll do this.

@kunlunl kunlunl force-pushed the add_high_precision_init_val branch 2 times, most recently from f0e5850 to 5f2b65b Compare August 27, 2024 10:50
@kunlunl
Copy link
Contributor Author

kunlunl commented Aug 27, 2024

@ptrendx I've finished the revision, could you help to find someone to review it?

@kunlunl kunlunl force-pushed the add_high_precision_init_val branch from 5c649e9 to 669cd4d Compare August 27, 2024 11:27
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature is opt-in and simple, so overall it looks fine. Some warnings:

  • I'm still a bit iffy about putting this logic within TE, since it's a general problem with low-precision models. For example, you would experience a similar loss of precision if you initialized an FP16 model in plain PyTorch.
  • This feature is not supported in the experimental operation-based API (see https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/pytorch/ops)
  • A test would be helpful for making sure we don't accidentally break this feature.
  • Please sign your commit to get past the DCO check.

@kunlunl kunlunl force-pushed the add_high_precision_init_val branch from 9801e12 to 8866920 Compare August 28, 2024 08:19
@kunlunl kunlunl force-pushed the add_high_precision_init_val branch from ccb6b63 to db3e139 Compare August 28, 2024 08:24
@kunlunl
Copy link
Contributor Author

kunlunl commented Aug 28, 2024

Added a unit test and signed off my commit.

@kunlunl
Copy link
Contributor Author

kunlunl commented Sep 3, 2024

@timmoon10 Can you take a look and merge this?

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense overall and the use case is clear, but a couple of thoughts:

  1. Unsure about the name preserve_high_precision_init_val. It is explicit and only the initial weight is stored without being updated, but maybe we could think of something clearer since this is a documented arg.
  2. Are we reliant on the user to clear the additional memory? Alternate could to free this memory during the forward pass or something, if the only use here is to initialize master weights.

@kunlunl
Copy link
Contributor Author

kunlunl commented Sep 3, 2024

Are we reliant on the user to clear the additional memory? Alternate could to free this memory during the forward pass or something, if the only use here is to initialize master weights.

@ksivaman Yes, in my previous idea, there were two ways to clear this variable: 1. Clear the variable inside the get_xxx method, so that the memory is automatically reclaimed after the user accesses it once; 2. The user manually clears the variable by calling clear_xxx method. I choose the latter because I found that when I used it in MCore, I needed to access it more than once..

I think what you said makes a lot of sense, but besides automatically reclaiming this resource after first forward, should we still keep the manual delete method? This way, users can reclaim this resources in advance when needed, avoiding some corner cases (for example, running out of cpu memory during the first forward, although I'm not sure if this will happen...)

Also, do you have any suggestions on where to put the code for "automatically reclaiming this resource after the first forward"? Should I put it in the forward() of each module? (This need to modify the forward() code of each module).

but maybe we could think of something clearer since this is a documented arg.

Do you have any suggestions for this also?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants