From 18ef0a083f0eb8d94707496e0c792283fa555003 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:59:33 -0800 Subject: [PATCH] enable softcap and gemma2 (#288) * enable softcap for gemma2 * fix lint * restore fa * restore accidental deletion * fix logits_soft_cap constructor * use 0.0 instead of 0 --- vllm/attention/backends/rocm_flash_attn.py | 25 ++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5f28a1bbd75fd..b4f4e5bb1500a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -459,10 +459,13 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - if logits_soft_cap is not None: - raise ValueError( - "ROCmFlashAttention does not support attention logits soft " - "capping.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -487,6 +490,14 @@ def __init__( # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention @@ -511,6 +522,11 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") + self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") @@ -717,6 +733,7 @@ def forward( causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) # common code for prefill