diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 1dd1fc686034a..1b111fe53782e 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -85,7 +85,8 @@ class CoarseSchedule { } void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, - bool includeArg); + bool includeArg, + DenseMap *additionalDep); void erase(Operation *op) { opToStageAndCluster.erase(op); } diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index b873fe236aa3b..fc5e81d08dcad 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,6 +30,10 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_LLVM_DEBUG_ONLY", "USE_TTGIR_LOC", "NVPTX_ENABLE_DUMP", + "SWP_FIRST_DOT", + "PEEL_EPILOGUE", + "LOAD_DIFFERENT_STAGE", + "FIRST_LOAD_OF_USE", // clang-format on }; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index e18d9312daa80..5c941c30c579e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -16,6 +16,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -68,7 +69,8 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::CoarseSchedule &schedule, tt::CoarseSchedule::Cluster prefetchCluster, llvm::MapVector &loadToInfo, - int numStages) { + int numStages, + DenseMap &TMAUserToWait) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -125,6 +127,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = wait; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -169,7 +172,8 @@ static void createTMAAsyncCopy( scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, Value phase, tt::CoarseSchedule &schedule, - llvm::MapVector &loadToInfo, int numStages) { + llvm::MapVector &loadToInfo, int numStages, + DenseMap &TMAUserToWait) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); Attribute sharedMemorySpace = @@ -201,6 +205,7 @@ static void createTMAAsyncCopy( loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -569,13 +574,21 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); // Put the root uses of the loads in the last stage. + bool firstDot = true; for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { if (loadToInfo.count(loadOp) == 0) continue; // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be // always present in the opInfo if (!isa(use)) { - schedule.insert(use, numStages - 1, rootUsersCluster); + if (::triton::tools::getBoolEnv("SWP_FIRST_DOT")) { + // check to see if it is first dot. + schedule.insert(use, firstDot ? numStages - 2 : numStages - 1, + rootUsersCluster); + firstDot = false; + } else { + schedule.insert(use, numStages - 1, rootUsersCluster); + } rootUsers.insert(use); } } @@ -585,17 +598,36 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, loadsClusters.push_back(schedule.clusters.newAtBack()); } // Assign stages to the loads. + unsigned iter = 0; + DenseSet addedLoads; for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { if (loadToInfo.count(loadOp) == 0) continue; int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + // Hard-code for the case of maxIndirectionLevel is 0. + if (::triton::tools::getBoolEnv("LOAD_DIFFERENT_STAGE")) { + if (addedLoads.count(loadOp)) + continue; + stage = iter; + ++iter; + } + addedLoads.insert(loadOp); schedule.insert(loadOp, stage, loadsClusters[indLevel]); } // Distance from the load to the use. + DenseSet seenLoads; for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { if (loadToInfo.count(loadOp) == 0) continue; + // For the case where loadOp has multiple uses with indLevel of 0, should we + // ignore one of the uses? As an example, load -> dot1 -> dot2, can we + // ignore the use of load -> dot2? + if (::triton::tools::getBoolEnv("FIRST_LOAD_OF_USE")) { + if (seenLoads.count(loadOp)) + continue; + } + seenLoads.insert(loadOp); loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; } @@ -664,8 +696,10 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, - int numStages) { +static void +scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages, + DenseMap &TMAUserToWait) { SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. @@ -673,7 +707,7 @@ static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, for (auto [op, stage_, cluster] : opsInOrder) { if (stage_ != stage) continue; - schedule.insertDepsOfOp(op, stage, cluster, false); + schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait); } } } @@ -715,14 +749,14 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, // Exception: Schedule loads with a distance of 1 together // with the current op. schedule.insertIfAbsent(defOp, stage, cluster); - schedule.insertDepsOfOp(defOp, stage, cluster, true); + schedule.insertDepsOfOp(defOp, stage, cluster, true, nullptr); } else { if (dist1Cluster.count(&cluster) == 0) { dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); } schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], - true); + true, nullptr); } } } @@ -938,7 +972,8 @@ static void createTMABarrierAndWait( static SmallVector createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, - SmallVector &barriers, int numStages) { + SmallVector &barriers, int numStages, + DenseMap &TMAUserToWait) { // Calculate the number of buffers needed for each load. // TODO pawel: we could do more fine-grained allocation here and // allocate only the number of buffers that specific loads need. @@ -1029,14 +1064,21 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, - schedule, prefetchCluster, loadToInfo, numStages); + schedule, prefetchCluster, loadToInfo, numStages, + TMAUserToWait); } else { auto descLoad = cast(asyncLoad.loadOp); createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, - schedule, loadToInfo, numStages); + schedule, loadToInfo, numStages, TMAUserToWait); } } + // Make sure each copy has a unique waitOp. + DenseSet uniqueWaits; + for (auto [copy, wait] : TMAUserToWait) { + assert(!uniqueWaits.count(wait)); + uniqueWaits.insert(wait); + } SmallVector newYieldOperands = {insertIdx, extractIdx}; if (phase) newYieldOperands.push_back(phase); @@ -1083,9 +1125,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule( }); SmallVector barriers; + DenseMap TMAUserToWait; // Convert the loads into async loads and create the allocs. - SmallVector allocs = - createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + SmallVector allocs = createAsyncOps( + forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with async loads:"); @@ -1099,7 +1142,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - scheduleDependencies(forOp, coarseSchedule, numStages); + scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); coarseSchedule.dump(); @@ -1128,7 +1171,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( std::vector> &s) { s = std::move(schedule); }; - options.peelEpilogue = false; + options.peelEpilogue = ::triton::tools::getBoolEnv("PEEL_EPILOGUE"); options.predicateFn = tt::predicateOp; options.supportDynamicLoops = true; options.annotateFn = [](Operation *op, @@ -1392,7 +1435,8 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, // in the loop's iter_args. (Rule (2) above ensures this is well-defined.) // static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, - scf::ForOp forOp) { + scf::ForOp forOp, + bool firstDot) { LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); // Rule 1: All shmem operands are multi-buffered. @@ -1470,6 +1514,13 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, })) { return iterArgIdx; } + // For the first dot that is in a different stage, it is only used by yield + // For the second dot, it is only used by yield, will be used by the next + // iteration + if (::triton::tools::getBoolEnv("SWP_FIRST_DOT")) { + // if (firstDot) return iterArgIdx; + return iterArgIdx; + } // Rule 3b: Are all users of the dot's result from iteration i-1 after the // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be @@ -1603,9 +1654,11 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { // the yield op. IRRewriter builder(forOp.getContext()); llvm::MapVector properlyAsyncDots; + bool firstDot = true; for (auto WarpGroupDotOp : forOp.getBody()->getOps()) { WarpGroupDotOp.setIsAsync(true); - if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) { + if (auto iterArgIdx = + dotCanBeProperlyAsync(WarpGroupDotOp, forOp, firstDot)) { properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx; } else { builder.setInsertionPointAfter(WarpGroupDotOp); @@ -1615,6 +1668,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { SmallVector waitOperands = {WarpGroupDotOp.getResult()}; threadValuesThroughWait(wait, waitOperands); } + firstDot = false; } if (properlyAsyncDots.empty()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index fe8b035400a20..97cb557f1ed98 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -74,6 +74,15 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, return op; } + // Create if statement around wait_barrier + if (auto wait = dyn_cast(op)) { + rewriter.setInsertionPoint(wait); + auto ifOp = + rewriter.create(wait->getLoc(), pred, /*else=*/false); + // move wait to ifOp + rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } assert("don't know how to predicate this op" && false); return op; } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 1116b70a0262a..1d10595c3d9e2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -15,9 +15,15 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; -void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, - tt::CoarseSchedule::Cluster cluster, - bool includeArg) { +void tt::CoarseSchedule::insertDepsOfOp( + Operation *op, int stage, tt::CoarseSchedule::Cluster cluster, + bool includeArg, DenseMap *additionalDep) { + // Look in additionalDep. + if (additionalDep && additionalDep->find(op) != additionalDep->end()) { + Operation *wait = (*additionalDep)[op]; + if (insertIfAbsent(wait, stage, cluster)) + insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep); + } for (Value operand : op->getOperands()) { Value v = operand; llvm::SmallDenseSet seen; @@ -36,7 +42,7 @@ void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, Operation *defOp = v.getDefiningOp(); if (defOp && defOp->getBlock() == op->getBlock()) { if (insertIfAbsent(defOp, stage, cluster)) { - insertDepsOfOp(defOp, stage, cluster, includeArg); + insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep); } } }