Skip to content

Commit

Permalink
enable softcap and gemma2 (#288)
Browse files Browse the repository at this point in the history
* enable softcap for gemma2

* fix lint

* restore fa

* restore accidental deletion

* fix logits_soft_cap constructor

* use 0.0 instead of 0
  • Loading branch information
hliuca authored Dec 4, 2024
1 parent 0cee60d commit 18ef0a0
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ def __init__(
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")

if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0
else:
self.logits_soft_cap = logits_soft_cap

self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -487,6 +490,14 @@ def __init__(
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
"logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")

from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
Expand All @@ -511,6 +522,11 @@ def __init__(
self.use_naive_attn = True

if self.use_naive_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Naive FlashAttention does not support"
"attention logits soft capping.")

self.attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")

Expand Down Expand Up @@ -717,6 +733,7 @@ def forward(
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)

# common code for prefill
Expand Down

0 comments on commit 18ef0a0

Please sign in to comment.