Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a CP implementation variant with KV all-gather. (#1060)
* add window_size to AttnFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * add seq_offsets_qkvo for cudnn thd Signed-off-by: Xiaowei Ren <[email protected]> * add seq_offsets_qkvo to AttnFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * fix seq_offsets calculation of cudnn thd Signed-off-by: Xiaowei Ren <[email protected]> * remove a thd assert Signed-off-by: Xiaowei Ren <[email protected]> * fix bias for thd test Signed-off-by: Xiaowei Ren <[email protected]> * add thd test for cudnn FA with CP Signed-off-by: Xiaowei Ren <[email protected]> * skip GQA/MQA test for cuDNN THD Signed-off-by: Xiaowei Ren <[email protected]> * make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1 Signed-off-by: Xiaowei Ren <[email protected]> * fix seq_offsets inputs Signed-off-by: Xiaowei Ren <[email protected]> * remove two comments Signed-off-by: Xiaowei Ren <[email protected]> * fix attn mask type for cudnn thd with cp Signed-off-by: Xiaowei Ren <[email protected]> * fix attn_mask_type check Signed-off-by: Xiaowei Ren <[email protected]> * fix attn_mask_type for cudnn fa with thd Signed-off-by: Xiaowei Ren <[email protected]> * fix a typo Signed-off-by: Xiaowei Ren <[email protected]> * fix out dout in bwd Signed-off-by: Xiaowei Ren <[email protected]> * assert cudnn+thd does not support attn bias Signed-off-by: Xiaowei Ren <[email protected]> * check if attn_mask_type has padding Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * change cp test batch size to 2 Signed-off-by: Xiaowei Ren <[email protected]> * fix code format Signed-off-by: Xiaowei Ren <[email protected]> * fix two assert info Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comment Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comments Signed-off-by: Xiaowei Ren <[email protected]> * minor fix Signed-off-by: Xiaowei Ren <[email protected]> * fix assert comments Signed-off-by: Xiaowei Ren <[email protected]> * assert swa+CP cannot work with thd format Signed-off-by: Xiaowei Ren <[email protected]> * add a new CP function for swa Signed-off-by: Xiaowei Ren <[email protected]> * add a missing dgrads Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * add draft fwd function for swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * enable flash attention for swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * remove an assert of swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * call SWAFuncWithCP for swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * typo fix Signed-off-by: Xiaowei Ren <[email protected]> * use 2hd layout Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change qkv_format check Signed-off-by: Xiaowei Ren <[email protected]> * add a code comment Signed-off-by: Xiaowei Ren <[email protected]> * tensor shape bug fix Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor shape fix Signed-off-by: Xiaowei Ren <[email protected]> * add function to compute cu_seqlens of a cp rank Signed-off-by: Xiaowei Ren <[email protected]> * add cu_seqlens and cu_seqlens_padded to context parallelism Signed-off-by: Xiaowei Ren <[email protected]> * typo fix Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * fix FlashAttention output sequence length Signed-off-by: Xiaowei Ren <[email protected]> * fix cu_seqlens_kv_per_step calculation Signed-off-by: Xiaowei Ren <[email protected]> * zero dQKV for ending padded tokens Signed-off-by: Xiaowei Ren <[email protected]> * zero dQKV tensors of FlashAttention Signed-off-by: Xiaowei Ren <[email protected]> * fix softmax_lse correction Signed-off-by: Xiaowei Ren <[email protected]> * remove padded tokens of KV to save comounication Signed-off-by: Xiaowei Ren <[email protected]> * do not need to zero dkv for FlashAttention any mroe Signed-off-by: Xiaowei Ren <[email protected]> * zero out tensors Signed-off-by: Xiaowei Ren <[email protected]> * remove redundant code Signed-off-by: Xiaowei Ren <[email protected]> * fix CP unit test Signed-off-by: Xiaowei Ren <[email protected]> * fix kv shape of cp test with thd format Signed-off-by: Xiaowei Ren <[email protected]> * update cp unit test Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add simple code framework Signed-off-by: Xiaowei Ren <[email protected]> * try not to have a separate CP function for SWA Signed-off-by: Xiaowei Ren <[email protected]> * backup some code change Signed-off-by: Xiaowei Ren <[email protected]> * back up code Signed-off-by: Xiaowei Ren <[email protected]> * clean up fwd implementation of SWAFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * remove redundant code Signed-off-by: Xiaowei Ren <[email protected]> * code cleaning Signed-off-by: Xiaowei Ren <[email protected]> * fix assert info Signed-off-by: Xiaowei Ren <[email protected]> * reduce kv chunk concat overheads Signed-off-by: Xiaowei Ren <[email protected]> * minor change Signed-off-by: Xiaowei Ren <[email protected]> * make AttnFuncWithCP and SWAFuncWithCP have same API Signed-off-by: Xiaowei Ren <[email protected]> * add a docstring Signed-off-by: Xiaowei Ren <[email protected]> * preliminary implementation of SWAFuncWithCP forward seems working Signed-off-by: Xiaowei Ren <[email protected]> * fix output shape of SWAFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * code refactoring for FlashAttention and add a code placeholder for bwd Signed-off-by: Xiaowei Ren <[email protected]> * use gather_along_first_dim Signed-off-by: Xiaowei Ren <[email protected]> * finish the preliminary implementation of bwd Signed-off-by: Xiaowei Ren <[email protected]> * remove redundant code Signed-off-by: Xiaowei Ren <[email protected]> * fix assert condition Signed-off-by: Xiaowei Ren <[email protected]> * add draft implementation of SWA+CP with FusedAttention Signed-off-by: Xiaowei Ren <[email protected]> * fix attention mask type of swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * code cleaning Signed-off-by: Xiaowei Ren <[email protected]> * add qkv_layout Signed-off-by: Xiaowei Ren <[email protected]> * add missing window_size argument Signed-off-by: Xiaowei Ren <[email protected]> * typo fix Signed-off-by: Xiaowei Ren <[email protected]> * fix kv shape of swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * bug and typo fix Signed-off-by: Xiaowei Ren <[email protected]> * fix dout shape Signed-off-by: Xiaowei Ren <[email protected]> * add multi stream in fwd of swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * save chunk_ids_to_kv_ag in fwd Signed-off-by: Xiaowei Ren <[email protected]> * add multi stream in bwd of swa+cp Signed-off-by: Xiaowei Ren <[email protected]> * minor fix to cp stream sync Signed-off-by: Xiaowei Ren <[email protected]> * rename AttnFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * check if window size is None Signed-off-by: Xiaowei Ren <[email protected]> * fix docstring of AttnFuncWithCP Signed-off-by: Xiaowei Ren <[email protected]> * minor fix Signed-off-by: Xiaowei Ren <[email protected]> * add env var for users to choose KV ag or KV p2p Signed-off-by: Xiaowei Ren <[email protected]> * update cp tests Signed-off-by: Xiaowei Ren <[email protected]> * fix window size in cp unit test Signed-off-by: Xiaowei Ren <[email protected]> * fix pytest skip messages Signed-off-by: Xiaowei Ren <[email protected]> * add cp_comm_type into API Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code cleaning Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * assert sequence length divisible requirements Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support table of context parallelism Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo and code format fix Signed-off-by: Xiaowei Ren <[email protected]> * do not print multiple disabling messages Signed-off-by: Xiaowei Ren <[email protected]> * bug fix Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix device in torch.arange and adjust code for the PR of MLA Signed-off-by: Xiaowei Ren <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typos and clean asserts Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xiaowei Ren <[email protected]>
- Loading branch information