-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Computation Pipelining #4403
base: main
Are you sure you want to change the base?
Conversation
@pawelszczerbuk @ThomasRaoux This is still in draft mode. I will ask Meta folks to take a look first. But if you have some high-level comments, please let me know! Thanks! @bertmaher @chenyang78 Since Hongtao is on vacation, wondering if you have some time to take a look at this. |
include/triton/Tools/Sys/GetEnv.hpp
Outdated
"SWP_FIRST_DOT", | ||
"PEEL_EPILOGUE", | ||
"LOAD_DIFFERENT_STAGE", | ||
"FIRST_LOAD_OF_USE", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these get a documentation line in the README.md?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I will add them to README.md. We may not need all of these env variables, maybe we should always set loads to different stages, see this comment.
// TODO pawel: we could do more fine-grained allocation here and
// allocate only the number of buffers that specific loads need.
What does |
It means SWP stage i at iteration j. As an example, the prologue usually handles S0(0) and S0(1) for num_stages of 3. |
So this actually refers to three iterations from i to i+2?
|
Yes it is the main loop unrolled 3 iterations to show loop-carried dependencies. |
Thanks, this is super interesting work and proves that we need more advanced modulo scheduling when doing software pipelining, however I'm afraid we are reaching the limit of hardcoded heuristic and I think this is going to ad hoc (meaning too much an overfit for FA) to be pushed into the pipeliner logic. |
Thanks Thomas! I wonder if we can at least land some fixings. One example is the correct dependency from barrierWait to viewLoad (MemDescSubviewOp). Before we use the output of a loadOp via vewLoad, we need to make sure the barrierWait is done. We (Meta folks) are open to discussions and collaborations on the correct implementation of modulo scheduling that includes both overlapping memory with computation and overlapping computation ops. As mentioned in the description:
We can add a formal dependency analyzer that feeds the decisions to SWP. Or just use a single knob of num_computation_stage. Let me know what is the right way to collaborate on this. |
Yes if there are general fixes/improvements we can land those. (I haven't yet looked at this exact case but in general I agree)
It makes sense to me. I think there are different ways this can be done, one is with a scheduler however this is usually a very difficult heuristic to get right. Combining it with PGO but the complexity of tuning might be impractical. |
Yes we have been thinking about using autotuning to help the scheduler/dependency analyzer. But which knobs to expose to autotuning is something that needs more discussion. Yes, PGO can help reducing the search space for tuning or help scheduler to have more accurate heuristics. For Flash Attention, there are at least two different schedules we can consider:
vs.
We need to make both staging decisions and ordering decisions (i.e stage assignment and cluster assignment as currently implemented in SWP). |
After discussions with Pawel/Phil/Thomas, we are proposing loop schedule annotation for computation pipelining: Prior to SWP, we have a single loop schedule annotation attached to the loop, one example can be flash_attention_first_dot as a loop schedule, so on the loop, we will annotate with "loop_schedule = flash_attention_first_dot". We will add one pass prior to SWP to convert the loop schedule to (stage, cluster) on ttgir ops, SWP will use these (stage, cluster) to generate loop body, epilogue, prologue etc. Definition of a loop scheduleThe easiest format for a schedule can be a text file, with one line for each key operation (i.e 2D mulf/addf/subf/exp/dot etc) at ttgir level: Even though the loop annotation is at source level, the schedule is described according to ttgir operations so it is not straightforward to developers. We choose ttgir operations since one ttir operation can split into many ttgir operations. People updating the library of schedules need to know what the kernel looks like at ttgir level. The instruction type needs to be general enough so it is not tied to a specific ttgir version. The schedules may become outdated when the kernel is updated, but it will not degrade performance if we can ignore the schedule when it doesn't apply. We can also use autotuning to iterate over all possible loop schedules of a kernel or a pruned set of loop schedules. Flash Attention Option 1Take flash attention as an example, we can have a schedule where the first dot is in its own stage and own cluster. The pipelined loop will look like: (with num_stages = 4) Flash Attention Option 2For flash attention, we can also have a schedule where the 2nd dot is in its own stage. The pipelined loop will look like: (with num_stages = 4) ImplementationWe will add the support for (stage, cluster) on ttgir operations first. Once that is done, we can improve SWP with test cases using (stage, cluster) attributes at ttgir level. We need to improve SWP to correctly handle test cases with legal (stage, cluster) attributes. We will have a pass to convert the loop schedule to annotations on the ttgir operations, prior to SWP. The pass will need to check if the schedule is valid and only annotate the ttgir operations with (stage, cluster) when the schedule is valid. During software pipelining, operations without (stage, cluster) will be assigned a (stage, cluster) according to dependencies. The pipelined loop body will be generated based on the (stage, cluster) of each operation. |
@ThomasRaoux @ThomasRaoux @ptillet @Jokeren Wondering if you have any more feedback on the proposal and the implementation plan. Thanks! |
Summary: For the test case, we need to predicate the dot for prologue. Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
bc397cf
to
7d0ece3
Compare
Updated the diff to use attribute "loop.stage" inside SWP, added test cases to show and fix a few issues The update currently doesn't handle refactoring of SWP. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
7d0ece3
to
7cbae2a
Compare
Summary:
The idea of computation pipelining comes from FA3 paper. The goal is to divide computation ops into stages and to overlap ops that use cuda core with ops that use tensor core.
Use flash attention with num_stages=3 as an example, currently the two loads are in stage 0 (S0), all other ops are in the last stage (stage 2). The loop body will look like
S2(i) S2(i+1) S2(i+2)
S0(i+2) S0(i+3) S0(i+4)
With computation pipelining enabled, we can put the load for the first dot in S0, the load for the second dot in S1, the first dot in S2, and the rest in S3. The loop body will look like
S2(i+1) S2(i+2) S2(i+3)
S3(i) S3(i+1) S3(i+2)
S0(i+3) S0(i+4) S0(i+5)
S1(i+2) S1(i+3) S1(i+4)
Note that the distance between the load and the corresponding usage stays the same (i.e 2 stages from the load of k in S0(i+3) to the use in S2(i+3), 2 stages from the load of v in S1(i+2) to the use in S3(i+2)), number of buffers stays at 3. The live range for output of S2 is increased to span the loop (i.e from S2(i+1) to S3(i+1)).
Details:
SWP_FIRST_DOT: move first dot in numStages - 2 instead of numStages - 1
--> for this to work, we need to support predicate on barrier_wait or use PEEL_EPILOGUE
LOAD_DIFFERENT_STAGE: put two loads in two different stages
FIRST_USE_OF_LOAD: only consider the first use of each load, for example load of k is used by the first dot and indirectly used by the 2nd dot, but we should only consider the use by the first dot to calculate the use distance.
For both createAsyncCopy and createTMAAsyncCopy, update TMAUserToWait so there will be an extra dependency from the view of the load to the barrier_wait op:
TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
In scheduleDependencies, waitOp will be added to the same stage/cluster as viewLoad.
LOAD_DIFFERENT_STAGE=1 FIRST_LOAD_OF_USE=1 SWP_FIRST_DOT=1 together with a TMA variant of FA (pytorch/benchmark#2379) seems to improve performance for headDim 128 and 256.
This implementation is kind of specific to flash attention, and needs rework. We can make it more general by adding a dependency analyzer that analyzes dependency and resource usages of ttgir ops, decides how to break ops into stages, and feeds the decisions to SWP. Another less general option is to only support num_computation_stage being 2 (i.e separate computation ops into two stages). In this restricted form, lifting the first dot into its own stage seems reasonable.
Open to discussions and reworks.