Skip to content

Commit

Permalink
[midend][examples] Correct parameter and variable stepPlaceholder i…
Browse files Browse the repository at this point in the history
…n batchmatmul optimization.
  • Loading branch information
EllisLambda committed Sep 18, 2023
1 parent 5f4adec commit db11165
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
6 changes: 3 additions & 3 deletions examples/MLIRLinalg/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ linalg-matmul-optimize-run:

linalg-batch-matmul-optimize-run:
@${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
-batchmatmul-optimize="step=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand Down Expand Up @@ -174,12 +174,12 @@ linalg-batch-matmul-run:

linalg-batch-matmul-optimize-lower:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
-batchmatmul-optimize="step=64" \
-o ./log.mlir

linalg-batch-matmul-optimize-translate:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
-batchmatmul-optimize="step=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand Down
42 changes: 21 additions & 21 deletions midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ namespace {
class BatchMatMulOptimizePattern : public ConversionPattern {
public:
explicit BatchMatMulOptimizePattern(MLIRContext *context,
int64_t stepPlaceHolderParam)
int64_t affineStepParam)
: ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1,
context) {
stepPlaceHolder = stepPlaceHolderParam;
affineStep = affineStepParam;
}

LogicalResult
Expand All @@ -73,7 +73,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
const Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
const Value step = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(stepPlaceHolder));
loc, rewriter.getIndexAttr(affineStep));
const AffineExpr d0 = rewriter.getAffineDimExpr(0);
const AffineExpr d1 = rewriter.getAffineDimExpr(1);
const AffineExpr d2 = rewriter.getAffineDimExpr(2);
Expand All @@ -82,7 +82,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
const Value c0_dynamicType = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(A_elementType));
const Value c0_dynamicType_vec = rewriter.create<vector::SplatOp>(
loc, VectorType::get({stepPlaceHolder}, A_elementType), c0_dynamicType);
loc, VectorType::get({affineStep}, A_elementType), c0_dynamicType);

// Dims
Value BATCH = rewriter.create<memref::DimOp>(loc, A, 0); // Batch size
Expand Down Expand Up @@ -132,7 +132,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, ValueRange ivRange) {
Value ivA_row = ivRange.front();
Value applied_n = builder.create<affine::AffineApplyOp>(
loc, AffineMap::get(1, 0, d0.ceilDiv(stepPlaceHolder)),
loc, AffineMap::get(1, 0, d0.ceilDiv(affineStep)),
ValueRange{N});
affine::buildAffineLoopNest(
builder, loc, {c0}, {applied_n}, 1,
Expand All @@ -142,7 +142,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
loc, A, ValueRange{ivBatch, ivA_row, ivB_row});
Value a_vec = builder.create<vector::BroadcastOp>(
loc,
VectorType::get({stepPlaceHolder}, A_elementType),
VectorType::get({affineStep}, A_elementType),
a_ele);
Value b_col_cur =
builder.create<arith::MulIOp>(loc, ivB_col, step);
Expand All @@ -156,21 +156,21 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value b_vec =
builder.create<affine::AffineVectorLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineStep},
A_elementType),
B,
AffineMap::get(
3, 0, {d0, d1, d2 * stepPlaceHolder},
3, 0, {d0, d1, d2 * affineStep},
rewriter.getContext()),
ValueRange{ivBatch, ivB_row, ivB_col});
Value c_vec =
builder.create<affine::AffineVectorLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineStep},
A_elementType),
C,
AffineMap::get(
3, 0, {d0, d1, d2 * stepPlaceHolder},
3, 0, {d0, d1, d2 * affineStep},
rewriter.getContext()),
ValueRange{ivBatch, ivA_row, ivB_col});
Value result_vec;
Expand All @@ -186,7 +186,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
builder.create<affine::AffineVectorStoreOp>(
loc, result_vec, C,
AffineMap::get(3, 0,
{d0, d1, d2 * stepPlaceHolder},
{d0, d1, d2 * affineStep},
rewriter.getContext()),
ValueRange{ivBatch, ivA_row, ivB_col});
builder.create<scf::YieldOp>(loc);
Expand All @@ -195,7 +195,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value mask_vec =
builder.create<vector::CreateMaskOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineStep},
rewriter.getI1Type()),
ValueRange{tail_len});
Value b_col_idx_tail =
Expand All @@ -204,7 +204,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value b_vec_tail =
builder.create<vector::MaskedLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineStep},
A_elementType),
B,
ValueRange{ivBatch, ivB_row,
Expand All @@ -213,7 +213,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value c_vec_tail =
builder.create<vector::MaskedLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineStep},
A_elementType),
C,
ValueRange{ivBatch, ivA_row,
Expand Down Expand Up @@ -249,7 +249,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
}

private:
int64_t stepPlaceHolder;
int64_t affineStep;
};
} // end anonymous namespace

Expand All @@ -268,8 +268,8 @@ class BatchMatMulOptimizePass
StringRef getDescription() const final { return "BatchMatMul Optimization."; }
BatchMatMulOptimizePass() = default;
BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {}
explicit BatchMatMulOptimizePass(int64_t stepPlaceHolderParam) {
stepPlaceHolder = stepPlaceHolderParam;
explicit BatchMatMulOptimizePass(int64_t affineStepParam) {
affineStep = affineStepParam;
}

void runOnOperation() override;
Expand All @@ -279,9 +279,9 @@ class BatchMatMulOptimizePass
affine::AffineDialect, VectorDialect>();
}

Option<int64_t> stepPlaceHolder{
*this, "step-placeholder",
llvm::cl::desc("Affine step placeholder size."), llvm::cl::init(64)};
Option<int64_t> affineStep{
*this, "vector-size",
llvm::cl::desc("Affine step size."), llvm::cl::init(64)};
};
} // end anonymous namespace.

Expand All @@ -297,7 +297,7 @@ void BatchMatMulOptimizePass::runOnOperation() {
target.addLegalOp<linalg::FillOp>();

RewritePatternSet patterns(context);
patterns.add<BatchMatMulOptimizePattern>(context, stepPlaceHolder);
patterns.add<BatchMatMulOptimizePattern>(context, affineStep);

if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
Expand Down

0 comments on commit db11165

Please sign in to comment.