You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm
class TestLayerNormComparison(unittest.TestCase):
def setUp(self):
self.hidden_size = 4096
self.batch_size = 1
self.seq_length = 1024
self.eps = 1e-5
self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))
self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)
self.te_layernorm.weight = self.shared_weight
self.hf_rmsnorm.weight = self.shared_weight
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.te_layernorm.to(self.device)
self.hf_rmsnorm.to(self.device)
def test_layernorm_comparison(self):
input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
dtype=torch.bfloat16, device=self.device)
with torch.no_grad():
te_output = self.te_layernorm(input_tensor)
hf_output = self.hf_rmsnorm(input_tensor)
assert torch.allclose(te_output, hf_output, atol=1e-2)
if __name__ == '__main__':
unittest.main()
The assertion will fail.
Expected behavior
If we change the last line of HFRMSNorm from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.
We have a discussion here, and I agree that we should all internal computation in FP32. So what's your opinion on HF side?
The text was updated successfully, but these errors were encountered:
#23535 and #30236 are related. I am not sure we can change this without breaking anything at this point 🤗
Moreover the "original" code does not compute everything in float32. Now of course they don't cover extended usage.
I am down to update as long as there is a strong motivation: better performances, better MMLU scores, anything that can prove that this is a good idea!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.44.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
There's a implementation difference between HF transformers'
RMSNorm
and Nvidiatransformer_engine
'sRMSNorm
.Version:
transformer-engine 1.7.0+4e7caa1
First define
HFRMSNorm
code, which is copied frommodeling_llama
implementation fromtransformers
library.Next, run the test code:
The assertion will fail.
Expected behavior
If we change the last line of
HFRMSNorm
fromreturn self.weight * hidden_states.to(input_dtype)
to return(self.weight.to(torch.float32) * hidden_states).to(input_dtype)
, the assertion should pass.We have a discussion here, and I agree that we should all internal computation in FP32. So what's your opinion on HF side?
The text was updated successfully, but these errors were encountered: