-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Fix get_swa_mask() for padding masks #1281
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
/te-ci pytorch |
Hi @cyanguwa, if "padding" in attn_mask_type:
if max_seqlen_q == max_seqlen_kv:
attention_mask = torch.logical_or(
> attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
E AttributeError: 'tuple' object has no attribute 'squeeze' The code in |
Yes, I think I should use Let me know if you observe any other issues too! :) Thanks! |
is applied, the bottom right corner comes from the [actual_seqlen_q[i], actual_seqlen_kv[i]] matrix, | ||
for each batch i, not the [max_seqlen_q, max_seqlen_kv] matrix.:: | ||
|
||
attn_mask_type output shape diagonal alignment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
Description
This PR fixes the mask generation for sliding window in
UnfusedDotProductAttention
. It fixes the logic for padding and arbitrary masks inget_swa_mask()
, adds more docstring, refactors the call site, and adds more testing in the unit tests.Fixes #1271
Type of change
Changes
Please list the changes introduced in this PR:
get_swa_mask()
and its call siteChecklist: