Skip to content

Commit

Permalink
Update Flash-Attention to v2.5.6 (fairinternal/xformers#1044)
Browse files Browse the repository at this point in the history
Fixes build for windows as well

__original_commit__ = fairinternal/xformers@70e8ba7
  • Loading branch information
danthe3rd authored and xFormers Bot committed Mar 4, 2024
1 parent 44b0d07 commit 051b56a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/win-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: win-build
on:
pull_request:
paths:
- "third_party/**"
- "xformers/csrc/**"
- "third-party/**"
- ".github/workflows/win-build.yml"
- "setup.py"
- "requirements*.txt"
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- New merge_attentions function
### Improved
- fMHA: Updated Flash-Attention to v2.5.2: this has a performance improvement for multiquery.
- fMHA: Updated Flash-Attention to v2.5.6: this has a performance improvement for multiquery.
- fMHA: triton_splitk changed and expanded. Now amalgamates using LSE. Can autotune, supports causal with a small number of queries - not just 1. Experimental support for paged attention.
### Removed

Expand Down
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
if not nvcc_archs_flags:
return []

nvcc_windows_flags = []
if platform.system() == "Windows":
nvcc_windows_flags = ["-Xcompiler", "/permissive-"]

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include")
if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc):
Expand Down Expand Up @@ -224,7 +220,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
"--ptxas-options=-v",
]
+ nvcc_archs_flags
+ nvcc_windows_flags
+ common_extra_compile_args
+ get_extra_nvcc_flags_for_build_type(cuda_version),
},
Expand Down
18 changes: 10 additions & 8 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention

FLASH_VERSION = flash_attn.__version__
WANTED_FLASH_VERSION = (2, 5, 2)
FLASH_VER_MIN = (2, 5, 2)
FLASH_VER_LAST = (2, 5, 6) # last supported, inclusive
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
if (
flash_ver_parsed != WANTED_FLASH_VERSION
and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
):
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
) and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1":
raise ImportError(
f"Requires Flash attention {WANTED_FLASH_VERSION} for varlen_fwd api "
f"Requires Flash-Attention version >={'.'.join([str(i) for i in FLASH_VER_MIN])},"
f"<={'.'.join([str(i) for i in FLASH_VER_LAST])} "
f"but got {FLASH_VERSION}."
)

Expand Down Expand Up @@ -577,7 +578,7 @@ class BwOp(AttentionBwOpBase):
NAME = f"flshattB@{FLASH_VERSION}"
VERSION = FLASH_VERSION

MAX_HEADDIM_SM8x = 192
MAX_HEADDIM_DROPOUT_SM8x = 224

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
Expand All @@ -589,12 +590,13 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
if (
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_DROPOUT_SM8x
and not is_sm80_or_sm90
and d.p != 0.0
):
reasons.append(
"requires a GPU with compute capability 8.0 "
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
f"(A100) or 9.0 (H100) for dropout when 'query.shape[-1] > {cls.MAX_HEADDIM_DROPOUT_SM8x}'"
)
return reasons

Expand Down

0 comments on commit 051b56a

Please sign in to comment.