Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential RMSNorm precision issue #33133

Closed
2 of 4 tasks
void-main opened this issue Aug 27, 2024 · 3 comments
Closed
2 of 4 tasks

Potential RMSNorm precision issue #33133

void-main opened this issue Aug 27, 2024 · 3 comments
Labels

Comments

@void-main
Copy link

System Info

  • transformers version: 4.44.1
  • Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0a0+f70bd71a48.nv24.06 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: distributed
  • Using GPU in script?: YES

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

There's a implementation difference between HF transformers' RMSNorm and Nvidia transformer_engine's RMSNorm.

Version: transformer-engine 1.7.0+4e7caa1

First define HFRMSNorm code, which is copied from modeling_llama implementation from transformers library.

import torch
from torch import nn

class HFRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, config=None):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

Next, run the test code:

import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm

class TestLayerNormComparison(unittest.TestCase):
    def setUp(self):
        self.hidden_size = 4096
        self.batch_size = 1
        self.seq_length = 1024
        self.eps = 1e-5

        self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))

        self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
        self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)

        self.te_layernorm.weight = self.shared_weight
        self.hf_rmsnorm.weight = self.shared_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.te_layernorm.to(self.device)
        self.hf_rmsnorm.to(self.device)

    def test_layernorm_comparison(self):
        input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
                                   dtype=torch.bfloat16, device=self.device)

        with torch.no_grad():
            te_output = self.te_layernorm(input_tensor)
            hf_output = self.hf_rmsnorm(input_tensor)

        assert torch.allclose(te_output, hf_output, atol=1e-2)

if __name__ == '__main__':
    unittest.main()

The assertion will fail.

Expected behavior

If we change the last line of HFRMSNorm from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.

We have a discussion here, and I agree that we should all internal computation in FP32. So what's your opinion on HF side?

@LysandreJik
Copy link
Member

Thanks for opening an issue @void-main! cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

#23535 and #30236 are related. I am not sure we can change this without breaking anything at this point 🤗
Moreover the "original" code does not compute everything in float32. Now of course they don't cover extended usage.

I am down to update as long as there is a strong motivation: better performances, better MMLU scores, anything that can prove that this is a good idea!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants
@void-main @LysandreJik @ArthurZucker and others