diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 82142e17..793bc02e 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -28,7 +28,7 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - + ext.rms_norm_backward( grad_input, grad_weight,