Skip to content

Commit

Permalink
Reduction loop GEMM to XSMM BRGEMM (triton-lang#18)
Browse files Browse the repository at this point in the history
Adds experimental rewrite collapsing reduction loop over GEMM into a BRGEMM ukernel.

The pattern matches the hand-written kernel using block pointers and is not compatible with IR generated by triton pointer raising. Direct lowering to XSMM allows to bypass triton load restriction when K dimension is not power-of-two.
The pattern is quite brittle but functional for the matmul tutorial example.

The rewriting is disable by default and can be enabled with environment variable:
  TRITON_CPU_LOOP_BRGEMM_XSMM=1
  • Loading branch information
adam-smnk authored and ienkovich committed Nov 19, 2024
1 parent 70e769c commit 11b8669
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 1 deletion.
8 changes: 7 additions & 1 deletion third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class CPUOptions:
enable_fast_math: bool = True
enable_vector_xsmm: bool = False
enable_triton_xsmm: bool = False
enable_loop_brgemm_xsmm: bool = False
enable_raise_block_pointer: bool = False
vec_lib: Optional[str] = 'libsleef'
# TODO: Try to enable it.
Expand Down Expand Up @@ -110,6 +111,8 @@ def parse_options(self, opts) -> Any:
args["enable_vector_xsmm"] = os.getenv("TRITON_CPU_VECTOR_XSMM", "0") != "0"
if "enable_triton_xsmm" not in args:
args["enable_triton_xsmm"] = os.getenv("TRITON_CPU_TRITON_XSMM", "0") != "0"
if "enable_loop_brgemm_xsmm" not in args:
args["enable_loop_brgemm_xsmm"] = os.getenv("TRITON_CPU_LOOP_BRGEMM_XSMM", "0") != "0"
if "enable_raise_block_pointer" not in args:
args["enable_raise_block_pointer"] = os.getenv("TRITON_CPU_RAISE_BLOCK_POINTER", "0") != "0"
return CPUOptions(**args)
Expand Down Expand Up @@ -150,6 +153,9 @@ def make_ttcir(mod, metadata, opt):
pm.enable_debug()
if opt.enable_raise_block_pointer:
cpu.passes.ttcpuir.add_raise_block_pointer(pm)
if opt.enable_loop_brgemm_xsmm:
cpu.passes.ttcpuir.add_loop_to_brgemm_xsmm(pm)
passes.common.add_canonicalizer(pm)
if opt.enable_triton_xsmm:
cpu.passes.ttcpuir.add_convert_triton_to_xsmm(pm)
passes.common.add_canonicalizer(pm)
Expand Down Expand Up @@ -287,7 +293,7 @@ def make_so(src, metadata, options):
Path(asm_path).write_text(src)
lib_dirs = cpu_driver.library_dirs
libs = ["m", "TritonCPURuntime", "sleef"]
if options.enable_vector_xsmm or options.enable_triton_xsmm:
if options.enable_vector_xsmm or options.enable_triton_xsmm or options.enable_loop_brgemm_xsmm:
libs.extend(["xsmm", "TritonCPUXsmmRuntime"])
so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs)
with open(so, "rb") as f:
Expand Down
12 changes: 12 additions & 0 deletions third_party/cpu/include/Xsmm/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,16 @@ def ConvertTritonToXsmm : Pass<"triton-cpu-convert-triton-to-xsmm", "mlir::Modul
"LLVM::LLVMDialect"];
}

def LoopToBrgemmXsmm : Pass<"triton-cpu-loop-to-brgemm-xsmm", "mlir::ModuleOp"> {
let summary = "Redution loop GEMM to BRGEMM";
let description = [{
Collapse reduction loop over GEMM to XSMM BRGEMM kernel.
}];
let dependentDialects = ["arith::ArithDialect",
"func::FuncDialect",
"memref::MemRefDialect",
"triton::cpu::TritonCPUDialect",
"LLVM::LLVMDialect"];
}

#endif
253 changes: 253 additions & 0 deletions third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace mlir {
namespace triton {
namespace cpu {
#define GEN_PASS_DEF_CONVERTTRITONTOXSMM
#define GEN_PASS_DEF_LOOPTOBRGEMMXSMM
#include "cpu/include/Xsmm/Passes.h.inc"
} // namespace cpu
} // namespace triton
Expand Down Expand Up @@ -279,6 +280,241 @@ struct DotToXsmm : public OpRewritePattern<triton::DotOp> {
ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis;
};

// Collapse whole reduction loop with a GEMM into equivalent BRGEMM operation.
// Rewrites the following pattern:
// %0 = tt.make_tensor_ptr %base_ptr0 : tensor<M x K>
// %1 = tt.make_tensor_ptr %base_ptr1 : tensor<K x N>
// %res:3 = scf.for %arg3 = %lb to %ub step %step
// iter_args(%acc = %init_val, %ptr_A = %0, %ptr_B = %1)
// %A = tt.load %ptr_A
// %B = tt.load %ptr_B
// %dot = tt.dot %A, %B, %acc
// %ptr_A_next = tt.advance %ptr_A, [0, %stepK]
// %ptr_B_next = tt.advance %ptr_B, [%stepK, %0]
// scf.yield %dot, %ptr_A_next, %ptr_B_next
// into:
// %A = tt.make_tensor_ptr %base_ptr0 : tensor<M x TILES x k>
// %B = tt.make_tensor_ptr %base_ptr1 : tensor<TILES x k x N>
// %res0 = BRGEMM %A, %B, %init_val
// %res1 = tt.advance %A, [0, ((%ub - %lb) / %step) * %stepK]
// %res2 = tt.advance %B, [((%ub - %lb) / %step) * %stepK, 0]
struct DotReductionLoopToBrgemm : public OpRewritePattern<triton::DotOp> {
using OpRewritePattern::OpRewritePattern;

DotReductionLoopToBrgemm(MLIRContext *context,
ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis,
PatternBenefit benefit = 10)
: OpRewritePattern<triton::DotOp>(context, benefit),
shapeAnalysis(shapeInfoAnalysis) {}

LogicalResult matchAndRewrite(triton::DotOp dotOp,
PatternRewriter &rewriter) const override {
Location loc = dotOp.getLoc();
MLIRContext *ctx = dotOp.getContext();

// Check if there is any loop around the contraction and if the accumulation
// value comes from loop's arguments.
TypedValue<RankedTensorType> acc = dotOp.getC();
if (acc.getType().getRank() != 2)
return rewriter.notifyMatchFailure(dotOp, "expects 2D GEMM");

auto forOp = dyn_cast<scf::ForOp>(dotOp->getParentOp());
BlockArgument accBbArg = dyn_cast<BlockArgument>(acc);
if (!forOp || !accBbArg)
return rewriter.notifyMatchFailure(dotOp, "not a reduction loop");
OpOperand *accArg = forOp.getTiedLoopInit(accBbArg);
if (!accArg)
return rewriter.notifyMatchFailure(
dotOp, "expects iter_args accumulation value");
// TODO: Relax this check. It is needed to collapse whole loop but
// alternatively only BRGEMM could be pulled out.
if (forOp.getNumRegionIterArgs() != 3)
return rewriter.notifyMatchFailure(dotOp, "invalid number of iter_args");

// Assume that the loop's range and all pointer advances are known
// statically. Thus, the induction variable should be unused.
Value loopIv = forOp.getInductionVar();
if (!loopIv.use_empty())
return rewriter.notifyMatchFailure(dotOp,
"expects unused induction variable");

// The subgraph should a simple reduction loop containing a GEMM operation.
// Validate presence of the following chain:
// iter_arg -> contraction -> yield
// and that there are no other users.
TypedValue<RankedTensorType> res = dotOp.getD();
if (!acc.hasOneUse() || !res.hasOneUse() ||
!isa<scf::YieldOp>(*res.getUsers().begin()))
return rewriter.notifyMatchFailure(dotOp, "GEMM subgraph does not match");

auto loadMatA = dotOp.getA().getDefiningOp<triton::LoadOp>();
auto loadMatB = dotOp.getB().getDefiningOp<triton::LoadOp>();
if (!loadMatA || !loadMatB)
return rewriter.notifyMatchFailure(dotOp, "expect GEMM input loads");
if (!loadMatA->hasOneUse() || !loadMatB->hasOneUse())
return rewriter.notifyMatchFailure(dotOp,
"Input loads subgraph does not match");

// Constrain input pointers to the following subgraph:
// iter_arg -> (load, increment) -> yield
BlockArgument lhsBbArg = dyn_cast<BlockArgument>(loadMatA.getPtr());
BlockArgument rhsBbArg = dyn_cast<BlockArgument>(loadMatB.getPtr());
if (!lhsBbArg || !rhsBbArg)
return rewriter.notifyMatchFailure(dotOp, "expect block arg pointers");
OpOperand *lhsArg = forOp.getTiedLoopInit(lhsBbArg);
OpOperand *rhsArg = forOp.getTiedLoopInit(rhsBbArg);
if (!lhsArg ||
std::distance(lhsBbArg.use_begin(), lhsBbArg.use_end()) != 2 ||
!rhsArg || std::distance(rhsBbArg.use_begin(), rhsBbArg.use_end()) != 2)
return rewriter.notifyMatchFailure(dotOp, "expect iter_args pointers");

// Input sources should be block pointers.
// TODO: Account for transposed GEMM operands.
auto lhsBlockPtr = dyn_cast_or_null<triton::MakeTensorPtrOp>(
lhsArg->get().getDefiningOp());
auto rhsBlockPtr = dyn_cast_or_null<triton::MakeTensorPtrOp>(
rhsArg->get().getDefiningOp());
if (!lhsBlockPtr || lhsBlockPtr.getOrder() != ArrayRef<int32_t>{1, 0} ||
!rhsBlockPtr || rhsBlockPtr.getOrder() != ArrayRef<int32_t>{1, 0})
return rewriter.notifyMatchFailure(dotOp, "expected block pointers");

// Check for pointer increments and validate their steps.
// Each input is expected to advance only in its reduction dimension.
auto lhsAdvanceOp = forOp.getTiedLoopYieldedValue(lhsBbArg)
->get()
.getDefiningOp<triton::AdvanceOp>();
auto rhsAdvanceOp = forOp.getTiedLoopYieldedValue(rhsBbArg)
->get()
.getDefiningOp<triton::AdvanceOp>();
if (!lhsAdvanceOp || !rhsAdvanceOp)
return rewriter.notifyMatchFailure(dotOp, "expected ptr advance");
if (!lhsAdvanceOp->hasOneUse() || !rhsAdvanceOp->hasOneUse())
return rewriter.notifyMatchFailure(
dotOp, "Ptr increment subgraph does not match");

auto resShape = res.getType().getShape();
auto lhsShape = dotOp.getA().getType().getShape();
auto lhsPtrOffsets = lhsAdvanceOp.getOffsets();
auto lhsStepParallel = getConstantIntValue(lhsPtrOffsets[0]);
auto lhsStepReduction = getConstantIntValue(lhsPtrOffsets[1]);
if (!lhsStepParallel || *lhsStepParallel != 0 || !lhsStepReduction ||
*lhsStepReduction != lhsShape[1])
return rewriter.notifyMatchFailure(dotOp, "invalid lhs increments");

auto rhsPtrOffsets = rhsAdvanceOp.getOffsets();
auto rhsStepReduction = getConstantIntValue(rhsPtrOffsets[0]);
auto rhsStepParallel = getConstantIntValue(rhsPtrOffsets[1]);
if (!rhsStepReduction || *rhsStepReduction != *lhsStepReduction ||
!rhsStepParallel || *rhsStepParallel != 0)
return rewriter.notifyMatchFailure(dotOp, "invalid rhs increments");

// Collapse the loop and create equivalent BRGEMM operation.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(forOp);

// TODO: Validate if number of tiles cleanly divides the source buffer.
auto loopRange = rewriter.create<arith::SubIOp>(loc, forOp.getUpperBound(),
forOp.getLowerBound());
Value numTiles =
rewriter.create<arith::DivUIOp>(loc, loopRange, forOp.getStep());
numTiles = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numTiles);
auto kStepCst =
rewriter.create<arith::ConstantIndexOp>(loc, *lhsStepReduction);
auto fullKDimLength =
rewriter.create<arith::MulIOp>(loc, numTiles, kStepCst);

// Create new mmeref views spanning the whole reduction dimension.
SmallVector<int64_t> strides(2, 1);
auto lhsMemref = extractMemRef(rewriter, lhsBlockPtr, shapeAnalysis);
auto lhsIndices =
rewriter.create<triton::cpu::ExtractIndicesOp>(loc, lhsBlockPtr)
.getResults();
auto lhsBuf = rewriter.create<memref::SubViewOp>(
loc, lhsMemref, getAsOpFoldResult(lhsIndices),
SmallVector<OpFoldResult>{getAsIndexOpFoldResult(ctx, resShape[0]),
getAsOpFoldResult(fullKDimLength)},
getAsIndexOpFoldResult(ctx, strides));

auto rhsMemref = extractMemRef(rewriter, rhsBlockPtr, shapeAnalysis);
auto rhsIndices =
rewriter.create<triton::cpu::ExtractIndicesOp>(loc, rhsBlockPtr)
.getResults();
auto rhsBuf = rewriter.create<memref::SubViewOp>(
loc, rhsMemref, getAsOpFoldResult(rhsIndices),
SmallVector<OpFoldResult>{getAsOpFoldResult(fullKDimLength),
getAsIndexOpFoldResult(ctx, resShape[1])},
getAsIndexOpFoldResult(ctx, strides));

Value accBuf =
getMemrefSource(rewriter, forOp,
dyn_cast<TypedValue<RankedTensorType>>(
accArg->get().getDefiningOp()->getResult(0)),
shapeAnalysis);

// Split reduction dimension into tiles.
// The number of tiles represents the batch dimension.
SmallVector<OpFoldResult> lhsOutSizes{
getAsIndexOpFoldResult(ctx, resShape[0]), getAsOpFoldResult(numTiles),
getAsIndexOpFoldResult(ctx, *lhsStepReduction)};
auto expandA = rewriter.create<memref::ExpandShapeOp>(
loc,
SmallVector<int64_t>{resShape[0], ShapedType::kDynamic,
*lhsStepReduction},
lhsBuf, SmallVector<ReassociationIndices>{{0}, {1, 2}}, lhsOutSizes);
SmallVector<OpFoldResult> rhsOutSizes{
getAsOpFoldResult(numTiles),
getAsIndexOpFoldResult(ctx, *rhsStepReduction),
getAsIndexOpFoldResult(ctx, resShape[1])};
auto expandB = rewriter.create<memref::ExpandShapeOp>(
loc,
SmallVector<int64_t>{ShapedType::kDynamic, *rhsStepReduction,
resShape[1]},
rhsBuf, SmallVector<ReassociationIndices>{{0, 1}, {2}}, rhsOutSizes);

// Update maps with BRGEMM indexing.
auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx);
auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx);
auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx);
SmallVector<AffineMap> indexingMaps{mapA, mapB, mapC};

// Create single equivalent BRGEMM.
SmallVector<Value> inputs{expandA, expandB, accBuf};
SmallVector<Attribute> flags;
auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, dotOp, inputs,
indexingMaps, flags);

// Load back the result to bring it back to tensor semantics.
auto loadOp =
rewriter.create<triton::cpu::LoadOp>(loc, res.getType(), accBuf);

// Increment the base pointers such that the whole loop can be removed.
// TODO: Revisit this part.
// Only the BRGEMM could be pulled out of the loop and the rest
// could be left as is.
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value reductionStepConst =
rewriter.create<arith::ConstantIndexOp>(loc, *lhsStepReduction);
Value reductionOffset =
rewriter.create<arith::MulIOp>(loc, reductionStepConst, numTiles);
auto advanceA = rewriter.create<triton::AdvanceOp>(
loc, lhsBlockPtr.getResult().getType(), lhsBlockPtr,
ValueRange{zero, reductionOffset});
auto advanceB = rewriter.create<triton::AdvanceOp>(
loc, rhsBlockPtr.getResult().getType(), rhsBlockPtr,
ValueRange{reductionOffset, zero});

rewriter.replaceOp(forOp,
ValueRange{loadOp.getResult(), advanceA.getResult(),
advanceB.getResult()});

return success();
}

private:
ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis;
};

struct ConvertTritonToXsmm
: public triton::cpu::impl::ConvertTritonToXsmmBase<ConvertTritonToXsmm> {
using ConvertTritonToXsmmBase::ConvertTritonToXsmmBase;
Expand All @@ -296,4 +532,21 @@ struct ConvertTritonToXsmm
}
};

struct LoopToBrgemmXsmm
: public triton::cpu::impl::LoopToBrgemmXsmmBase<LoopToBrgemmXsmm> {
using LoopToBrgemmXsmmBase::LoopToBrgemmXsmmBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod);

RewritePatternSet patterns(context);
patterns.add<DotReductionLoopToBrgemm>(context, shapeInfoAnalysis);
if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
3 changes: 3 additions & 0 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
m.def("add_convert_triton_to_xsmm", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertTritonToXsmm());
});
m.def("add_loop_to_brgemm_xsmm", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createLoopToBrgemmXsmm());
});
}

void init_triton_cpu(py::module &&m) {
Expand Down

0 comments on commit 11b8669

Please sign in to comment.