diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 316b3c9d..35915e2d 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -41,5 +41,9 @@ 1e-6 ) +print("Output:", output) +print("Grad Input:", grad_input) +print("Grad Weight:", grad_weight) +print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight assert torch.allclose(output, b)