Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention support softcap. #1013

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 11 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@

_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_max_version = PkgVersion("2.6.1")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_6_1_plus = _flash_attn_version >= PkgVersion("2.6.1")

if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
Expand Down Expand Up @@ -3090,6 +3091,7 @@ def __init__(
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
softcap: float = 0.0,
) -> None:
super().__init__()

Expand All @@ -3099,13 +3101,18 @@ def __init__(
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."
if softcap > 0.0:
assert (
_flash_attn_2_6_1_plus
), f"FlashAttention minimum version {PkgVersion("2.6.1")} is required for softcap."

self.softmax_scale = softmax_scale
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softcap = softcap

def forward(
self,
Expand Down Expand Up @@ -3286,6 +3293,7 @@ def forward(
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
softcap=self.softcap,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be added to the fa_optional_forward_kwargs instead? Otherwise the previous versions of FA will complain.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, it is my fault

**fa_optional_forward_kwargs,
)

Expand Down Expand Up @@ -5097,6 +5105,7 @@ def __init__(
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
softmax_scale: Optional[float] = None,
softcap: float = 0.0,
) -> None:
super().__init__()

Expand Down Expand Up @@ -5180,6 +5189,7 @@ def __init__(
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
softcap=softcap,
**attn_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"],
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.1,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"],
include_package_data=True,
package_data={
Expand Down