diff --git a/deeplink_ext/easyllm_ops/rms_norm_fallback.py b/deeplink_ext/easyllm_ops/rms_norm_fallback.py index 73e4047..80e9594 100644 --- a/deeplink_ext/easyllm_ops/rms_norm_fallback.py +++ b/deeplink_ext/easyllm_ops/rms_norm_fallback.py @@ -6,10 +6,8 @@ def rms_norm_torch(x, weight, epsilon): + input_dtype = x.dtype variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = x * torch.rsqrt(variance + epsilon) - if weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(weight.dtype) - - return hidden_states * weight + return (hidden_states * weight).to(input_dtype) diff --git a/tests/core.py b/tests/core.py index 22938e8..ffa1ca1 100644 --- a/tests/core.py +++ b/tests/core.py @@ -68,8 +68,8 @@ def allclose(expected_vals: list, real_vals: list, rtol=1e-05, atol=1e-08): if isinstance(expected_vals[i], torch.Tensor): assert isinstance(real_vals[i], torch.Tensor) return torch.allclose( - expected_vals[i].to(real_vals[i].dtype).cpu(), - real_vals[i].cpu(), + expected_vals[i].cpu().to(torch.float32), + real_vals[i].cpu().to(torch.float32), rtol, atol, )