Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions on DotProductAttention API Usage in Flash Attention thd Mode #1409

Open
pipSu opened this issue Jan 14, 2025 · 1 comment
Open
Assignees

Comments

@pipSu
Copy link

pipSu commented Jan 14, 2025

We are using Megatron-LM with TE (Transformer Engine) for Flash Attention, specifically in the THD mode, and we have some questions about the API usage.

  1. What is the specific difference between cu_seqlens_q and cu_seqlens_q_padded?
    From the documentation and example code), it seems that both parameters pass the padded values. What is the internal handling difference between them? what means sequence lengths with/without offset.

  2. We are conducting SFT (Supervised Fine-Tuning) training and aim to construct the

    attention_mask = causal_inputs_mask * padding_mask * segment_mask

    However, we are encountering difficulties in ensuring the accuracy of padding_mask and causal_inputs_mask, examples when tokens are padded to 2CP.
    such as cp=2 sequences [1, 2, 3, pad, 4, 5, pad, pad]. Currently, both cu_seqlens_q and cu_seqlens_q_padded are set to [0, 4, 8]. Our attempts to address this issue by setting cu_seqlens_q differently from cu_seqlens_q_padded have consistently resulted in NaN errors. How should we correctly set the Attention Mask to handle these padding tokens?

@cyanguwa
Copy link
Collaborator

@pipSu , please take a look here for the differences between cu_seqlens_q and cu_seqlens_q_padded.

In your case, I believe you should use cu_seqlens_q=[0, 3, 5] and cu_seqlens_q_padded=[0,4,8]. If your intended mask is causal, you can just pass in attn_mask_type=padding_causal and attention_mask=None (which is the default).

The FlashAttention backend doesn't support THD + padding between sequences, so you will be running with the FusedAttention backend. If you run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1, you should be able to see Running with FusedAttention backend (sub-backend 1) in the logging, rather than Running with FlashAttention backend (version 2.x.x).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants