Skip to content

Commit

Permalink
Small changes to prepare conversion to Triton for BCSR format
Browse files Browse the repository at this point in the history
pthomadakis committed Sep 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 03addda commit 7cfaee1
Showing 3 changed files with 94 additions and 19 deletions.
7 changes: 6 additions & 1 deletion frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
@@ -245,11 +245,14 @@ static cl::opt<bool> IsLoweringtoIndexTree("convert-ta-to-it", /// Lower sparse/
static cl::opt<bool> IsLoweringtoSCF("convert-to-loops",
cl::desc("Output SCF dialect after lowering all operations"));

#ifdef ENABLE_GPU_TARGET

/// =============================================================================
/// Lowering loops to Triton
/// =============================================================================
static cl::opt<bool> IsLoweringtoTriton("convert-to-triton",
cl::desc("Output Triton dialect after lowering all operations"));
#endif

/// =============================================================================
/// Lowering to LLVM
@@ -490,7 +493,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

optPM.addPass(mlir::comet::createSTCRemoveDeadOpsPass());
optPM.addPass(mlir::comet::createLateLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
// pm.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

#ifdef ENABLE_GPU_TARGET
@@ -512,6 +515,8 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

}
#endif
pm.addPass(mlir::createCanonicalizerPass());

/// =============================================================================

if (isLoweringToLLVM || emitLLVM)
76 changes: 59 additions & 17 deletions lib/Conversion/GpuToTriton/GpuToTritonPass.cpp
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/Support/Casting.h"

#include <map>
#include <set>
@@ -238,8 +239,8 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
mlir::Value memMask;
for(size_t i = index_offset; i < op->getNumOperands(); i++)
{
auto affineOp = llvm::dyn_cast<affine::AffineApplyOp>(op->getOperand(i).getDefiningOp());
auto minOp = llvm::dyn_cast<arith::MinUIOp>(op->getOperand(i).getDefiningOp());
auto affineOp = llvm::dyn_cast_if_present<affine::AffineApplyOp>(op->getOperand(i).getDefiningOp());
auto minOp = llvm::dyn_cast_if_present<arith::MinUIOp>(op->getOperand(i).getDefiningOp());
if(affineOp || (minOp && (minOp->hasAttr("GuardX") || minOp->hasAttr("GuardY") || minOp->hasAttr("GuardR"))))
{
// int pidXIndex = -1;
@@ -313,7 +314,10 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
if(map.find(exp.getAsOpaquePointer()) == map.end())
{
map[exp.getAsOpaquePointer()] = rewriter.create<triton::GetProgramIdOp>(op->getLoc(), rewriter.getI32Type(), mlir::triton::ProgramIDDimAttr::get(op->getContext(), mlir::triton::ProgramIDDim::X));
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32);
if(guardX)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32);
}
}

bidX = map[exp.getAsOpaquePointer()];
@@ -323,7 +327,10 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
if(map.find(exp.getAsOpaquePointer()) == map.end())
{
map[exp.getAsOpaquePointer()] = rewriter.create<triton::GetProgramIdOp>(op->getLoc(), rewriter.getI32Type(), mlir::triton::ProgramIDDimAttr::get(op->getContext(), mlir::triton::ProgramIDDim::Y));
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32);
if(guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(op->getLoc(), 0, 32);
}
}

bidY = map[exp.getAsOpaquePointer()];
@@ -338,7 +345,10 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
int blockX = op->getParentOfType<triton::FuncOp>()->getAttrOfType<IntegerAttr>("block_size_x").getInt();
auto range = rewriter.create<mlir::triton::MakeRangeOp>(op->getLoc(), RankedTensorType::get({blockX} , rewriter.getI32Type()), 0, blockX)->getResult(0);
map[exp.getAsOpaquePointer()] = rewriter.create<triton::ExpandDimsOp>(op->getLoc(), RankedTensorType::get({1, blockX}, rewriter.getI32Type()), range, 0);
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({1, blockX}, rewriter.getI32Type()), guardX );
if(guardX)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({1, blockX}, rewriter.getI32Type()), guardX );
}
}
tidX = map[exp.getAsOpaquePointer()];
}
@@ -349,7 +359,10 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
int blockY = op->getParentOfType<triton::FuncOp>()->getAttrOfType<IntegerAttr>("block_size_y").getInt();
auto range = rewriter.create<mlir::triton::MakeRangeOp>(op->getLoc(), RankedTensorType::get({blockY} , rewriter.getI32Type()), 0, blockY)->getResult(0);
map[exp.getAsOpaquePointer()] = rewriter.create<triton::ExpandDimsOp>(op->getLoc(), RankedTensorType::get({blockY, 1}, rewriter.getI32Type()), range, 1);
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({blockY, 1}, rewriter.getI32Type()), guardY );
if(guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({blockY, 1}, rewriter.getI32Type()), guardY );
}

}
tidY = map[exp.getAsOpaquePointer()];
@@ -368,14 +381,20 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
{
auto loopBlockSize = blockArg.getOwner()->getParentOp()->getAttrOfType<IntegerAttr>("loop_block_size").getInt();
map[exp.getAsOpaquePointer()] = rewriter.createOrFold<triton::MakeRangeOp>(op->getLoc(), RankedTensorType::get({loopBlockSize}, rewriter.getI32Type()), 0, loopBlockSize);
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({loopBlockSize}, rewriter.getI32Type()), guardR );
if(guardR)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<triton::SplatOp>(op->getLoc(), RankedTensorType::get({loopBlockSize}, rewriter.getI32Type()), guardR );
}

}
else
{
// [TODO] map[exp.getAsOpaquePointer()] = redIdx.getIn();
map[exp.getAsOpaquePointer()] = redIdx;
mapGuard[exp.getAsOpaquePointer()] = guardR;
if(guardR)
{
mapGuard[exp.getAsOpaquePointer()] = guardR;
}
}
}
}
@@ -393,16 +412,25 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
{
if(!(barg.getOwner()->getParentOp()->hasAttr("programs_loop_x") || barg.getOwner()->getParentOp()->hasAttr("programs_loop_y")|| barg.getOwner()->getParentOp()->hasAttr("loop_block_size")))
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::IndexCastOp>(op->getLoc(), rewriter.getI32Type(), getSymOrDimOperand(aaffineop, exp))->getResult(0);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::IndexCastOp>(op->getLoc(), rewriter.getI32Type(), getSymOrDimOperand(aaffineop, exp))->getResult(0);
}
}
else
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), 0, 32);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), 0, 32);
}
}
}
else
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::IndexCastOp>(op->getLoc(), rewriter.getI32Type(), getSymOrDimOperand(aaffineop, exp))->getResult(0);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::IndexCastOp>(op->getLoc(), rewriter.getI32Type(), getSymOrDimOperand(aaffineop, exp))->getResult(0);
}
}
}
else
@@ -412,16 +440,25 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
{
if(!(barg.getOwner()->getParentOp()->hasAttr("programs_loop_x") || barg.getOwner()->getParentOp()->hasAttr("programs_loop_y" ) || barg.getOwner()->getParentOp()->hasAttr("loop_block_size")))
{
mapGuard[exp.getAsOpaquePointer()] = getSymOrDimOperand(aaffineop, exp);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = getSymOrDimOperand(aaffineop, exp);
}
}
else
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), 0, 32);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), 0, 32);
}
}
}
else
{
mapGuard[exp.getAsOpaquePointer()] = getSymOrDimOperand(aaffineop, exp);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = getSymOrDimOperand(aaffineop, exp);
}
}

}
@@ -436,15 +473,21 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
auto cst = llvm::cast<mlir::AffineConstantExpr>(exp);

map[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), cst.getValue(), 32);
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), cst.getValue(), 32);
if(guardR || guardX || guardY)
{
mapGuard[exp.getAsOpaquePointer()] = rewriter.create<arith::ConstantIntOp>(aaffineop->getLoc(), cst.getValue(), 32);
}
}
}
else
{
// auto binOp = llvm::cast<mlir::AffineBinaryOpExpr>(exp);
// for(auto m: {map /*mapGuard*/})
handleBinaryExpr(op, map, exp, rewriter);
handleBinaryExpr(op, mapGuard, exp, rewriter);
if(!mapGuard.empty())
{
handleBinaryExpr(op, mapGuard, exp, rewriter);
}
}
};};

@@ -514,7 +557,6 @@ LogicalResult convertMemoryOp(Operation* op, ConversionPatternRewriter &rewriter
// map[cast<affine::AffineApplyOp>(guardRExpr.getDefiningOp()).getAffineMap().getResult(0).getAsOpaquePointer()].dump();
guardR = map[cast<affine::AffineApplyOp>(guardRExpr.getDefiningOp()).getAffineMap().getResult(0).getAsOpaquePointer()];
guardRExpr.replaceUsesWithIf(guardR, [](OpOperand& oper){return !isa<arith::MinUIOp>(oper.getOwner());});

}

map.clear();
30 changes: 29 additions & 1 deletion lib/Conversion/ParallelLoopsToGpu/ParallelLoopsToGpu.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

#include <iostream>
#include <memory>
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -417,7 +419,6 @@ struct DetectReduction
bool reduction = is_reduction(forOp);
if (mlir::scf::ParallelOp parent = llvm::dyn_cast_or_null<mlir::scf::ParallelOp>(forOp->getParentOp()); parent && parent->hasAttr("parallelDim") && no_inner_loops && reduction)
{

// assert(parent && parent->getAttrOfType<mlir::StringAttr>("parallelDim").getValue().equals("dimX_block") && !forOp->hasAttr("reduceDim"));

auto block_size_r = rewriter.create<mlir::arith::ConstantIndexOp>(forOp->getLoc(), blockR );
@@ -511,6 +512,33 @@ class ConvertParallelLoopsToGpu: public CometParallelLoopsToGpuBase<ConvertParal
return signalPassFailure();
}

funcOp->walk([](mlir::scf::ParallelOp par_for) {
mlir::OpBuilder builder(par_for);
auto map = builder.getDimIdentityMap();
mlir::gpu::ParallelLoopDimMappingAttr newAttr;
if(par_for->hasAttr("parallelDim") && !par_for->hasAttr("mapping"))
{
if(par_for->getAttrOfType<mlir::StringAttr>("parallelDim").str() == "dimY_grid")
{
newAttr = mlir::gpu::ParallelLoopDimMappingAttr::get(builder.getContext(), ::mlir::gpu::Processor::BlockY, map, map);
}
else if(par_for->getAttrOfType<mlir::StringAttr>("parallelDim").str() == "dimX_grid")
{
newAttr = mlir::gpu::ParallelLoopDimMappingAttr::get(builder.getContext(), ::mlir::gpu::Processor::BlockX, map, map);
}
else if(par_for->getAttrOfType<mlir::StringAttr>("parallelDim").str() == "dimX_block")
{
newAttr = mlir::gpu::ParallelLoopDimMappingAttr::get(builder.getContext(), ::mlir::gpu::Processor::ThreadX, map, map);
}
else if(par_for->getAttrOfType<mlir::StringAttr>("parallelDim").str() == "dimY_block")
{
newAttr = mlir::gpu::ParallelLoopDimMappingAttr::get(builder.getContext(), ::mlir::gpu::Processor::ThreadY, map, map);
}
assert(newAttr);
par_for->setAttr("mapping", mlir::ArrayAttr::get(par_for->getContext(), newAttr) );
}
});

mlir::RewritePatternSet patterns2(context);
mlir::ConversionTarget target2(*context);

0 comments on commit 7cfaee1

Please sign in to comment.