-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add high_precision_init_val to model params when using fp8_model_init #1121
base: main
Are you sure you want to change the base?
Conversation
Hmmm, I see the problem you are trying to solve, although I don't think I like the approach (don't think I have an alternative ready yet though :-( ). Did you consider any other ways to solve this issue? |
I also feel it is a very ugly approach, but I can't think of a better way to do it. Then I asked @timmoon10 if he has any insight, I quote his ideas here, @timmoon10 you can add more comments if you have.
I feel that method 1 (make CPU copy optional) and method 3 (add random perturbation to master weights) are more feasible methods. You can decide whether to adopt method 1, and I can go to test whether method 3 can help convergence. However, for method 3, my concerns are:
|
Right... The best option would have been to create the master weights first, but that is not really possible due to the API of pyTorch. Ok, so let's maybe do this:
Then for the pretraining people can use this option, while for inference/Lora/etc where those weights come pretrained they will not incur the CPU memory overhead. |
Ok, I'll do this. |
f0e5850
to
5f2b65b
Compare
@ptrendx I've finished the revision, could you help to find someone to review it? |
5c649e9
to
669cd4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feature is opt-in and simple, so overall it looks fine. Some warnings:
- I'm still a bit iffy about putting this logic within TE, since it's a general problem with low-precision models. For example, you would experience a similar loss of precision if you initialized an FP16 model in plain PyTorch.
- This feature is not supported in the experimental operation-based API (see https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/pytorch/ops)
- A test would be helpful for making sure we don't accidentally break this feature.
- Please sign your commit to get past the DCO check.
9801e12
to
8866920
Compare
Signed-off-by: kunlunl <[email protected]>
ccb6b63
to
db3e139
Compare
for more information, see https://pre-commit.ci
Added a unit test and signed off my commit. |
@timmoon10 Can you take a look and merge this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense overall and the use case is clear, but a couple of thoughts:
- Unsure about the name
preserve_high_precision_init_val
. It is explicit and only the initial weight is stored without being updated, but maybe we could think of something clearer since this is a documented arg. - Are we reliant on the user to clear the additional memory? Alternate could to free this memory during the forward pass or something, if the only use here is to initialize master weights.
@ksivaman Yes, in my previous idea, there were two ways to clear this variable: 1. Clear the variable inside the I think what you said makes a lot of sense, but besides automatically reclaiming this resource after first forward, should we still keep the manual delete method? This way, users can reclaim this resources in advance when needed, avoiding some corner cases (for example, running out of cpu memory during the first forward, although I'm not sure if this will happen...) Also, do you have any suggestions on where to put the code for "automatically reclaiming this resource after the first forward"? Should I put it in the
Do you have any suggestions for this also? |
Description
When using
fp8_model_init
to create a model, the weights will be casted toFloat8Tensor
. However, in scenarios where high-precision (FP32) master weights are needed, initializing the master weights with these FP8 weights can affect the loss convergence compared to using bf16/fp16 to initialize master weights (especially in the early stages of training).This PR stores the original bf16/fp16 params as cpu tensors within the FP8 weights, which can be used to initialize master weights in other frameworks like MCore.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: