-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Context Parallel Attention with All-Gather #1106
[JAX] Context Parallel Attention with All-Gather #1106
Conversation
How is this different from #1059? |
The other method is implemented as ring attention with point to point comms. Both will ultimately provide context parallel attention but some forms of attention e.g window attention are more easily supported with ag approach. We also may get better scaling of the ag comms on multi node setups.
Michael Goldfarb
Get Outlook for iOS<https://aka.ms/o0ukef>
…________________________________
From: Przemyslaw Tredak ***@***.***>
Sent: Wednesday, August 14, 2024 7:32:03 PM
To: NVIDIA/TransformerEngine ***@***.***>
Cc: Michael Goldfarb ***@***.***>; Author ***@***.***>
Subject: Re: [NVIDIA/TransformerEngine] [JAX] DRAFT: Context Parallel Attention with All-Gather (PR #1106)
How is this different from #1059<#1059>?
—
Reply to this email directly, view it on GitHub<#1106 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/BI2EORL7SERQM3OUS62GNU3ZRPZIHAVCNFSM6AAAAABMRF7F5CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOJQGE2TEMBTGM>.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Hmmm, sure, but it still feels like duplicate of the functionality. Could you maybe collaborate with @mingxu1067 to merge your work with his? |
Sure thing. Ming and I have already been in discussion on how to merge the PRs. Likely there should be an initial CP attention PR and will need to follow with updated to other components on the jax side as we implement the rest of the CP features for jax. |
4cd496a
to
050bce8
Compare
43a8d78
to
988e3f6
Compare
|
||
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) | ||
|
||
|
||
def fused_attn_fwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an important discussion point: Do we keep CP hidden behind a common function or add a separate flavor of fused_attn_cp_allgather_fwd
. The though here was it makes sense to keep a common interface that naturally support CP. Other implementation strategies e.g. ring can be exposed via argument.
988e3f6
to
e912b64
Compare
e912b64
to
03e34c0
Compare
a83875c
to
df5267a
Compare
ba88ede
to
d3c9d06
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the contributions, especially the _FusedAttnConfig
, that's quite simplify the argument passing.
Please help check the pre-commit failure. You can run |
5c6bc6c
to
83a62b4
Compare
/te-ci jax |
83a62b4
to
0691181
Compare
/te-ci jax |
1 similar comment
/te-ci jax |
Signed-off-by: Michael Goldfarb <[email protected]>
0691181
to
5194eb4
Compare
/te-ci jax |
Description
Adds support for context parallel fused attention based on an all-gather/reduce-scatter approach. This implementation exposes the collective communication between CP ranks.
The first implementation of CP only support causal and no masking without bias. Additional QKV formats and configurations will be added to subsequent PRs.
Fixes # (issue)
Type of change
Changes
Adds context parallel axis resource and new primitives to fused attention.
Checklist: