From 051b56aa139700cf2f44a4b8e7b5d3cd095849ef Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Mon, 4 Mar 2024 16:11:41 +0000 Subject: [PATCH] Update Flash-Attention to v2.5.6 (fairinternal/xformers#1044) Fixes build for windows as well __original_commit__ = fairinternal/xformers@70e8ba7af49eb0fe5b9843c6b01652ffd79e0de3 --- .github/workflows/win-build.yml | 2 +- CHANGELOG.md | 2 +- setup.py | 5 ----- third_party/flash-attention | 2 +- xformers/ops/fmha/flash.py | 18 ++++++++++-------- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/.github/workflows/win-build.yml b/.github/workflows/win-build.yml index 4a39da3802..eb2f537b78 100644 --- a/.github/workflows/win-build.yml +++ b/.github/workflows/win-build.yml @@ -3,8 +3,8 @@ name: win-build on: pull_request: paths: + - "third_party/**" - "xformers/csrc/**" - - "third-party/**" - ".github/workflows/win-build.yml" - "setup.py" - "requirements*.txt" diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a3e3b56e..8c4e39b501 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - New merge_attentions function ### Improved -- fMHA: Updated Flash-Attention to v2.5.2: this has a performance improvement for multiquery. +- fMHA: Updated Flash-Attention to v2.5.6: this has a performance improvement for multiquery. - fMHA: triton_splitk changed and expanded. Now amalgamates using LSE. Can autotune, supports causal with a small number of queries - not just 1. Experimental support for paged attention. ### Removed diff --git a/setup.py b/setup.py index ed5fa52b61..163344bb5c 100644 --- a/setup.py +++ b/setup.py @@ -188,10 +188,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): if not nvcc_archs_flags: return [] - nvcc_windows_flags = [] - if platform.system() == "Windows": - nvcc_windows_flags = ["-Xcompiler", "/permissive-"] - flash_root = os.path.join(this_dir, "third_party", "flash-attention") cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): @@ -224,7 +220,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): "--ptxas-options=-v", ] + nvcc_archs_flags - + nvcc_windows_flags + common_extra_compile_args + get_extra_nvcc_flags_for_build_type(cuda_version), }, diff --git a/third_party/flash-attention b/third_party/flash-attention index 4687936413..6c9e60de56 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 468793641374d00a0cd7017cc0b2310b4c710376 +Subproject commit 6c9e60de566800538fedad2ad5e6b7b55ca7f0c5 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 8f5e7ee451..76e9371bca 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -47,14 +47,15 @@ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention FLASH_VERSION = flash_attn.__version__ - WANTED_FLASH_VERSION = (2, 5, 2) + FLASH_VER_MIN = (2, 5, 2) + FLASH_VER_LAST = (2, 5, 6) # last supported, inclusive flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) if ( - flash_ver_parsed != WANTED_FLASH_VERSION - and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" - ): + flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST + ) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1": raise ImportError( - f"Requires Flash attention {WANTED_FLASH_VERSION} for varlen_fwd api " + f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])}," + f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} " f"but got {FLASH_VERSION}." ) @@ -577,7 +578,7 @@ class BwOp(AttentionBwOpBase): NAME = f"flshattB@{FLASH_VERSION}" VERSION = FLASH_VERSION - MAX_HEADDIM_SM8x = 192 + MAX_HEADDIM_DROPOUT_SM8x = 224 @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: @@ -589,12 +590,13 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: device_capability = torch.cuda.get_device_capability(d.device) is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)] if ( - max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x + max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_DROPOUT_SM8x and not is_sm80_or_sm90 + and d.p != 0.0 ): reasons.append( "requires a GPU with compute capability 8.0 " - f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'" + f"(A100) or 9.0 (H100) for dropout when 'query.shape[-1] > {cls.MAX_HEADDIM_DROPOUT_SM8x}'" ) return reasons