diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index e122f15fd901..22349c50e308 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,23 +5,28 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; namespace ttg = mlir::triton::gpu; -namespace tt = mlir::triton; - -static bool isLocalLoadOrDotLayoutConversion(Operation *op) { - if (isa(op)) - return true; - if (auto cvt = dyn_cast(op)) - return isa(cvt.getType().getEncoding()); - return false; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +// Return true if the given moduleOp contains a pure matmul problem; i.e., +// single dot in the main loop. +static bool isPureMatmulProblem(ModuleOp moduleOp) { + for (auto forOp : moduleOp.getOps()) { + int counter = 0; + forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); + if (counter != 1) + return false; + } + return true; } // Search through block to find earliest insertion point for move op. This can @@ -61,194 +66,233 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } +// Return the first user in the same block of the given op. If the user is in a +// nested block then return the op owning the block. Return nullptr if not +// existing. +static Operation *getFirstUseInSameBlock(Operation *op) { + SmallVector usersInSameBlock; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + usersInSameBlock.push_back(ancestor); + } + auto minOpIt = + llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; +} + // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -class TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { -public: - TritonAMDGPUReorderInstructionsPass() = default; - - Operation *getFirstUse(Operation *op) { - std::vector users; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - users.push_back(ancestor); - } - auto minOpIt = std::min_element(users.begin(), users.end(), - [](mlir::Operation *a, mlir::Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != users.end() ? *minOpIt : nullptr; - } +//===----------------------------------------------------------------------===// +// Reorder mechanisms +//===----------------------------------------------------------------------===// - void runOnOperation() override { - ModuleOp m = getOperation(); +// Sink dot layout conversions into loops to decrease register pressure when +// possible. +static void sinkDotConversion(ModuleOp moduleOp) { + DenseMap opToMove; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { + Attribute encoding = op.getType().getEncoding(); + if (!isa_and_nonnull(encoding)) + return; + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove[op] = user; + }); - // Sink shared memory loads and layout conversions into loops to decrease - // register pressure when possible. - DenseMap opToMove; - m.walk([&](Operation *op) { - if (!isLocalLoadOrDotLayoutConversion(op)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); +} + +// Adjust the placement of shared memory writes and reads to immediately follow +// the definition of their operands in case where shared memory write is in the +// loop but its operand is not. +// +// This is a heuristic driven by optimizing fused attention by hoisting Q tensor +// shared memory read/write operations outside of the loop, as Q is a loop +// invariant and can be loaded once before entering the loop. But it should be +// generally applicable. +// +// There are two possible patterns for this adjustment depending on whether the +// write to shared memory is performed using an optional `local_alloc` argument +// or a `local_store` instruction. +// +// 1) %1 = some_op ... (typically a load or an operation that scales the tensor +// after loading) +// %2 = local_alloc %1 +// %3 = local_load %2 +// +// 2) %1 = some_op ... +// %2 = local_alloc +// %3 = local_store %1, %2 +// %4 = local_load %2 +static void hoistLocalLoad(ModuleOp moduleOp) { + moduleOp.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) return; - opToMove.insert({op, user}); - }); - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); - opToMove.clear(); - - // Adjust the placement of LDS writes and reads to immediately follow the - // definition of their operands in case where LDS write is in the - // loop but it's operand is not. This is a heuristic for optimizing fused - // attention by hoisting Q tensor LDS read/write operations outside of the - // loop, as Q is a loop invariant and can be loaded once before entering the - // loop. - // There are two possible patterns for this adjustment depending on - // whether the write to LDS is performed using an optional `local_alloc` - // argument or a `local_store` instruction. - // - // clang-format off - // - // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) - // %2 = local_alloc %1 - // %3 = local_load %2 - // - // 2) %1 = some_op ... - // %2 = local_alloc - // %3 = local_store %1, %2 - // %4 = local_load %2 - // - // clang-format on - m.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) + + auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) return; - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) - return; + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } - auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { - return; - } + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); + // localStore comes before localLoad in block. + Operation *localStore = getFirstUseInSameBlock(localAlloc); + if (!isa(localStore)) + return; - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - // localStore comes before localLoad in block. - Operation *localStore = getFirstUse(localAlloc); - if (!isa(localStore)) - return; + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); +} - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; - } +// Sink conversion after the last dealloc but before the first use in its block. +// This helps to avoid unnecessary shared memory allocation. +static void moveDownCoversion(ModuleOp moduleOp) { + SmallVector convertOps; + moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); + for (auto op : convertOps) { + Operation *user = getFirstUseInSameBlock(op); + for (auto it = Block::iterator(op), ie = op->getBlock()->end(); + it != ie && &*it != user; ++it) + if (isa(&*it)) + op->moveAfter(&*it); + } +} - // Sink conversion after the last dealloc but before the first use ancestor - // in its block. This helps to avoid unnecessary shared memory allocation. - m.walk([&](triton::gpu::ConvertLayoutOp op) { - auto curr = mlir::Block::iterator(op); - for (; &*curr != getFirstUse(op); curr++) - if (isa(&*curr)) - op->moveAfter(&*curr); - }); +// Move transpositions just after their definition. +static void moveUpTranspose(ModuleOp moduleOp) { + SmallVector transOps; + moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); - // Move transpositions just after their definition. - m.walk([&](triton::TransOp op) { - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); - }); + for (auto op : transOps) + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); +} - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than - // one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); - - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; - - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); - - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; - - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. - - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); - } +// Schedule global load and local store ops for better GEMM performance. +static void scheduleGlobalLoadLocalStore(ModuleOp m) { + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); } } +} + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { + void runOnOperation() override { + ModuleOp m = getOperation(); + + hoistLocalLoad(m); + + sinkDotConversion(m); + moveDownCoversion(m); + + moveUpTranspose(m); + + if (isPureMatmulProblem(m)) + scheduleGlobalLoadLocalStore(m); + } }; +} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique();