Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jul 30, 2024
1 parent d6823de commit 057df1a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
6 changes: 2 additions & 4 deletions deeplink_ext/easyllm_ops/rms_norm_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@


def rms_norm_torch(x, weight, epsilon):
input_dtype = x.dtype
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = x * torch.rsqrt(variance + epsilon)

if weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(weight.dtype)

return hidden_states * weight
return (hidden_states * weight).to(input_dtype)
4 changes: 2 additions & 2 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def allclose(expected_vals: list, real_vals: list, rtol=1e-05, atol=1e-08):
if isinstance(expected_vals[i], torch.Tensor):
assert isinstance(real_vals[i], torch.Tensor)
return torch.allclose(
expected_vals[i].to(real_vals[i].dtype).cpu(),
real_vals[i].cpu(),
expected_vals[i].cpu().to(torch.float32),
real_vals[i].cpu().to(torch.float32),
rtol,
atol,
)
Expand Down

0 comments on commit 057df1a

Please sign in to comment.