Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMSNorm precision different from HF implementation #1132

Open
void-main opened this issue Aug 23, 2024 · 5 comments
Open

RMSNorm precision different from HF implementation #1132

void-main opened this issue Aug 23, 2024 · 5 comments

Comments

@void-main
Copy link

We noticed there's a tiny implementation difference that makes transformer_engine.pytorch.module.rmsnorm and also TELayerNormColumnParallelLinear generate results from HF implementation.

And the tiny difference is when the hidden_states are converted back to bfloat16. Here's the gap:
CleanShot 2024-08-23 at 22 40 51@2x

  • the red line is native HF implementation, and converts hidden_states to bfloat16 before multiply weight, and TENorm's result is different form this implementation
  • the green line implementation matches the TENorm's implementation, and converts hidden_states to bfloat16 after multiply weights.

We wonder if TE could provide an other to match the HF's implementation, which converts hidden_states to bfloat16 before multiply the weights. Thanks.

How to reproduce

Version: transformer-engine 1.7.0+4e7caa1

Code to reproduce:

import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm

class TestLayerNormComparison(unittest.TestCase):
    def setUp(self):
        self.hidden_size = 4096
        self.batch_size = 1
        self.seq_length = 1024
        self.eps = 1e-5

        self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))

        self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
        self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)

        self.te_layernorm.weight = self.shared_weight
        self.hf_rmsnorm.weight = self.shared_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.te_layernorm.to(self.device)
        self.hf_rmsnorm.to(self.device)

    def test_layernorm_comparison(self):
        input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
                                   dtype=torch.bfloat16, device=self.device)

        with torch.no_grad():
            te_output = self.te_layernorm(input_tensor)
            hf_output = self.hf_rmsnorm(input_tensor)

        assert torch.allclose(te_output, hf_output, atol=1e-2)

if __name__ == '__main__':
    unittest.main()

First define HFRMSNorm with native implementation:

import torch
from torch import nn

class HFRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, config=None):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

The assertion should fail when we run the code with this implementation.

Now, let's change the last line from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.

@ptrendx
Copy link
Member

ptrendx commented Aug 23, 2024

That is correct, both RMSNorm and LayerNorm in TE perform all internal computation in FP32 (and so e.g. TE LayerNorm is equivalent to

x = x.to(torch.float32)
y = nn.LayerNorm(x)
y = y.to(torch.bfloat16)

The reason for that is to preserve precision of the computation, especially since RMSNorm/LayerNorm weights are typically close to 1.
This is especially important and visible with zero_centered_gamma option, which initializes the weight to 0 and adds 1 to it inside the normalization operation itself. Since the floating point numbers are the most precise around 0, and because bfloat16 does not have many mantissa bits, adding 1 to it in precision other than float32 results in losing most of that precision - see e.g. this example:

>>> import torch                                                                                                                                                                                                   
>>> a = torch.Tensor([0.003]).to(torch.bfloat16)                                                                                                                                                                   
>>> a                                                                                                                                                                                                              
tensor([0.0030], dtype=torch.bfloat16)                                                                                                                                                                             
>>> a + 1                                                                                                                                                                                                          
tensor([1.], dtype=torch.bfloat16)

Based on this, I would argue that it is actually HF implementation that is wrong here.

@void-main
Copy link
Author

@ptrendx Thanks for your reply. I totally agree that we should use float32 to do all the calculations, in theory.

However, we're not training from scratch. We're continuous training open source models like Llama3 and Qwen2 with Megatron-LM, and if we compare logits generated by Megatron-LM and HF transformers, the RMSNorm implementation difference will cause the logits to be very different. (80% of the hidden_states elements are different in numerics more than 0.01)

That's why I believe we should at least provide an option to align the RMSNorm with HF transformers?

@ptrendx
Copy link
Member

ptrendx commented Aug 26, 2024

Yeah, I figured that's a probable reason for this ask. Could you open an issue in HF transformers repo as well then? It would be interesting to hear their opinion on the topic and also raise their awareness to, hopefully, align the implementations to the right precisions with new models going forward.

I need to think how to expose that option. In the meantime - if you wanted to change TE implementation yourself to do the multiplication in the lower precision you would need to change
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L108-L109 and https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L248-L249 to use Ktraits::weight_t rather than compute_t type.

@void-main
Copy link
Author

Great, thank you @ptrendx ! I'll try to change the code myself.

Besides, here's the issue on HF: huggingface/transformers#33133

@fjosw
Copy link
Contributor

fjosw commented Aug 29, 2024

We just stumbled upon this issue and compared the implementation of the RMSNorm between TransformerEngine and TensorRT-LLM. It looks like TensorRT-LLM does the weight multiplication in lower precision, consistent with the HF transformers implementation. This likely means that a model trained with TransformerEngine will produce (at least slightly) different outputs when inferenced with TensorRT-LLM.

I agree with @ptrendx that performing the operation in higher precision sounds sensible but I think it would be useful to have the option to align implementations across Nvidia's stack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants