From f4208836a56b2f376b78230f0a20b6dd1d64b47c Mon Sep 17 00:00:00 2001 From: Lingjie Li Date: Wed, 13 Mar 2024 15:33:03 +0800 Subject: [PATCH] fix(internevo): mock flash_attn_2_cuda InternEVO uses FlashAttention 2.2.1, where the CUDA module is renamed from flash_attn_cuda to flash_attn_2_cuda. This commit mocks the correct module name. --- deeplink_ext/patch_internlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index 65d22f91..f06d7eb3 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -148,7 +148,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore _find_or_mock_module("rotary_emb") _find_or_mock_module("fused_dense_lib") _find_or_mock_module("xentropy_cuda_lib") - _find_or_mock_module("flash_attn_cuda") + _find_or_mock_module("flash_attn_2_cuda") _find_flash_attn() if force_fallback: _force_fallback()