From 11b866938af7371032b81875af5297a1a43685c2 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 12 Nov 2024 21:39:17 +0100 Subject: [PATCH] Reduction loop GEMM to XSMM BRGEMM (#18) Adds experimental rewrite collapsing reduction loop over GEMM into a BRGEMM ukernel. The pattern matches the hand-written kernel using block pointers and is not compatible with IR generated by triton pointer raising. Direct lowering to XSMM allows to bypass triton load restriction when K dimension is not power-of-two. The pattern is quite brittle but functional for the matmul tutorial example. The rewriting is disable by default and can be enabled with environment variable: TRITON_CPU_LOOP_BRGEMM_XSMM=1 --- third_party/cpu/backend/compiler.py | 8 +- third_party/cpu/include/Xsmm/Passes.td | 12 + .../cpu/lib/Xsmm/ConvertTritonToXsmm.cpp | 253 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 3 + 4 files changed, 275 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 8dfd16e7f4e9..8106a5cb7163 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -50,6 +50,7 @@ class CPUOptions: enable_fast_math: bool = True enable_vector_xsmm: bool = False enable_triton_xsmm: bool = False + enable_loop_brgemm_xsmm: bool = False enable_raise_block_pointer: bool = False vec_lib: Optional[str] = 'libsleef' # TODO: Try to enable it. @@ -110,6 +111,8 @@ def parse_options(self, opts) -> Any: args["enable_vector_xsmm"] = os.getenv("TRITON_CPU_VECTOR_XSMM", "0") != "0" if "enable_triton_xsmm" not in args: args["enable_triton_xsmm"] = os.getenv("TRITON_CPU_TRITON_XSMM", "0") != "0" + if "enable_loop_brgemm_xsmm" not in args: + args["enable_loop_brgemm_xsmm"] = os.getenv("TRITON_CPU_LOOP_BRGEMM_XSMM", "0") != "0" if "enable_raise_block_pointer" not in args: args["enable_raise_block_pointer"] = os.getenv("TRITON_CPU_RAISE_BLOCK_POINTER", "0") != "0" return CPUOptions(**args) @@ -150,6 +153,9 @@ def make_ttcir(mod, metadata, opt): pm.enable_debug() if opt.enable_raise_block_pointer: cpu.passes.ttcpuir.add_raise_block_pointer(pm) + if opt.enable_loop_brgemm_xsmm: + cpu.passes.ttcpuir.add_loop_to_brgemm_xsmm(pm) + passes.common.add_canonicalizer(pm) if opt.enable_triton_xsmm: cpu.passes.ttcpuir.add_convert_triton_to_xsmm(pm) passes.common.add_canonicalizer(pm) @@ -287,7 +293,7 @@ def make_so(src, metadata, options): Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs libs = ["m", "TritonCPURuntime", "sleef"] - if options.enable_vector_xsmm or options.enable_triton_xsmm: + if options.enable_vector_xsmm or options.enable_triton_xsmm or options.enable_loop_brgemm_xsmm: libs.extend(["xsmm", "TritonCPUXsmmRuntime"]) so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: diff --git a/third_party/cpu/include/Xsmm/Passes.td b/third_party/cpu/include/Xsmm/Passes.td index 5527c233de7a..08ecdf76c3cf 100644 --- a/third_party/cpu/include/Xsmm/Passes.td +++ b/third_party/cpu/include/Xsmm/Passes.td @@ -27,4 +27,16 @@ def ConvertTritonToXsmm : Pass<"triton-cpu-convert-triton-to-xsmm", "mlir::Modul "LLVM::LLVMDialect"]; } +def LoopToBrgemmXsmm : Pass<"triton-cpu-loop-to-brgemm-xsmm", "mlir::ModuleOp"> { + let summary = "Redution loop GEMM to BRGEMM"; + let description = [{ + Collapse reduction loop over GEMM to XSMM BRGEMM kernel. + }]; + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", + "memref::MemRefDialect", + "triton::cpu::TritonCPUDialect", + "LLVM::LLVMDialect"]; +} + #endif diff --git a/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp index d2e7a3522ca4..c3b28736ff56 100644 --- a/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp +++ b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp @@ -43,6 +43,7 @@ namespace mlir { namespace triton { namespace cpu { #define GEN_PASS_DEF_CONVERTTRITONTOXSMM +#define GEN_PASS_DEF_LOOPTOBRGEMMXSMM #include "cpu/include/Xsmm/Passes.h.inc" } // namespace cpu } // namespace triton @@ -279,6 +280,241 @@ struct DotToXsmm : public OpRewritePattern { ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; }; +// Collapse whole reduction loop with a GEMM into equivalent BRGEMM operation. +// Rewrites the following pattern: +// %0 = tt.make_tensor_ptr %base_ptr0 : tensor +// %1 = tt.make_tensor_ptr %base_ptr1 : tensor +// %res:3 = scf.for %arg3 = %lb to %ub step %step +// iter_args(%acc = %init_val, %ptr_A = %0, %ptr_B = %1) +// %A = tt.load %ptr_A +// %B = tt.load %ptr_B +// %dot = tt.dot %A, %B, %acc +// %ptr_A_next = tt.advance %ptr_A, [0, %stepK] +// %ptr_B_next = tt.advance %ptr_B, [%stepK, %0] +// scf.yield %dot, %ptr_A_next, %ptr_B_next +// into: +// %A = tt.make_tensor_ptr %base_ptr0 : tensor +// %B = tt.make_tensor_ptr %base_ptr1 : tensor +// %res0 = BRGEMM %A, %B, %init_val +// %res1 = tt.advance %A, [0, ((%ub - %lb) / %step) * %stepK] +// %res2 = tt.advance %B, [((%ub - %lb) / %step) * %stepK, 0] +struct DotReductionLoopToBrgemm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + DotReductionLoopToBrgemm(MLIRContext *context, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, + PatternBenefit benefit = 10) + : OpRewritePattern(context, benefit), + shapeAnalysis(shapeInfoAnalysis) {} + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + Location loc = dotOp.getLoc(); + MLIRContext *ctx = dotOp.getContext(); + + // Check if there is any loop around the contraction and if the accumulation + // value comes from loop's arguments. + TypedValue acc = dotOp.getC(); + if (acc.getType().getRank() != 2) + return rewriter.notifyMatchFailure(dotOp, "expects 2D GEMM"); + + auto forOp = dyn_cast(dotOp->getParentOp()); + BlockArgument accBbArg = dyn_cast(acc); + if (!forOp || !accBbArg) + return rewriter.notifyMatchFailure(dotOp, "not a reduction loop"); + OpOperand *accArg = forOp.getTiedLoopInit(accBbArg); + if (!accArg) + return rewriter.notifyMatchFailure( + dotOp, "expects iter_args accumulation value"); + // TODO: Relax this check. It is needed to collapse whole loop but + // alternatively only BRGEMM could be pulled out. + if (forOp.getNumRegionIterArgs() != 3) + return rewriter.notifyMatchFailure(dotOp, "invalid number of iter_args"); + + // Assume that the loop's range and all pointer advances are known + // statically. Thus, the induction variable should be unused. + Value loopIv = forOp.getInductionVar(); + if (!loopIv.use_empty()) + return rewriter.notifyMatchFailure(dotOp, + "expects unused induction variable"); + + // The subgraph should a simple reduction loop containing a GEMM operation. + // Validate presence of the following chain: + // iter_arg -> contraction -> yield + // and that there are no other users. + TypedValue res = dotOp.getD(); + if (!acc.hasOneUse() || !res.hasOneUse() || + !isa(*res.getUsers().begin())) + return rewriter.notifyMatchFailure(dotOp, "GEMM subgraph does not match"); + + auto loadMatA = dotOp.getA().getDefiningOp(); + auto loadMatB = dotOp.getB().getDefiningOp(); + if (!loadMatA || !loadMatB) + return rewriter.notifyMatchFailure(dotOp, "expect GEMM input loads"); + if (!loadMatA->hasOneUse() || !loadMatB->hasOneUse()) + return rewriter.notifyMatchFailure(dotOp, + "Input loads subgraph does not match"); + + // Constrain input pointers to the following subgraph: + // iter_arg -> (load, increment) -> yield + BlockArgument lhsBbArg = dyn_cast(loadMatA.getPtr()); + BlockArgument rhsBbArg = dyn_cast(loadMatB.getPtr()); + if (!lhsBbArg || !rhsBbArg) + return rewriter.notifyMatchFailure(dotOp, "expect block arg pointers"); + OpOperand *lhsArg = forOp.getTiedLoopInit(lhsBbArg); + OpOperand *rhsArg = forOp.getTiedLoopInit(rhsBbArg); + if (!lhsArg || + std::distance(lhsBbArg.use_begin(), lhsBbArg.use_end()) != 2 || + !rhsArg || std::distance(rhsBbArg.use_begin(), rhsBbArg.use_end()) != 2) + return rewriter.notifyMatchFailure(dotOp, "expect iter_args pointers"); + + // Input sources should be block pointers. + // TODO: Account for transposed GEMM operands. + auto lhsBlockPtr = dyn_cast_or_null( + lhsArg->get().getDefiningOp()); + auto rhsBlockPtr = dyn_cast_or_null( + rhsArg->get().getDefiningOp()); + if (!lhsBlockPtr || lhsBlockPtr.getOrder() != ArrayRef{1, 0} || + !rhsBlockPtr || rhsBlockPtr.getOrder() != ArrayRef{1, 0}) + return rewriter.notifyMatchFailure(dotOp, "expected block pointers"); + + // Check for pointer increments and validate their steps. + // Each input is expected to advance only in its reduction dimension. + auto lhsAdvanceOp = forOp.getTiedLoopYieldedValue(lhsBbArg) + ->get() + .getDefiningOp(); + auto rhsAdvanceOp = forOp.getTiedLoopYieldedValue(rhsBbArg) + ->get() + .getDefiningOp(); + if (!lhsAdvanceOp || !rhsAdvanceOp) + return rewriter.notifyMatchFailure(dotOp, "expected ptr advance"); + if (!lhsAdvanceOp->hasOneUse() || !rhsAdvanceOp->hasOneUse()) + return rewriter.notifyMatchFailure( + dotOp, "Ptr increment subgraph does not match"); + + auto resShape = res.getType().getShape(); + auto lhsShape = dotOp.getA().getType().getShape(); + auto lhsPtrOffsets = lhsAdvanceOp.getOffsets(); + auto lhsStepParallel = getConstantIntValue(lhsPtrOffsets[0]); + auto lhsStepReduction = getConstantIntValue(lhsPtrOffsets[1]); + if (!lhsStepParallel || *lhsStepParallel != 0 || !lhsStepReduction || + *lhsStepReduction != lhsShape[1]) + return rewriter.notifyMatchFailure(dotOp, "invalid lhs increments"); + + auto rhsPtrOffsets = rhsAdvanceOp.getOffsets(); + auto rhsStepReduction = getConstantIntValue(rhsPtrOffsets[0]); + auto rhsStepParallel = getConstantIntValue(rhsPtrOffsets[1]); + if (!rhsStepReduction || *rhsStepReduction != *lhsStepReduction || + !rhsStepParallel || *rhsStepParallel != 0) + return rewriter.notifyMatchFailure(dotOp, "invalid rhs increments"); + + // Collapse the loop and create equivalent BRGEMM operation. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + + // TODO: Validate if number of tiles cleanly divides the source buffer. + auto loopRange = rewriter.create(loc, forOp.getUpperBound(), + forOp.getLowerBound()); + Value numTiles = + rewriter.create(loc, loopRange, forOp.getStep()); + numTiles = rewriter.create(loc, rewriter.getIndexType(), + numTiles); + auto kStepCst = + rewriter.create(loc, *lhsStepReduction); + auto fullKDimLength = + rewriter.create(loc, numTiles, kStepCst); + + // Create new mmeref views spanning the whole reduction dimension. + SmallVector strides(2, 1); + auto lhsMemref = extractMemRef(rewriter, lhsBlockPtr, shapeAnalysis); + auto lhsIndices = + rewriter.create(loc, lhsBlockPtr) + .getResults(); + auto lhsBuf = rewriter.create( + loc, lhsMemref, getAsOpFoldResult(lhsIndices), + SmallVector{getAsIndexOpFoldResult(ctx, resShape[0]), + getAsOpFoldResult(fullKDimLength)}, + getAsIndexOpFoldResult(ctx, strides)); + + auto rhsMemref = extractMemRef(rewriter, rhsBlockPtr, shapeAnalysis); + auto rhsIndices = + rewriter.create(loc, rhsBlockPtr) + .getResults(); + auto rhsBuf = rewriter.create( + loc, rhsMemref, getAsOpFoldResult(rhsIndices), + SmallVector{getAsOpFoldResult(fullKDimLength), + getAsIndexOpFoldResult(ctx, resShape[1])}, + getAsIndexOpFoldResult(ctx, strides)); + + Value accBuf = + getMemrefSource(rewriter, forOp, + dyn_cast>( + accArg->get().getDefiningOp()->getResult(0)), + shapeAnalysis); + + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + SmallVector lhsOutSizes{ + getAsIndexOpFoldResult(ctx, resShape[0]), getAsOpFoldResult(numTiles), + getAsIndexOpFoldResult(ctx, *lhsStepReduction)}; + auto expandA = rewriter.create( + loc, + SmallVector{resShape[0], ShapedType::kDynamic, + *lhsStepReduction}, + lhsBuf, SmallVector{{0}, {1, 2}}, lhsOutSizes); + SmallVector rhsOutSizes{ + getAsOpFoldResult(numTiles), + getAsIndexOpFoldResult(ctx, *rhsStepReduction), + getAsIndexOpFoldResult(ctx, resShape[1])}; + auto expandB = rewriter.create( + loc, + SmallVector{ShapedType::kDynamic, *rhsStepReduction, + resShape[1]}, + rhsBuf, SmallVector{{0, 1}, {2}}, rhsOutSizes); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + SmallVector indexingMaps{mapA, mapB, mapC}; + + // Create single equivalent BRGEMM. + SmallVector inputs{expandA, expandB, accBuf}; + SmallVector flags; + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, dotOp, inputs, + indexingMaps, flags); + + // Load back the result to bring it back to tensor semantics. + auto loadOp = + rewriter.create(loc, res.getType(), accBuf); + + // Increment the base pointers such that the whole loop can be removed. + // TODO: Revisit this part. + // Only the BRGEMM could be pulled out of the loop and the rest + // could be left as is. + Value zero = rewriter.create(loc, 0); + Value reductionStepConst = + rewriter.create(loc, *lhsStepReduction); + Value reductionOffset = + rewriter.create(loc, reductionStepConst, numTiles); + auto advanceA = rewriter.create( + loc, lhsBlockPtr.getResult().getType(), lhsBlockPtr, + ValueRange{zero, reductionOffset}); + auto advanceB = rewriter.create( + loc, rhsBlockPtr.getResult().getType(), rhsBlockPtr, + ValueRange{reductionOffset, zero}); + + rewriter.replaceOp(forOp, + ValueRange{loadOp.getResult(), advanceA.getResult(), + advanceB.getResult()}); + + return success(); + } + +private: + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; +}; + struct ConvertTritonToXsmm : public triton::cpu::impl::ConvertTritonToXsmmBase { using ConvertTritonToXsmmBase::ConvertTritonToXsmmBase; @@ -296,4 +532,21 @@ struct ConvertTritonToXsmm } }; +struct LoopToBrgemmXsmm + : public triton::cpu::impl::LoopToBrgemmXsmmBase { + using LoopToBrgemmXsmmBase::LoopToBrgemmXsmmBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + + RewritePatternSet patterns(context); + patterns.add(context, shapeInfoAnalysis); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + } // namespace diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 3a3bc18f9cb7..c5530cb24cee 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -179,6 +179,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_triton_to_xsmm", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertTritonToXsmm()); }); + m.def("add_loop_to_brgemm_xsmm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createLoopToBrgemmXsmm()); + }); } void init_triton_cpu(py::module &&m) {