-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update attention.py #1416
Update attention.py #1416
Conversation
modify the code about bigcode. This modification makes the KV cache with multiple new tokens works well.
Hi @DongHande, thank you for the PR. I will have a look shortly! |
Thank you @DongHande for the notice, this is indeed a significant bug in our code base. Passing a non-None # We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
if self.training:
is_causal = True
attn_mask = None
elif batch_size == 1 and query_length == 1:
is_causal = False
attn_mask = None
elif batch_size == 1 and kv_seq_len == query_length:
is_causal = True
attn_mask = None
elif attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)
# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attn_mask = torch.where(attention_mask, 0.0, mask_value)
is_causal = False
else:
attn_mask = None
is_causal = True
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
) WDYT? |
Thank you for your reply. I still have two questions: (1) I don't understand why we should consider batch_size == 1 here. The attn_mask has been calculated in the outer forward function. Why not use it directly? In other words, this function is a SDPA implementation to replace the attention operation of the Transformers Library (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L128-L203 ). In the Transformers library, it does not consider the situation of batch_size == 1. So why should consider it in the optimum library? (2) Maybe in your reply, the last sentense should be modified
to
? For the first question, it is likely you have some other reasons to write in this way, and you don't have to explain it if it has a long context and is hard to explain to save your time. |
What I am concerned about your proposed change is that it will never dispatch to FA/FA2.
|
OK. I have modified my PR according to your instruction. Please review and merge it. Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thank you!
I'll keep in mind to update other archs as well :)
modify the code about bigcode.
This modification makes the KV cache with multiple new tokens works well.
What does this PR do?
When we use the starcoder to generate text/code with KV cache and multiple new tokens, it becomes wrong because a possible error in the torch.nn.functional.scaled_dot_product_attention() function. I have proposed a issue in pytorch in pytorch/pytorch#110144. But before pytorch fix it, the optimum can work well with minor changes.
How to re-implement the error in current version:
Before submitting