Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 support to CP implementation with KV P2P #1114

Merged
merged 206 commits into from
Aug 21, 2024

Conversation

xrennvidia
Copy link
Collaborator

@xrennvidia xrennvidia commented Aug 15, 2024

Description

Enable FP8+CP with KV P2P implementation.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

xrennvidia and others added 30 commits May 30, 2024 19:09
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
xrennvidia and others added 6 commits August 14, 2024 22:32
@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

@ptrendx ptrendx added the 1.10.0 label Aug 16, 2024
@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa cyanguwa merged commit 26c8fcc into NVIDIA:main Aug 21, 2024
26 checks passed
@xrennvidia xrennvidia deleted the xren/cp_fp8 branch August 21, 2024 03:32
BeingGod pushed a commit to BeingGod/TransformerEngine that referenced this pull request Aug 30, 2024
* add window_size to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo for cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets calculation of cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove a thd assert

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias for thd test

Signed-off-by: Xiaowei Ren <[email protected]>

* add thd test for cudnn FA with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* skip GQA/MQA test for cuDNN THD

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets inputs

Signed-off-by: Xiaowei Ren <[email protected]>

* remove two comments

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn mask type for cudnn thd with cp

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type check

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type for cudnn fa with thd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix a typo

Signed-off-by: Xiaowei Ren <[email protected]>

* fix out dout in bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* assert cudnn+thd does not support attn bias

Signed-off-by: Xiaowei Ren <[email protected]>

* check if attn_mask_type has padding

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* change cp test batch size to 2

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

* fix two assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comment

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* assert swa+CP cannot work with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* add a new CP function for swa

Signed-off-by: Xiaowei Ren <[email protected]>

* add a missing dgrads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft fwd function for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* enable flash attention for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* remove an assert of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* call SWAFuncWithCP for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* use 2hd layout

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change qkv_format check

Signed-off-by: Xiaowei Ren <[email protected]>

* add a code comment

Signed-off-by: Xiaowei Ren <[email protected]>

* tensor shape bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor shape fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add function to compute cu_seqlens of a cp rank

Signed-off-by: Xiaowei Ren <[email protected]>

* add cu_seqlens and cu_seqlens_padded to context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix FlashAttention output sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* fix cu_seqlens_kv_per_step calculation

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV for ending padded tokens

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV tensors of FlashAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse correction

Signed-off-by: Xiaowei Ren <[email protected]>

* remove padded tokens of KV to save comounication

Signed-off-by: Xiaowei Ren <[email protected]>

* do not need to zero dkv for FlashAttention any mroe

Signed-off-by: Xiaowei Ren <[email protected]>

* zero out tensors

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of cp test with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add simple code framework

Signed-off-by: Xiaowei Ren <[email protected]>

* try not to have a separate CP function for SWA

Signed-off-by: Xiaowei Ren <[email protected]>

* backup some code change

Signed-off-by: Xiaowei Ren <[email protected]>

* back up code

Signed-off-by: Xiaowei Ren <[email protected]>

* clean up fwd implementation of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* reduce kv chunk concat overheads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* make AttnFuncWithCP and SWAFuncWithCP have same API

Signed-off-by: Xiaowei Ren <[email protected]>

* add a docstring

Signed-off-by: Xiaowei Ren <[email protected]>

* preliminary implementation of SWAFuncWithCP forward seems working

Signed-off-by: Xiaowei Ren <[email protected]>

* fix output shape of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* code refactoring for FlashAttention and add a code placeholder for bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* use gather_along_first_dim

Signed-off-by: Xiaowei Ren <[email protected]>

* finish the preliminary implementation of bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert condition

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft implementation of SWA+CP with FusedAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attention mask type of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* add qkv_layout

Signed-off-by: Xiaowei Ren <[email protected]>

* add missing window_size argument

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* bug and typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dout shape

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in fwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* save chunk_ids_to_kv_ag in fwd

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in bwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix to cp stream sync

Signed-off-by: Xiaowei Ren <[email protected]>

* rename AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* check if window size is None

Signed-off-by: Xiaowei Ren <[email protected]>

* fix docstring of AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add env var for users to choose KV ag or KV p2p

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp tests

Signed-off-by: Xiaowei Ren <[email protected]>

* fix window size in cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix pytest skip messages

Signed-off-by: Xiaowei Ren <[email protected]>

* add cp_comm_type into API

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add deterministic konb in cuDNN fused attn backend

Signed-off-by: Xiaowei Ren <[email protected]>

* pass fp8 and fp8_meta to attn_func_with_cp

Signed-off-by: Xiaowei Ren <[email protected]>

* assert only Fused Attn can support FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant assert

Signed-off-by: Xiaowei Ren <[email protected]>

* add a fwd draft implementation of FP8 + CP

Signed-off-by: Xiaowei Ren <[email protected]>

* save fp8 and fp8_meta

Signed-off-by: Xiaowei Ren <[email protected]>

* assert sequence length divisible requirements

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove a redundant qkv_layout compute

Signed-off-by: Xiaowei Ren <[email protected]>

* if condition change

Signed-off-by: Xiaowei Ren <[email protected]>

* some typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add support table of context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo and code format fix

Signed-off-by: Xiaowei Ren <[email protected]>

* do not print multiple disabling messages

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix aux_ctx_tensors of FP8

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix device in torch.arange and adjust code for the PR of MLA

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* commit code change for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* commit more code change for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* commit more fp8 code for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fixes

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* cast merged CP results from FP32 to BF16

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse

Signed-off-by: Xiaowei Ren <[email protected]>

* fix some bugs of FP8 dkv exchange

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add FP8 unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix typos and clean asserts

Signed-off-by: Xiaowei Ren <[email protected]>

* fix get_p2p_comm_info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dkv p2p exchange

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* change FP8 dkv P2P to A2A

Signed-off-by: Xiaowei Ren <[email protected]>

* add FP8+CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* assert amax reduction is needed for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove duplicated code

Signed-off-by: Xiaowei Ren <[email protected]>

* destroy process group in CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* remove interval from fp8_recipe because it has been deprecated

Signed-off-by: Xiaowei Ren <[email protected]>

* try to fix the failed CP test with the latest CI pipeline

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove redundant f before string

Signed-off-by: Xiaowei Ren <[email protected]>

* change META_O_CP

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xiaowei Ren <[email protected]>
Signed-off-by: beinggod <[email protected]>
ptrendx pushed a commit that referenced this pull request Aug 31, 2024
* add window_size to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo for cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* add seq_offsets_qkvo to AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets calculation of cudnn thd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove a thd assert

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias for thd test

Signed-off-by: Xiaowei Ren <[email protected]>

* add thd test for cudnn FA with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* skip GQA/MQA test for cuDNN THD

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_offsets inputs

Signed-off-by: Xiaowei Ren <[email protected]>

* remove two comments

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn mask type for cudnn thd with cp

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type check

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attn_mask_type for cudnn fa with thd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix a typo

Signed-off-by: Xiaowei Ren <[email protected]>

* fix out dout in bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* assert cudnn+thd does not support attn bias

Signed-off-by: Xiaowei Ren <[email protected]>

* check if attn_mask_type has padding

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* change cp test batch size to 2

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

* fix two assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comment

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert comments

Signed-off-by: Xiaowei Ren <[email protected]>

* assert swa+CP cannot work with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* add a new CP function for swa

Signed-off-by: Xiaowei Ren <[email protected]>

* add a missing dgrads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft fwd function for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* enable flash attention for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* remove an assert of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* call SWAFuncWithCP for swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* use 2hd layout

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change qkv_format check

Signed-off-by: Xiaowei Ren <[email protected]>

* add a code comment

Signed-off-by: Xiaowei Ren <[email protected]>

* tensor shape bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor shape fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add function to compute cu_seqlens of a cp rank

Signed-off-by: Xiaowei Ren <[email protected]>

* add cu_seqlens and cu_seqlens_padded to context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix FlashAttention output sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* fix cu_seqlens_kv_per_step calculation

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV for ending padded tokens

Signed-off-by: Xiaowei Ren <[email protected]>

* zero dQKV tensors of FlashAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse correction

Signed-off-by: Xiaowei Ren <[email protected]>

* remove padded tokens of KV to save comounication

Signed-off-by: Xiaowei Ren <[email protected]>

* do not need to zero dkv for FlashAttention any mroe

Signed-off-by: Xiaowei Ren <[email protected]>

* zero out tensors

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of cp test with thd format

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add simple code framework

Signed-off-by: Xiaowei Ren <[email protected]>

* try not to have a separate CP function for SWA

Signed-off-by: Xiaowei Ren <[email protected]>

* backup some code change

Signed-off-by: Xiaowei Ren <[email protected]>

* back up code

Signed-off-by: Xiaowei Ren <[email protected]>

* clean up fwd implementation of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert info

Signed-off-by: Xiaowei Ren <[email protected]>

* reduce kv chunk concat overheads

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* make AttnFuncWithCP and SWAFuncWithCP have same API

Signed-off-by: Xiaowei Ren <[email protected]>

* add a docstring

Signed-off-by: Xiaowei Ren <[email protected]>

* preliminary implementation of SWAFuncWithCP forward seems working

Signed-off-by: Xiaowei Ren <[email protected]>

* fix output shape of SWAFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* code refactoring for FlashAttention and add a code placeholder for bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* use gather_along_first_dim

Signed-off-by: Xiaowei Ren <[email protected]>

* finish the preliminary implementation of bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant code

Signed-off-by: Xiaowei Ren <[email protected]>

* fix assert condition

Signed-off-by: Xiaowei Ren <[email protected]>

* add draft implementation of SWA+CP with FusedAttention

Signed-off-by: Xiaowei Ren <[email protected]>

* fix attention mask type of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* add qkv_layout

Signed-off-by: Xiaowei Ren <[email protected]>

* add missing window_size argument

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix kv shape of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* bug and typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dout shape

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in fwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* save chunk_ids_to_kv_ag in fwd

Signed-off-by: Xiaowei Ren <[email protected]>

* add multi stream in bwd of swa+cp

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix to cp stream sync

Signed-off-by: Xiaowei Ren <[email protected]>

* rename AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* check if window size is None

Signed-off-by: Xiaowei Ren <[email protected]>

* fix docstring of AttnFuncWithCP

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add env var for users to choose KV ag or KV p2p

Signed-off-by: Xiaowei Ren <[email protected]>

* update cp tests

Signed-off-by: Xiaowei Ren <[email protected]>

* fix window size in cp unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix pytest skip messages

Signed-off-by: Xiaowei Ren <[email protected]>

* add cp_comm_type into API

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add deterministic konb in cuDNN fused attn backend

Signed-off-by: Xiaowei Ren <[email protected]>

* pass fp8 and fp8_meta to attn_func_with_cp

Signed-off-by: Xiaowei Ren <[email protected]>

* assert only Fused Attn can support FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant assert

Signed-off-by: Xiaowei Ren <[email protected]>

* add a fwd draft implementation of FP8 + CP

Signed-off-by: Xiaowei Ren <[email protected]>

* save fp8 and fp8_meta

Signed-off-by: Xiaowei Ren <[email protected]>

* assert sequence length divisible requirements

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove a redundant qkv_layout compute

Signed-off-by: Xiaowei Ren <[email protected]>

* if condition change

Signed-off-by: Xiaowei Ren <[email protected]>

* some typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add support table of context parallelism

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo and code format fix

Signed-off-by: Xiaowei Ren <[email protected]>

* do not print multiple disabling messages

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix aux_ctx_tensors of FP8

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix device in torch.arange and adjust code for the PR of MLA

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* commit code change for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* commit more code change for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* commit more fp8 code for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fixes

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* cast merged CP results from FP32 to BF16

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse

Signed-off-by: Xiaowei Ren <[email protected]>

* fix some bugs of FP8 dkv exchange

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add FP8 unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* fix typos and clean asserts

Signed-off-by: Xiaowei Ren <[email protected]>

* fix get_p2p_comm_info

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dkv p2p exchange

Signed-off-by: Xiaowei Ren <[email protected]>

* minor fix

Signed-off-by: Xiaowei Ren <[email protected]>

* change FP8 dkv P2P to A2A

Signed-off-by: Xiaowei Ren <[email protected]>

* add FP8+CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* assert amax reduction is needed for FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove duplicated code

Signed-off-by: Xiaowei Ren <[email protected]>

* destroy process group in CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* remove interval from fp8_recipe because it has been deprecated

Signed-off-by: Xiaowei Ren <[email protected]>

* try to fix the failed CP test with the latest CI pipeline

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove redundant f before string

Signed-off-by: Xiaowei Ren <[email protected]>

* change META_O_CP

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xiaowei Ren <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants