Skip to content

Commit

Permalink
computation pipelining implementation
Browse files Browse the repository at this point in the history
Summary:
SWP_FIRST_DOT: move first dot in numStages - 2 instead of numStages - 1
--> for this to work, we need to support predicate on barrier_wait
PEEL_EPILOGUE
LOAD_DIFFERENT_STAGE: put two loads in two different stages

For both createAsyncCopy and createTMAAsyncCopy, update TMAUserToWait so there will be an extra dependency
from the view of the load to the barrier_wait op:
  TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
In scheduleDependencies, waitOp will be added to the same stage/cluster as viewLoad.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Jul 26, 2024
1 parent 309484c commit bc397cf
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 22 deletions.
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class CoarseSchedule {
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);
bool includeArg,
DenseMap<Operation *, Operation *> *additionalDep);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

Expand Down
4 changes: 4 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_LLVM_DEBUG_ONLY",
"USE_TTGIR_LOC",
"NVPTX_ENABLE_DUMP",
"SWP_FIRST_DOT",
"PEEL_EPILOGUE",
"LOAD_DIFFERENT_STAGE",
"FIRST_LOAD_OF_USE",
// clang-format on
};

Expand Down
88 changes: 71 additions & 17 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -68,7 +69,8 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
tt::CoarseSchedule &schedule,
tt::CoarseSchedule::Cluster prefetchCluster,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages) {
int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
OpBuilder builder(forOp);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
// Replace the load with insert/extract slice.
Expand Down Expand Up @@ -125,6 +127,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
loadOffsets[0] = extractIdx;
auto viewLoad =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
TMAUserToWait[viewLoad] = wait; // viewLoad will depend on barrierWait
if (isMMV3Load) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
Expand Down Expand Up @@ -169,7 +172,8 @@ static void createTMAAsyncCopy(
scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp,
Value phase, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages) {
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
assert(phase && "Phase value is required for TMA async copy.");
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
Expand Down Expand Up @@ -201,6 +205,7 @@ static void createTMAAsyncCopy(
loadOffsets[0] = extractIdx;
auto viewLoad =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
if (isMMV3Load) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
Expand Down Expand Up @@ -569,13 +574,21 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,

tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront();
// Put the root uses of the loads in the last stage.
bool firstDot = true;
for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
// Non-LoadOp(s) are the root uses of all LoadOp(s) and should be
// always present in the opInfo
if (!isa<tt::LoadOp>(use)) {
schedule.insert(use, numStages - 1, rootUsersCluster);
if (::triton::tools::getBoolEnv("SWP_FIRST_DOT")) {
// check to see if it is first dot.
schedule.insert(use, firstDot ? numStages - 2 : numStages - 1,
rootUsersCluster);
firstDot = false;
} else {
schedule.insert(use, numStages - 1, rootUsersCluster);
}
rootUsers.insert(use);
}
}
Expand All @@ -585,17 +598,36 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
loadsClusters.push_back(schedule.clusters.newAtBack());
}
// Assign stages to the loads.
unsigned iter = 0;
DenseSet<Operation *> addedLoads;
for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads;
// Hard-code for the case of maxIndirectionLevel is 0.
if (::triton::tools::getBoolEnv("LOAD_DIFFERENT_STAGE")) {
if (addedLoads.count(loadOp))
continue;
stage = iter;
++iter;
}
addedLoads.insert(loadOp);
schedule.insert(loadOp, stage, loadsClusters[indLevel]);
}

// Distance from the load to the use.
DenseSet<Operation *> seenLoads;
for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) {
if (loadToInfo.count(loadOp) == 0)
continue;
// For the case where loadOp has multiple uses with indLevel of 0, should we
// ignore one of the uses? As an example, load -> dot1 -> dot2, can we
// ignore the use of load -> dot2?
if (::triton::tools::getBoolEnv("FIRST_LOAD_OF_USE")) {
if (seenLoads.count(loadOp))
continue;
}
seenLoads.insert(loadOp);
loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first;
}

Expand Down Expand Up @@ -664,16 +696,18 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule,

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
// the same stage and ordering cluster as the anchor op.
static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule,
int numStages) {
static void
scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule,
int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
SmallVector<std::tuple<Operation *, int, tt::CoarseSchedule::Cluster>>
opsInOrder = schedule.getOpsInOrder(forOp);
// Schedule dependencies stage by stage.
for (int stage = 0; stage < numStages; stage++) {
for (auto [op, stage_, cluster] : opsInOrder) {
if (stage_ != stage)
continue;
schedule.insertDepsOfOp(op, stage, cluster, false);
schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait);
}
}
}
Expand Down Expand Up @@ -715,14 +749,14 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp,
// Exception: Schedule loads with a distance of 1 together
// with the current op.
schedule.insertIfAbsent(defOp, stage, cluster);
schedule.insertDepsOfOp(defOp, stage, cluster, true);
schedule.insertDepsOfOp(defOp, stage, cluster, true, nullptr);
} else {
if (dist1Cluster.count(&cluster) == 0) {
dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster);
}
schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]);
schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster],
true);
true, nullptr);
}
}
}
Expand Down Expand Up @@ -938,7 +972,8 @@ static void createTMABarrierAndWait(
static SmallVector<Value>
createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
SmallVector<Value> &barriers, int numStages) {
SmallVector<Value> &barriers, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
// Calculate the number of buffers needed for each load.
// TODO pawel: we could do more fine-grained allocation here and
// allocate only the number of buffers that specific loads need.
Expand Down Expand Up @@ -1029,14 +1064,21 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
for (AsyncLoad &asyncLoad : asyncLoads) {
if (auto loadOp = dyn_cast<tt::LoadOp>(asyncLoad.loadOp)) {
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
schedule, prefetchCluster, loadToInfo, numStages);
schedule, prefetchCluster, loadToInfo, numStages,
TMAUserToWait);
} else {
auto descLoad = cast<tt::ExperimentalDescriptorLoadOp>(asyncLoad.loadOp);
createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx,
extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase,
schedule, loadToInfo, numStages);
schedule, loadToInfo, numStages, TMAUserToWait);
}
}
// Make sure each copy has a unique waitOp.
DenseSet<Operation *> uniqueWaits;
for (auto [copy, wait] : TMAUserToWait) {
assert(!uniqueWaits.count(wait));
uniqueWaits.insert(wait);
}
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
if (phase)
newYieldOperands.push_back(phase);
Expand Down Expand Up @@ -1083,9 +1125,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
});

SmallVector<Value> barriers;
DenseMap<Operation *, Operation *> TMAUserToWait;
// Convert the loads into async loads and create the allocs.
SmallVector<Value> allocs =
createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages);
SmallVector<Value> allocs = createAsyncOps(
forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait);

LLVM_DEBUG({
LDBG("Coarse schedule with async loads:");
Expand All @@ -1099,7 +1142,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
coarseSchedule.dump();
});

scheduleDependencies(forOp, coarseSchedule, numStages);
scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait);
LLVM_DEBUG({
LDBG("Coarse schedule with dependencies:");
coarseSchedule.dump();
Expand Down Expand Up @@ -1128,7 +1171,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
std::vector<std::pair<Operation *, unsigned>> &s) {
s = std::move(schedule);
};
options.peelEpilogue = false;
options.peelEpilogue = ::triton::tools::getBoolEnv("PEEL_EPILOGUE");
options.predicateFn = tt::predicateOp;
options.supportDynamicLoops = true;
options.annotateFn = [](Operation *op,
Expand Down Expand Up @@ -1392,7 +1435,8 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
// in the loop's iter_args. (Rule (2) above ensures this is well-defined.)
//
static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
scf::ForOp forOp) {
scf::ForOp forOp,
bool firstDot) {
LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp);

// Rule 1: All shmem operands are multi-buffered.
Expand Down Expand Up @@ -1470,6 +1514,13 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
})) {
return iterArgIdx;
}
// For the first dot that is in a different stage, it is only used by yield
// For the second dot, it is only used by yield, will be used by the next
// iteration
if (::triton::tools::getBoolEnv("SWP_FIRST_DOT")) {
// if (firstDot) return iterArgIdx;
return iterArgIdx;
}

// Rule 3b: Are all users of the dot's result from iteration i-1 after the
// first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be
Expand Down Expand Up @@ -1603,9 +1654,11 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
// the yield op.
IRRewriter builder(forOp.getContext());
llvm::MapVector<Operation *, int /*iterArgIdx*/> properlyAsyncDots;
bool firstDot = true;
for (auto WarpGroupDotOp : forOp.getBody()->getOps<ttng::WarpGroupDotOp>()) {
WarpGroupDotOp.setIsAsync(true);
if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) {
if (auto iterArgIdx =
dotCanBeProperlyAsync(WarpGroupDotOp, forOp, firstDot)) {
properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx;
} else {
builder.setInsertionPointAfter(WarpGroupDotOp);
Expand All @@ -1615,6 +1668,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
SmallVector<Value> waitOperands = {WarpGroupDotOp.getResult()};
threadValuesThroughWait(wait, waitOperands);
}
firstDot = false;
}

if (properlyAsyncDots.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
return op;
}

// Create if statement around wait_barrier
if (auto wait = dyn_cast<ttng::WaitBarrierOp>(op)) {
rewriter.setInsertionPoint(wait);
auto ifOp =
rewriter.create<scf::IfOp>(wait->getLoc(), pred, /*else=*/false);
// move wait to ifOp
rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin());
return ifOp;
}
assert("don't know how to predicate this op" && false);
return op;
}
Expand Down
14 changes: 10 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;

void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
tt::CoarseSchedule::Cluster cluster,
bool includeArg) {
void tt::CoarseSchedule::insertDepsOfOp(
Operation *op, int stage, tt::CoarseSchedule::Cluster cluster,
bool includeArg, DenseMap<Operation *, Operation *> *additionalDep) {
// Look in additionalDep.
if (additionalDep && additionalDep->find(op) != additionalDep->end()) {
Operation *wait = (*additionalDep)[op];
if (insertIfAbsent(wait, stage, cluster))
insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep);
}
for (Value operand : op->getOperands()) {
Value v = operand;
llvm::SmallDenseSet<Value> seen;
Expand All @@ -36,7 +42,7 @@ void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
Operation *defOp = v.getDefiningOp();
if (defOp && defOp->getBlock() == op->getBlock()) {
if (insertIfAbsent(defOp, stage, cluster)) {
insertDepsOfOp(defOp, stage, cluster, includeArg);
insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep);
}
}
}
Expand Down

0 comments on commit bc397cf

Please sign in to comment.