You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When testing the fused_attn_fwd_qkvpacked function in transformer_engine.pytorch.cpp_extensions.fused_attn, the head dimension is dropped in the output if the input QKV matrix has layout "t3hd" and h=3 or h=7
Hi @ajayjain , thanks for raising this issue. I understand how the h=3 case could be misinterpreted by TE, but I don't think h=7 should be a problem even before the fix. Could you please give PR 1187 a try and let me know if you still have issues with either case please? Thanks.
When testing the
fused_attn_fwd_qkvpacked
function intransformer_engine.pytorch.cpp_extensions.fused_attn
, the head dimension is dropped in the output if the input QKV matrix has layout "t3hd" and h=3 or h=7Minimal reproducible example:
It should print [1024, 3, 128] instead.
h=1, 2, 4, 5, 6, and 8 work.
Would it be possible to support the h=3 case?
Thanks!
The text was updated successfully, but these errors were encountered: