Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Computation Pipelining #4403

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
@@ -84,8 +84,10 @@ class CoarseSchedule {
return true;
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);
void
insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg,
DenseMap<Operation *, Operation *> *additionalDep = nullptr);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

82 changes: 68 additions & 14 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
@@ -56,7 +56,8 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
tt::CoarseSchedule &schedule,
tt::CoarseSchedule::Cluster prefetchCluster,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
OpBuilder builder(forOp);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
// Replace the load with insert/extract slice.
@@ -113,6 +114,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
loadOffsets[0] = extractIdx;
auto viewLoad =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
TMAUserToWait[viewLoad] = wait; // viewLoad will depend on barrierWait
if (isMMV3Load) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
@@ -157,7 +159,8 @@ static void createTMAAsyncCopy(
scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp,
Value phase, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages) {
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
assert(phase && "Phase value is required for TMA async copy.");
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
@@ -189,6 +192,7 @@ static void createTMAAsyncCopy(
loadOffsets[0] = extractIdx;
auto viewLoad =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
if (isMMV3Load) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
@@ -563,7 +567,12 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
// Non-LoadOp(s) are the root uses of all LoadOp(s) and should be
// always present in the opInfo
if (!isa<tt::LoadOp>(use)) {
schedule.insert(use, numStages - 1, rootUsersCluster);
int stage = numStages - 1;
if (use->hasAttr("loop.stage"))
stage = cast<IntegerAttr>(use->getAttr("loop.stage"))
.getValue()
.getZExtValue();
schedule.insertIfAbsent(use, stage, rootUsersCluster);
rootUsers.insert(use);
}
}
@@ -577,13 +586,53 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
if (loadToInfo.count(loadOp) == 0)
continue;
int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads;
schedule.insert(loadOp, stage, loadsClusters[indLevel]);
if (loadOp->hasAttr("loop.stage"))
stage = cast<IntegerAttr>(loadOp->getAttr("loop.stage"))
.getValue()
.getZExtValue();
schedule.insertIfAbsent(loadOp, stage, loadsClusters[indLevel]);
}

// Distance from the load to the use.
if (forOp->hasAttr(tt::kNumStagesAttrName)) {
for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
loadToInfo[loadOp].distToUse =
schedule[use].first - schedule[loadOp].first;
}
return loadToInfo;
}
// If there is a use chain of load -> dot -> dot, we can ignore the second dot
// here.
// Start from loadOp, check uses and stop the recursion when hitting a dot.
DenseSet<Operation *> seen;
llvm::SmallVector<std::tuple<Operation *, Operation *>> loadOpToDirectUses;
std::function<void(Operation * op, Operation *)> dfsUse =
[&](Operation *op, Operation *use) {
if (!seen.insert(use).second)
return;
if (use->hasTrait<OpTrait::DotLike>()) {
loadOpToDirectUses.push_back(std::make_tuple(op, use));
return;
}
for (auto &tUse : use->getUses()) {
Operation *useOp = tUse.getOwner();
if (useOp && useOp->getBlock() == op->getBlock()) {
dfsUse(op, useOp);
}
}
};
DenseSet<Operation *> loadOps;
for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
if (!loadOps.insert(loadOp).second)
continue;
seen.clear();
dfsUse(loadOp, loadOp);
}
for (auto [loadOp, use] : loadOpToDirectUses) {
loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first;
}

@@ -652,16 +701,18 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule,

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
// the same stage and ordering cluster as the anchor op.
static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule,
int numStages) {
static void
scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule,
int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
SmallVector<std::tuple<Operation *, int, tt::CoarseSchedule::Cluster>>
opsInOrder = schedule.getOpsInOrder(forOp);
// Schedule dependencies stage by stage.
for (int stage = 0; stage < numStages; stage++) {
for (auto [op, stage_, cluster] : opsInOrder) {
if (stage_ != stage)
continue;
schedule.insertDepsOfOp(op, stage, cluster, false);
schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait);
}
}
}
@@ -818,7 +869,7 @@ struct AsyncLoad {
};

// Create barriers and wait ops for the async loads. Barriers may be shared by
// multiple loads is the schedule allows it.
// multiple loads if the schedule allows it.
static void createTMABarrierAndWait(
scf::ForOp &forOp, SmallVector<AsyncLoad> &asyncLoads, Value insertIdx,
Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule,
@@ -926,7 +977,8 @@ static void createTMABarrierAndWait(
static SmallVector<Value>
createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
SmallVector<Value> &barriers, int numStages) {
SmallVector<Value> &barriers, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
// Calculate the number of buffers needed for each load.
// TODO pawel: we could do more fine-grained allocation here and
// allocate only the number of buffers that specific loads need.
@@ -1017,12 +1069,13 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
for (AsyncLoad &asyncLoad : asyncLoads) {
if (auto loadOp = dyn_cast<tt::LoadOp>(asyncLoad.loadOp)) {
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
schedule, prefetchCluster, loadToInfo, numStages);
schedule, prefetchCluster, loadToInfo, numStages,
TMAUserToWait);
} else {
auto descLoad = cast<tt::ExperimentalDescriptorLoadOp>(asyncLoad.loadOp);
createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx,
extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase,
schedule, loadToInfo, numStages);
schedule, loadToInfo, numStages, TMAUserToWait);
}
}
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
@@ -1071,9 +1124,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
});

SmallVector<Value> barriers;
DenseMap<Operation *, Operation *> TMAUserToWait;
// Convert the loads into async loads and create the allocs.
SmallVector<Value> allocs =
createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages);
SmallVector<Value> allocs = createAsyncOps(
forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait);

LLVM_DEBUG({
LDBG("Coarse schedule with async loads:");
@@ -1087,7 +1141,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
coarseSchedule.dump();
});

scheduleDependencies(forOp, coarseSchedule, numStages);
scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait);
LLVM_DEBUG({
LDBG("Coarse schedule with dependencies:");
coarseSchedule.dump();
Original file line number Diff line number Diff line change
@@ -74,6 +74,16 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
return op;
}

if (isa<ttng::WarpGroupDotOp>(op))
return op;
if (auto wait = dyn_cast<ttng::WaitBarrierOp>(op)) {
rewriter.setInsertionPoint(wait);
auto ifOp =
rewriter.create<scf::IfOp>(wait->getLoc(), pred, /*else=*/false);
rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin());
return ifOp;
}

assert("don't know how to predicate this op" && false);
return op;
}
14 changes: 10 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp
Original file line number Diff line number Diff line change
@@ -15,9 +15,15 @@ namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;

void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
tt::CoarseSchedule::Cluster cluster,
bool includeArg) {
void tt::CoarseSchedule::insertDepsOfOp(
Operation *op, int stage, tt::CoarseSchedule::Cluster cluster,
bool includeArg, DenseMap<Operation *, Operation *> *additionalDep) {
// Look in additionalDep.
if (additionalDep && additionalDep->find(op) != additionalDep->end()) {
Operation *wait = (*additionalDep)[op];
if (insertIfAbsent(wait, stage, cluster))
insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep);
}
for (Value operand : op->getOperands()) {
Value v = operand;
llvm::SmallDenseSet<Value> seen;
@@ -36,7 +42,7 @@ void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
Operation *defOp = v.getDefiningOp();
if (defOp && defOp->getBlock() == op->getBlock()) {
if (insertIfAbsent(defOp, stage, cluster)) {
insertDepsOfOp(defOp, stage, cluster, includeArg);
insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep);
}
}
}
102 changes: 102 additions & 0 deletions test/TritonGPU/comp-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 -debug-only=triton-matmul-loop-pipeline 2>&1 | FileCheck %s --check-prefix=DEBUG
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 | FileCheck %s

// DEBUG: Final coarse schedule:
// DEBUG: Ops in stage 2
// DEBUG-DAG: triton_nvidia_gpu.wait_barrier
// DEBUG-DAG: triton_nvidia_gpu.warp_group_dot
// DEBUG: Ops in stage 3
// DEBUG: triton_nvidia_gpu.wait_barrier
// DEBUG: Original loop:

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @_attn_fwd_tma(%arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg6: f32, %arg8: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i64, %arg14: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c128_i32 = arith.constant 128 : i32
%c0_i32 = arith.constant 0 : i32
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%25 = tt.experimental_descriptor_load %arg3[%arg9, %c0_i32] : !tt.ptr<i8> -> tensor<128x128xf16, #blocked1>
%26 = triton_gpu.local_alloc %25 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%27 = arith.extsi %arg14 : i32 to i64
%28 = tt.splat %arg6 : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%29 = tt.splat %arg6 : f32 -> tensor<128x128xf32, #mma>
%30 = arith.extsi %arg17 : i32 to i64
// CHECK: tt.experimental_descriptor_load
// CHECK: %[[QLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<128x128xf16
// CHECK: %[[KLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16
// CHECK: %[[VLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16
// CHECK: %[[KBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64
// CHECK: %[[VBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64
// stage 0 iteration 0
// CHECK: %[[K0:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c0_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K0]]
// stage 0 iteration 1
// CHECK: %[[K1:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c1_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K1]]
// stage 1 iteration 0
// CHECK: %[[V0:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c0_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V0]]
// stage 2 iteration 0
// CHECK: %[[FIRSTDOT:.+]] = triton_nvidia_gpu.warp_group_dot
// stage 0 iteration 2
// CHECK: %[[K2:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c2_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K2]]
// stage 1 iteration 1
// CHECK: %[[V1:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c1_i32, %c0_i32, %c0_i32]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V1]]
// CHECK: scf.for {{.*}} %[[ARG:.+]] = %[[FIRSTDOT]]
// CHECK: %[[KBARSUB:.+]] = triton_gpu.memdesc_subview %[[KBAR]][%[[KBARIDX:.+]]]
// CHECK: scf.if
// CHECK: triton_nvidia_gpu.wait_barrier %[[KBARSUB]]
// CHECK: %[[KLOOP:.+]] = triton_gpu.memdesc_subview %[[KLOC]]
// CHECK: tt.trans %[[KLOOP]]
// CHECK: %[[FIRSTDOTLOOP:.+]] = triton_nvidia_gpu.warp_group_dot
// CHECK: %[[WAIT:.+]]:{{[0-9]+}} = triton_nvidia_gpu.warp_group_dot_wait
// CHECK: "tt.reduce"(%[[ARG]])
// CHECK: %[[VBARSUB:.+]] = triton_gpu.memdesc_subview %[[VBAR]]
// CHECK: triton_nvidia_gpu.wait_barrier %[[VBARSUB]]
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local
// CHECK: scf.yield {{.*}}%[[WAIT]]#0
// arg26 is acc
%31:1 = scf.for %arg24 = %c0_i32 to %arg23 step %c128_i32 iter_args(%arg26 = %cst_2) -> (tensor<128x128xf32, #mma>) : i32 {
%48 = arith.divsi %arg11, %27 : i64
%49 = arith.trunci %48 : i64 to i32
%50 = arith.addi %arg24, %49 : i32
// loads in different stages
%51 = tt.experimental_descriptor_load %arg4[%50, %c0_i32] {loop.stage = 0 : i32} : !tt.ptr<i8> -> tensor<128x128xf16, #blocked1>
%52 = triton_gpu.local_alloc %51 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%53 = tt.trans %52 {order = array<i32: 1, 0>} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory>
%54 = triton_nvidia_gpu.warp_group_dot %26, %53, %cst_2 {inputPrecision = 0 : i32, loop.stage = 2} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
%55 = "tt.reduce"(%54) <{axis = 1 : i32}> ({
^bb0(%arg28: f32 loc(unknown), %arg29: f32 loc(unknown)):
%80 = arith.maxnumf %arg28, %arg29 : f32
tt.reduce.return %80 : f32
}) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%56 = arith.mulf %55, %28 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%58 = arith.mulf %54, %29 : tensor<128x128xf32, #mma>
%59 = tt.expand_dims %56 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%60 = tt.broadcast %59 : tensor<128x1xf32, #mma> -> tensor<128x128xf32, #mma>
%61 = arith.subf %58, %60 : tensor<128x128xf32, #mma>
%62 = math.exp2 %61 : tensor<128x128xf32, #mma>
%71 = arith.divsi %arg11, %30 : i64
%72 = arith.extsi %arg24 : i32 to i64
%73 = arith.addi %71, %72 : i64
%74 = arith.trunci %73 : i64 to i32
%75 = tt.experimental_descriptor_load %arg5[%74, %c0_i32] {loop.stage = 1 : i32} : !tt.ptr<i8> -> tensor<128x128xf16, #blocked1>
%76 = triton_gpu.local_alloc %75 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%77 = arith.truncf %62 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%78 = triton_gpu.convert_layout %77 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
%79 = triton_nvidia_gpu.warp_group_dot %78, %76, %arg26 {inputPrecision = 0 : i32, loop.stage = 3 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
scf.yield %79 : tensor<128x128xf32, #mma>
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
%42 = arith.truncf %31#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%43 = triton_gpu.convert_layout %42 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1>
tt.experimental_descriptor_store %arg8[%arg10, %c0_i32], %43 : !tt.ptr<i8>, tensor<128x128xf16, #blocked1>
tt.return
}
}