Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Context Parallel Attention with All-Gather #1106

Conversation

mgoldfarb-nvidia
Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia commented Aug 14, 2024

Description

Adds support for context parallel fused attention based on an all-gather/reduce-scatter approach. This implementation exposes the collective communication between CP ranks.

The first implementation of CP only support causal and no masking without bias. Additional QKV formats and configurations will be added to subsequent PRs.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Adds context parallel axis resource and new primitives to fused attention.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@mgoldfarb-nvidia mgoldfarb-nvidia changed the title [JAX] WIP: Context Parallel Attention with All-Gather [JAX] DRAFT: Context Parallel Attention with All-Gather Aug 14, 2024
@ptrendx
Copy link
Member

ptrendx commented Aug 15, 2024

How is this different from #1059?

@mgoldfarb-nvidia
Copy link
Collaborator Author

mgoldfarb-nvidia commented Aug 15, 2024 via email

@ptrendx
Copy link
Member

ptrendx commented Aug 16, 2024

Hmmm, sure, but it still feels like duplicate of the functionality. Could you maybe collaborate with @mingxu1067 to merge your work with his?

@mgoldfarb-nvidia
Copy link
Collaborator Author

Sure thing. Ming and I have already been in discussion on how to merge the PRs. Likely there should be an initial CP attention PR and will need to follow with updated to other components on the jax side as we implement the rest of the CP features for jax.

@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch 2 times, most recently from 4cd496a to 050bce8 Compare August 21, 2024 21:10
@mgoldfarb-nvidia mgoldfarb-nvidia changed the title [JAX] DRAFT: Context Parallel Attention with All-Gather [JAX] Context Parallel Attention with All-Gather Aug 22, 2024
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch 2 times, most recently from 43a8d78 to 988e3f6 Compare August 22, 2024 21:30

register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)


def fused_attn_fwd(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important discussion point: Do we keep CP hidden behind a common function or add a separate flavor of fused_attn_cp_allgather_fwd. The though here was it makes sense to keep a common interface that naturally support CP. Other implementation strategies e.g. ring can be exposed via argument.

@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch from 988e3f6 to e912b64 Compare August 23, 2024 13:27
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch from e912b64 to 03e34c0 Compare August 27, 2024 23:03
transformer_engine/jax/sharding.py Outdated Show resolved Hide resolved
transformer_engine/jax/attention.py Outdated Show resolved Hide resolved
transformer_engine/jax/attention.py Show resolved Hide resolved
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch 5 times, most recently from a83875c to df5267a Compare September 2, 2024 16:18
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch 2 times, most recently from ba88ede to d3c9d06 Compare September 5, 2024 22:34
@zlsh80826 zlsh80826 self-requested a review September 6, 2024 03:31
Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contributions, especially the _FusedAttnConfig, that's quite simplify the argument passing.

@zlsh80826
Copy link
Collaborator

Please help check the pre-commit failure. You can run pre-commit install in your working directory, then it will check when each time you commit.

@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch from 5c6bc6c to 83a62b4 Compare September 9, 2024 21:29
@mgoldfarb-nvidia
Copy link
Collaborator Author

/te-ci jax

@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch from 83a62b4 to 0691181 Compare September 9, 2024 23:20
@mgoldfarb-nvidia
Copy link
Collaborator Author

/te-ci jax

1 similar comment
@ptrendx
Copy link
Member

ptrendx commented Sep 10, 2024

/te-ci jax

@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the mgoldfarb-nvidia/context_parallel_attention_with_all_gather branch from 0691181 to 5194eb4 Compare September 16, 2024 15:23
@mingxu1067
Copy link
Collaborator

/te-ci jax

@mgoldfarb-nvidia mgoldfarb-nvidia merged commit 9101a78 into NVIDIA:main Sep 17, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants