diff --git a/deeplink_ext/internlm_ops/rms_norm/__init__.py b/deeplink_ext/internlm_ops/rms_norm/__init__.py index 36aac826..77909d1e 100644 --- a/deeplink_ext/internlm_ops/rms_norm/__init__.py +++ b/deeplink_ext/internlm_ops/rms_norm/__init__.py @@ -7,7 +7,10 @@ "[deeplink_ext] rms_norm is not implemented in diopi. Falling back to the slower implementation.\n", end="", ) - from .fallback import RMSNorm as DeepLinkRMSNorm + from .fallback import ( + RMSNorm as DeepLinkRMSNorm, + RMSNorm as DeepLinkRMSNormWithNormalizedShape, # TODO(lljbash): check how this works + ) from . import fallback __all__ = ["DeepLinkRMSNorm", "DeepLinkRMSNormWithNormalizedShape", "fallback"] diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 8513c74d..72bca9ef 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -12,14 +12,14 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): x_intern = x_base.clone() x_intern.retain_grad() - hidden_szie = 5 + hidden_size = 5 - model_base = BaseRmsNorm(hidden_szie).cuda() + model_base = BaseRmsNorm(hidden_size).cuda() out_base = model_base(x_base) out_base.backward(torch.ones_like(x_base)) grad_x_base = x_base.grad.cpu().numpy() - model_deeplink = DeeplinkRmsNorm(hidden_szie).cuda() + model_deeplink = DeeplinkRmsNorm(hidden_size).cuda() out_deeplink = model_deeplink(x_intern) out_deeplink.backward(torch.ones_like(x_base)) grad_x_intern = x_intern.grad.cpu().numpy()