Skip to content

Commit

Permalink
fix number of buffers
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Aug 21, 2024
1 parent 5eb95f0 commit 7cbae2a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
36 changes: 36 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,45 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
}

// Distance from the load to the use.
if (forOp->hasAttr(tt::kNumStagesAttrName)) {
for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
loadToInfo[loadOp].distToUse =
schedule[use].first - schedule[loadOp].first;
}
return loadToInfo;
}
// If there is a use chain of load -> dot -> dot, we can ignore the second dot
// here.
// Start from loadOp, check uses and stop the recursion when hitting a dot.
DenseSet<Operation *> seen;
llvm::SmallVector<std::tuple<Operation *, Operation *>> loadOpToDirectUses;
std::function<void(Operation * op, Operation *)> dfsUse =
[&](Operation *op, Operation *use) {
if (!seen.insert(use).second)
return;
if (use->hasTrait<OpTrait::DotLike>()) {
loadOpToDirectUses.push_back(std::make_tuple(op, use));
return;
}
for (auto &tUse : use->getUses()) {
Operation *useOp = tUse.getOwner();
if (useOp && useOp->getBlock() == op->getBlock()) {
dfsUse(op, useOp);
}
}
};
DenseSet<Operation *> loadOps;
for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
if (!loadOps.insert(loadOp).second)
continue;
seen.clear();
dfsUse(loadOp, loadOp);
}
for (auto [loadOp, use] : loadOpToDirectUses) {
loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first;
}

Expand Down
8 changes: 4 additions & 4 deletions test/TritonGPU/comp-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
%30 = arith.extsi %arg17 : i32 to i64
// CHECK: tt.experimental_descriptor_load
// CHECK: %[[QLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<128x128xf16
// CHECK: %[[KLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<4x128x128xf16
// CHECK: %[[VLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<4x128x128xf16
// CHECK: %[[KBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<4xi64
// CHECK: %[[VBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<4xi64
// CHECK: %[[KLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16
// CHECK: %[[VLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16
// CHECK: %[[KBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64
// CHECK: %[[VBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64
// stage 0 iteration 0
// CHECK: %[[K0:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c0_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K0]]
Expand Down

0 comments on commit 7cbae2a

Please sign in to comment.