Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable softcap and gemma2 #288

Merged
merged 9 commits into from
Dec 4, 2024
Merged

enable softcap and gemma2 #288

merged 9 commits into from
Dec 4, 2024

Conversation

hliuca
Copy link
Collaborator

@hliuca hliuca commented Nov 20, 2024

Gemma2 model needs softcap feature from flash attention.

@hliuca hliuca requested a review from gshtras November 20, 2024 21:57
@hliuca hliuca changed the title enable softcap for gemma2 enable softcap and gemma2 Dec 2, 2024
Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The linter warnings are due to attn_func always being called with the softcap parameter, but not all implementations support it.
Per the other comments please see to only using this parameter on the models that require it

@@ -218,12 +218,6 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this section being removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry... this is accident... it has been restored.


if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This effectively enables logits_soft_cap for any model, unconditionally.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -511,6 +515,11 @@ def __init__(
self.use_naive_attn = True

if self.use_naive_attn:
if logits_soft_cap is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You make sure it's not None in the constructor. So naive flash attention can never be used now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. thanks.

@hliuca hliuca requested a review from gshtras December 3, 2024 22:51
@gshtras gshtras merged commit 18ef0a0 into develop Dec 4, 2024
7 of 8 checks passed
@gshtras gshtras deleted the softcap_fix branch December 4, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants