diff --git a/deeplink_ext/common/rms_norm/__init__.py b/deeplink_ext/common/rms_norm/__init__.py index 81b96b65..f91583f3 100644 --- a/deeplink_ext/common/rms_norm/__init__.py +++ b/deeplink_ext/common/rms_norm/__init__.py @@ -1,4 +1,4 @@ from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward -all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] \ No newline at end of file +all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] diff --git a/deeplink_ext/common/rms_norm/deeplink.py b/deeplink_ext/common/rms_norm/deeplink.py index ac83a411..f196d297 100644 --- a/deeplink_ext/common/rms_norm/deeplink.py +++ b/deeplink_ext/common/rms_norm/deeplink.py @@ -76,4 +76,3 @@ def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bia ) return [grad_input, grad_weight, grad_bias] - diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index f97d31d4..f8c89d7e 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -52,6 +52,7 @@ def patch_token_softmax_reducev_inference(): def patch_rms_norm_lightllm(): from .common.rms_norm.deeplink import rms_norm + rms_norm_pack.rmsnorm_forward = rms_norm def patch_rotary_emb():