Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Computation Pipelining #4403

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

manman-ren
Copy link
Collaborator

Summary:
The idea of computation pipelining comes from FA3 paper. The goal is to divide computation ops into stages and to overlap ops that use cuda core with ops that use tensor core.

Use flash attention with num_stages=3 as an example, currently the two loads are in stage 0 (S0), all other ops are in the last stage (stage 2). The loop body will look like
S2(i) S2(i+1) S2(i+2)
S0(i+2) S0(i+3) S0(i+4)
With computation pipelining enabled, we can put the load for the first dot in S0, the load for the second dot in S1, the first dot in S2, and the rest in S3. The loop body will look like
S2(i+1) S2(i+2) S2(i+3)
S3(i) S3(i+1) S3(i+2)
S0(i+3) S0(i+4) S0(i+5)
S1(i+2) S1(i+3) S1(i+4)

Note that the distance between the load and the corresponding usage stays the same (i.e 2 stages from the load of k in S0(i+3) to the use in S2(i+3), 2 stages from the load of v in S1(i+2) to the use in S3(i+2)), number of buffers stays at 3. The live range for output of S2 is increased to span the loop (i.e from S2(i+1) to S3(i+1)).

Details:
SWP_FIRST_DOT: move first dot in numStages - 2 instead of numStages - 1
--> for this to work, we need to support predicate on barrier_wait or use PEEL_EPILOGUE
LOAD_DIFFERENT_STAGE: put two loads in two different stages
FIRST_USE_OF_LOAD: only consider the first use of each load, for example load of k is used by the first dot and indirectly used by the 2nd dot, but we should only consider the use by the first dot to calculate the use distance.

For both createAsyncCopy and createTMAAsyncCopy, update TMAUserToWait so there will be an extra dependency from the view of the load to the barrier_wait op:
TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
In scheduleDependencies, waitOp will be added to the same stage/cluster as viewLoad.

LOAD_DIFFERENT_STAGE=1 FIRST_LOAD_OF_USE=1 SWP_FIRST_DOT=1 together with a TMA variant of FA (pytorch/benchmark#2379) seems to improve performance for headDim 128 and 256.

This implementation is kind of specific to flash attention, and needs rework. We can make it more general by adding a dependency analyzer that analyzes dependency and resource usages of ttgir ops, decides how to break ops into stages, and feeds the decisions to SWP. Another less general option is to only support num_computation_stage being 2 (i.e separate computation ops into two stages). In this restricted form, lifting the first dot into its own stage seems reasonable.

Open to discussions and reworks.

@manman-ren manman-ren requested a review from ptillet as a code owner July 26, 2024 22:25
@manman-ren manman-ren marked this pull request as draft July 26, 2024 22:25
@manman-ren
Copy link
Collaborator Author

@pawelszczerbuk @ThomasRaoux This is still in draft mode. I will ask Meta folks to take a look first. But if you have some high-level comments, please let me know! Thanks!

@bertmaher @chenyang78 Since Hongtao is on vacation, wondering if you have some time to take a look at this.

"SWP_FIRST_DOT",
"PEEL_EPILOGUE",
"LOAD_DIFFERENT_STAGE",
"FIRST_LOAD_OF_USE",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these get a documentation line in the README.md?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I will add them to README.md. We may not need all of these env variables, maybe we should always set loads to different stages, see this comment.

  // TODO pawel: we could do more fine-grained allocation here and
  // allocate only the number of buffers that specific loads need.

@Jokeren
Copy link
Contributor

Jokeren commented Jul 26, 2024

What does Si(j) mean?

@manman-ren
Copy link
Collaborator Author

What does Si(j) mean?

It means SWP stage i at iteration j. As an example, the prologue usually handles S0(0) and S0(1) for num_stages of 3.

@Jokeren
Copy link
Contributor

Jokeren commented Jul 27, 2024

So this actually refers to three iterations from i to i+2?

S2(i) S2(i+1) S2(i+2)
S0(i+2) S0(i+3) S0(i+4)

@manman-ren
Copy link
Collaborator Author

So this actually refers to three iterations from i to i+2?

S2(i) S2(i+1) S2(i+2)
S0(i+2) S0(i+3) S0(i+4)

Yes it is the main loop unrolled 3 iterations to show loop-carried dependencies.

@ThomasRaoux
Copy link
Collaborator

Thanks, this is super interesting work and proves that we need more advanced modulo scheduling when doing software pipelining, however I'm afraid we are reaching the limit of hardcoded heuristic and I think this is going to ad hoc (meaning too much an overfit for FA) to be pushed into the pipeliner logic.
I do think we want to be able to get those performance but we most likely need a more programmable way to compute the schedule. This is something that came up several times when discussing with Phil and Pawel and maybe it is time to address it.

@manman-ren
Copy link
Collaborator Author

Thanks, this is super interesting work and proves that we need more advanced modulo scheduling when doing software pipelining, however I'm afraid we are reaching the limit of hardcoded heuristic and I think this is going to ad hoc (meaning too much an overfit for FA) to be pushed into the pipeliner logic. I do think we want to be able to get those performance but we most likely need a more programmable way to compute the schedule. This is something that came up several times when discussing with Phil and Pawel and maybe it is time to address it.

Thanks Thomas! I wonder if we can at least land some fixings. One example is the correct dependency from barrierWait to viewLoad (MemDescSubviewOp). Before we use the output of a loadOp via vewLoad, we need to make sure the barrierWait is done.

We (Meta folks) are open to discussions and collaborations on the correct implementation of modulo scheduling that includes both overlapping memory with computation and overlapping computation ops.

As mentioned in the description:

We can make it more general by adding a dependency analyzer that analyzes dependency and resource usages of ttgir ops, decides how to break ops into stages, and feeds the decisions to SWP. Another less general option is to only support num_computation_stage being 2 (i.e separate computation ops into two stages). In this restricted form, lifting the first dot into its own stage seems reasonable.

We can add a formal dependency analyzer that feeds the decisions to SWP. Or just use a single knob of num_computation_stage. Let me know what is the right way to collaborate on this.

@ThomasRaoux
Copy link
Collaborator

Thanks, this is super interesting work and proves that we need more advanced modulo scheduling when doing software pipelining, however I'm afraid we are reaching the limit of hardcoded heuristic and I think this is going to ad hoc (meaning too much an overfit for FA) to be pushed into the pipeliner logic. I do think we want to be able to get those performance but we most likely need a more programmable way to compute the schedule. This is something that came up several times when discussing with Phil and Pawel and maybe it is time to address it.

Thanks Thomas! I wonder if we can at least land some fixings. One example is the correct dependency from barrierWait to viewLoad (MemDescSubviewOp). Before we use the output of a loadOp via vewLoad, we need to make sure the barrierWait is done.

Yes if there are general fixes/improvements we can land those. (I haven't yet looked at this exact case but in general I agree)

We (Meta folks) are open to discussions and collaborations on the correct implementation of modulo scheduling that includes both overlapping memory with computation and overlapping computation ops.

As mentioned in the description:

We can make it more general by adding a dependency analyzer that analyzes dependency and resource usages of ttgir ops, decides how to break ops into stages, and feeds the decisions to SWP. Another less general option is to only support num_computation_stage being 2 (i.e separate computation ops into two stages). In this restricted form, lifting the first dot into its own stage seems reasonable.

We can add a formal dependency analyzer that feeds the decisions to SWP. Or just use a single knob of num_computation_stage. Let me know what is the right way to collaborate on this.

It makes sense to me. I think there are different ways this can be done, one is with a scheduler however this is usually a very difficult heuristic to get right. Combining it with PGO but the complexity of tuning might be impractical.
The other solution is to expose more control in the language directly.
This will probably require a deeper discussion and some white boarding to figure out which path we want to prototype.

@manman-ren
Copy link
Collaborator Author

It makes sense to me. I think there are different ways this can be done, one is with a scheduler however this is usually a very difficult heuristic to get right. Combining it with PGO but the complexity of tuning might be impractical.
The other solution is to expose more control in the language directly.
This will probably require a deeper discussion and some white boarding to figure out which path we want to prototype.

Yes we have been thinking about using autotuning to help the scheduler/dependency analyzer. But which knobs to expose to autotuning is something that needs more discussion. Yes, PGO can help reducing the search space for tuning or help scheduler to have more accurate heuristics. For Flash Attention, there are at least two different schedules we can consider:

S2 has the first dot, S3 has softmax and the 2nd dot
S2(i+1)  S2(i+2) S2(i+3)
S3(i)     S3(i+1)  S3(i+2)
S0(i+3) S0(i+4) S0(i+5)
S1(i+2) S1(i+3)  S1(i+4)
Output of the first dot will be live through the whole loop, as it is generated in S2(i+1) and used in S3(i+1). And we are overlapping S2(i+1) with softmax in S3(i).

vs.

S2 has the first dot and the softmax, S3 has the second dot (this seems to be what is chosen in FA3)
S2(i+1)_dot           S2(i+2) _dot        S2(i+3)_dot
S3(i)                      S3(i+1)                 S3(i+2)
S2(i+1)_softmax  S2(i+2)_softmax S2(i+3)_softmax
S0(i+3)                 S0(i+4)                S0(i+5)
S1(i+2)                  S1(i+3)                S1(i+4)
Output of the first dot will be live inside one iteration, as it is generated in S2(i+1)_dot and used in S2(i+1)_softmax. And we are overlapping execution of S3(i) with execution of S2(i+1)_softmax.

We need to make both staging decisions and ordering decisions (i.e stage assignment and cluster assignment as currently implemented in SWP).

@manman-ren
Copy link
Collaborator Author

After discussions with Pawel/Phil/Thomas, we are proposing loop schedule annotation for computation pipelining:
At ttgir level, we will support attributes (stage, cluster) on operations. At source/ttir level, we will support schedule annotations on loops. The possible loop schedules will come from a library of schedules.

Prior to SWP, we have a single loop schedule annotation attached to the loop, one example can be flash_attention_first_dot as a loop schedule, so on the loop, we will annotate with "loop_schedule = flash_attention_first_dot". We will add one pass prior to SWP to convert the loop schedule to (stage, cluster) on ttgir ops, SWP will use these (stage, cluster) to generate loop body, epilogue, prologue etc.

Definition of a loop schedule

The easiest format for a schedule can be a text file, with one line for each key operation (i.e 2D mulf/addf/subf/exp/dot etc) at ttgir level:
instruction type + order (first dot, 2nd dot etc)
shape rank
pipeline stage number
cluster id

Even though the loop annotation is at source level, the schedule is described according to ttgir operations so it is not straightforward to developers. We choose ttgir operations since one ttir operation can split into many ttgir operations. People updating the library of schedules need to know what the kernel looks like at ttgir level. The instruction type needs to be general enough so it is not tied to a specific ttgir version.

The schedules may become outdated when the kernel is updated, but it will not degrade performance if we can ignore the schedule when it doesn't apply. We can also use autotuning to iterate over all possible loop schedules of a kernel or a pruned set of loop schedules.

Flash Attention Option 1

Take flash attention as an example, we can have a schedule where the first dot is in its own stage and own cluster.
dot first 2D 0 0 ← first 2D dot is in stage 0, cluster 0
dot second 2D 1 1 ← second 2D dot is in stage 1, cluster 1

The pipelined loop will look like: (with num_stages = 4)
MMA0(i+1) <-- stage 2, cluster 0 (first dot is in the first computation stage, thus last stage - 1: stage 2)
Softmax(i) <-- stage 3, cluster 1 (due to dependency)
MUL(i) <-- stage 3, cluster 1 (due to dependency)
MMA1(i) <-- stage 3, cluster 1 (2nd dot is in the last computation stage, thus the last stage: stage 3)
loadK(i+3) <-- stage 0
loadV(i+2) <-- stage 1

Flash Attention Option 2

For flash attention, we can also have a schedule where the 2nd dot is in its own stage.
dot first 2D 0 0
exp2 * 2D 0 2 <-- all 2D exp2 ops will be in stage 0, cluster 2
mulf * 2D 0 2
addf * 2D 0 2
dot second 2D 1 1

The pipelined loop will look like: (with num_stages = 4)
MMA0(i+1) <-- stage 2, cluster 0
MMA1(i) <-- stage 3, cluster 1
Softmax(i+1) <-- stage 2, cluster 2
MUL(i+1) <-- stage 2, cluster 2
loadK(i+3) <-- stage 0
loadV(i+2) <-- stage 1

Implementation

We will add the support for (stage, cluster) on ttgir operations first. Once that is done, we can improve SWP with test cases using (stage, cluster) attributes at ttgir level. We need to improve SWP to correctly handle test cases with legal (stage, cluster) attributes.

We will have a pass to convert the loop schedule to annotations on the ttgir operations, prior to SWP. The pass will need to check if the schedule is valid and only annotate the ttgir operations with (stage, cluster) when the schedule is valid.

During software pipelining, operations without (stage, cluster) will be assigned a (stage, cluster) according to dependencies. The pipelined loop body will be generated based on the (stage, cluster) of each operation.

@manman-ren
Copy link
Collaborator Author

@ThomasRaoux @ThomasRaoux @ptillet @Jokeren Wondering if you have any more feedback on the proposal and the implementation plan. Thanks!

Summary: For the test case, we need to predicate the dot for prologue.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren force-pushed the computation-swp-first-dot branch from bc397cf to 7d0ece3 Compare August 21, 2024 01:25
@manman-ren
Copy link
Collaborator Author

manman-ren commented Aug 21, 2024

Updated the diff to use attribute "loop.stage" inside SWP, added test cases to show and fix a few issues
1> the dependency from the WaitBarrier op to the operation that actually uses the output of async load, as an example:
%95 = triton_gpu.memdesc_subview %30[%92]
triton_nvidia_gpu.wait_barrier %95, %94
%96 = triton_gpu.memdesc_subview %28[%92, %c0_i32, %c0_i32]
%97 = tt.trans %96 {order = array<i32: 1, 0>}
%98 = triton_nvidia_gpu.warp_group_dot %23, %97, %cst_0
If we set warp_group_dot at stage 2, we will want the wait_barrier to be in stage 2. Without the fix, the wait_barrier will be put in stage 3 during scheduleRemaining.
2> change distToUse to ignore dots/uses that reference a loadOp only through another dot
in the case of loadOp --> dot1 --> dot2, and there is no path from loadOp to dot2 without going through dot1, we can ignore the use of loadOp from dot2. This will reduce the number of buffers from 4 to 3 for our test case where we have dot1 in stage2, load1 in stage0, load2 in stage1 and dot2 in stage3.
3> handle more ops in predicateOp: WaitBarrierOp etc

The update currently doesn't handle refactoring of SWP.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants