Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_attn_fwd_qkvpacked silently doesn't support 3 or 7 heads #1182

Open
ajayjain opened this issue Sep 15, 2024 · 1 comment · May be fixed by #1187
Open

fused_attn_fwd_qkvpacked silently doesn't support 3 or 7 heads #1182

ajayjain opened this issue Sep 15, 2024 · 1 comment · May be fixed by #1187
Assignees

Comments

@ajayjain
Copy link

When testing the fused_attn_fwd_qkvpacked function in transformer_engine.pytorch.cpp_extensions.fused_attn, the head dimension is dropped in the output if the input QKV matrix has layout "t3hd" and h=3 or h=7

Minimal reproducible example:

import torch

from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    fused_attn_fwd_qkvpacked,
    FusedAttnBackend
)
from transformer_engine.pytorch.constants import TE_DType


total = 1024
num_heads = 3
head_dim = 128
qkv = torch.randn(total, 3, num_heads, head_dim, dtype=torch.bfloat16).cuda()
cu_seqlens = torch.tensor([0, 512, 1024], dtype=torch.int64).cuda()

output, _ = fused_attn_fwd_qkvpacked(
    False,
    512,  # max_seqlen
    cu_seqlens,
    qkv,
    TE_DType[qkv.dtype],  # qkv_dtype
    FusedAttnBackend["F16_arbitrary_seqlen"],  # fused_attention_backend,
    None,  # attn_bias
    cu_seqlens,  # cu_seqlens_padded,
    None,  # d_scale_qkv
    0,  # d_scale_qkv_offset
    None,  # d_scale_s
    0,  # d_scale_s_offset
    None,  # q_scale_s
    0,  # q_scale_s_offset
    None,  # q_scale_o
    0,  # q_scale_o_offset
    None,  # amax_s
    0,  # amax_s_offset
    None,  # amax_o
    0,  # amax_o_offset
    None,  # attn_scale
    0.0,  # dropout_p
    True,  # fast_zero_fill
    "t3hd",  # qkv_layout
    "no_bias",  # attn_bias_type
    "padding",  # attn_mask_type
    (-1, -1),  # window_size
    None,  # rng_gen
)

print(output.shape)  # Prints torch.Size([1024, 128])
assert output.shape == (total, num_heads, head_dim)  # Assertion fails

It should print [1024, 3, 128] instead.

h=1, 2, 4, 5, 6, and 8 work.

Would it be possible to support the h=3 case?

Thanks!

@cyanguwa cyanguwa linked a pull request Sep 16, 2024 that will close this issue
13 tasks
@cyanguwa cyanguwa self-assigned this Sep 16, 2024
@cyanguwa
Copy link
Collaborator

Hi @ajayjain , thanks for raising this issue. I understand how the h=3 case could be misinterpreted by TE, but I don't think h=7 should be a problem even before the fix. Could you please give PR 1187 a try and let me know if you still have issues with either case please? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants