diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index 65d22f91..f06d7eb3 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -148,7 +148,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore _find_or_mock_module("rotary_emb") _find_or_mock_module("fused_dense_lib") _find_or_mock_module("xentropy_cuda_lib") - _find_or_mock_module("flash_attn_cuda") + _find_or_mock_module("flash_attn_2_cuda") _find_flash_attn() if force_fallback: _force_fallback()