diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f71b469f2d..4c7833ca72 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -73,12 +73,13 @@ _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version_required = PkgVersion("2.0.6") -_flash_attn_max_version = PkgVersion("2.5.8") +_flash_attn_max_version = PkgVersion("2.6.3") _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") +_flash_attn_2_6_1_plus = _flash_attn_version >= PkgVersion("2.6.1") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func @@ -3090,6 +3091,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + softcap: float = 0.0, ) -> None: super().__init__() @@ -3099,6 +3101,10 @@ def __init__( assert ( _flash_attn_version <= _flash_attn_max_version ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." + if softcap > 0.0: + assert ( + _flash_attn_2_6_1_plus + ), f"FlashAttention minimum version 2.6.1 is required for softcap." self.softmax_scale = softmax_scale self.attention_dropout_ctx = attention_dropout_ctx @@ -3106,6 +3112,7 @@ def __init__( self.attention_type = attention_type self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.softcap = softcap def forward( self, @@ -3275,6 +3282,9 @@ def forward( fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes if _flash_attn_2_4_1_plus: fa_optional_forward_kwargs["deterministic"] = self.deterministic + if _flash_attn_2_6_1_plus: + fa_optional_forward_kwargs["softcap"] = self.softcap + output = flash_attn_forward_func( query_layer, key_layer, @@ -5097,6 +5107,7 @@ def __init__( cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, softmax_scale: Optional[float] = None, + softcap: float = 0.0, ) -> None: super().__init__() @@ -5180,6 +5191,7 @@ def __init__( attention_type=attention_type, layer_number=layer_number, deterministic=self.deterministic, + softcap=softcap, **attn_kwargs, ) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 9d0f24b478..af6323f794 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -56,7 +56,7 @@ description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], + install_requires=["torch", "flash-attn>=2.0.6,<=2.6.1,!=2.0.9,!=2.1.0"], tests_require=["numpy", "onnxruntime", "torchvision"], include_package_data=True, package_data={