-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable softcap and gemma2 #288
Changes from 5 commits
84d1387
1348c23
e5cf3da
f0ce486
b6a8200
9242621
8cdb96f
566ebdb
0a77cd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,12 +218,6 @@ | |
max_encoder_seq_len=self.max_encoder_seq_len, | ||
cross_slot_mapping=self.cross_slot_mapping, | ||
cross_block_tables=self.cross_block_tables) | ||
# Batch may be composed of prefill|decodes, adjust query start indices | ||
# to refer to the start of decodes when the two are split apart. | ||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. | ||
if self._cached_decode_metadata.query_start_loc is not None: | ||
qs = self._cached_decode_metadata.query_start_loc | ||
self._cached_decode_metadata.query_start_loc = qs - qs[0] | ||
return self._cached_decode_metadata | ||
|
||
def advance_step(self, | ||
|
@@ -459,10 +453,12 @@ | |
if blocksparse_params is not None: | ||
raise ValueError( | ||
"ROCmFlashAttention does not support blocksparse attention.") | ||
if logits_soft_cap is not None: | ||
raise ValueError( | ||
"ROCmFlashAttention does not support attention logits soft " | ||
"capping.") | ||
|
||
if logits_soft_cap is None: | ||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. | ||
logits_soft_cap = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This effectively enables logits_soft_cap for any model, unconditionally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In flash attention, it is 0 by default, I think. https://github.com/ROCm/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1334 |
||
self.logits_soft_cap = logits_soft_cap | ||
|
||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
|
@@ -487,6 +483,14 @@ | |
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. | ||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN | ||
if self.use_triton_flash_attn: | ||
if logits_soft_cap is not None: | ||
raise ValueError( | ||
"ROCm Triton FlashAttention does not support attention" | ||
"logits soft capping." | ||
" please try using the ROCm CK " | ||
"FA backend instead by setting the env var " | ||
"`VLLM_USE_TRITON_FLASH_ATTN=0`") | ||
|
||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 | ||
triton_attention) | ||
self.attn_func = triton_attention | ||
|
@@ -505,12 +509,17 @@ | |
else: | ||
try: | ||
from flash_attn import flash_attn_varlen_func # noqa: F401 | ||
self.attn_func = flash_attn_varlen_func | ||
Check failure on line 512 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.9)
Check failure on line 512 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.10)
Check failure on line 512 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.11)
|
||
logger.debug("Using CK FA in ROCmBackend") | ||
except ModuleNotFoundError: | ||
self.use_naive_attn = True | ||
|
||
if self.use_naive_attn: | ||
if logits_soft_cap is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You make sure it's not None in the constructor. So naive flash attention can never be used now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. thanks. |
||
raise ValueError( | ||
"ROCm Naive FlashAttention does not support" | ||
"attention logits soft capping.") | ||
|
||
self.attn_func = _sdpa_attention | ||
logger.debug("Using naive (SDPA) attention in ROCmBackend") | ||
|
||
|
@@ -663,7 +672,7 @@ | |
query.dtype, | ||
seq_lens, | ||
make_attn_mask=False) # type: ignore | ||
out, _ = self.attn_func( | ||
Check failure on line 675 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.9)
Check failure on line 675 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.10)
Check failure on line 675 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.11)
|
||
query, | ||
key, | ||
value, | ||
|
@@ -692,7 +701,7 @@ | |
key = key.movedim(0, key.dim() - 2) | ||
value = value.movedim(0, value.dim() - 2) | ||
# sdpa math backend attention | ||
out = self.attn_func( | ||
Check failure on line 704 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.9)
Check failure on line 704 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.10)
Check failure on line 704 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.11)
|
||
query, | ||
key, | ||
value, | ||
|
@@ -705,7 +714,7 @@ | |
attn_masks, | ||
) | ||
else: | ||
out = self.attn_func( | ||
Check failure on line 717 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.9)
Check failure on line 717 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.10)
Check failure on line 717 in vllm/attention/backends/rocm_flash_attn.py GitHub Actions / mypy (3.11)
|
||
q=query, | ||
k=key, | ||
v=value, | ||
|
@@ -717,6 +726,7 @@ | |
causal=True, | ||
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
softcap=self.logits_soft_cap, | ||
) | ||
|
||
# common code for prefill | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this section being removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry... this is accident... it has been restored.