diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index 6b377e577..6f408de3b 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -138,7 +138,7 @@ linalg-matmul-optimize-run: linalg-batch-matmul-optimize-run: @${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="step=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ @@ -174,12 +174,12 @@ linalg-batch-matmul-run: linalg-batch-matmul-optimize-lower: @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="step=64" \ -o ./log.mlir linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="step=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 9b3924b7d..9ce7acbcb 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -51,10 +51,10 @@ namespace { class BatchMatMulOptimizePattern : public ConversionPattern { public: explicit BatchMatMulOptimizePattern(MLIRContext *context, - int64_t stepPlaceHolderParam) + int64_t affineStepParam) : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, context) { - stepPlaceHolder = stepPlaceHolderParam; + affineStep = affineStepParam; } LogicalResult @@ -73,7 +73,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { const Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); const Value step = rewriter.create( - loc, rewriter.getIndexAttr(stepPlaceHolder)); + loc, rewriter.getIndexAttr(affineStep)); const AffineExpr d0 = rewriter.getAffineDimExpr(0); const AffineExpr d1 = rewriter.getAffineDimExpr(1); const AffineExpr d2 = rewriter.getAffineDimExpr(2); @@ -82,7 +82,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { const Value c0_dynamicType = rewriter.create( loc, rewriter.getZeroAttr(A_elementType)); const Value c0_dynamicType_vec = rewriter.create( - loc, VectorType::get({stepPlaceHolder}, A_elementType), c0_dynamicType); + loc, VectorType::get({affineStep}, A_elementType), c0_dynamicType); // Dims Value BATCH = rewriter.create(loc, A, 0); // Batch size @@ -132,7 +132,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, ValueRange ivRange) { Value ivA_row = ivRange.front(); Value applied_n = builder.create( - loc, AffineMap::get(1, 0, d0.ceilDiv(stepPlaceHolder)), + loc, AffineMap::get(1, 0, d0.ceilDiv(affineStep)), ValueRange{N}); affine::buildAffineLoopNest( builder, loc, {c0}, {applied_n}, 1, @@ -142,7 +142,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { loc, A, ValueRange{ivBatch, ivA_row, ivB_row}); Value a_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, A_elementType), + VectorType::get({affineStep}, A_elementType), a_ele); Value b_col_cur = builder.create(loc, ivB_col, step); @@ -156,21 +156,21 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value b_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineStep}, A_elementType), B, AffineMap::get( - 3, 0, {d0, d1, d2 * stepPlaceHolder}, + 3, 0, {d0, d1, d2 * affineStep}, rewriter.getContext()), ValueRange{ivBatch, ivB_row, ivB_col}); Value c_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineStep}, A_elementType), C, AffineMap::get( - 3, 0, {d0, d1, d2 * stepPlaceHolder}, + 3, 0, {d0, d1, d2 * affineStep}, rewriter.getContext()), ValueRange{ivBatch, ivA_row, ivB_col}); Value result_vec; @@ -186,7 +186,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { builder.create( loc, result_vec, C, AffineMap::get(3, 0, - {d0, d1, d2 * stepPlaceHolder}, + {d0, d1, d2 * affineStep}, rewriter.getContext()), ValueRange{ivBatch, ivA_row, ivB_col}); builder.create(loc); @@ -195,7 +195,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value mask_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineStep}, rewriter.getI1Type()), ValueRange{tail_len}); Value b_col_idx_tail = @@ -204,7 +204,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value b_vec_tail = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineStep}, A_elementType), B, ValueRange{ivBatch, ivB_row, @@ -213,7 +213,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value c_vec_tail = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineStep}, A_elementType), C, ValueRange{ivBatch, ivA_row, @@ -249,7 +249,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { } private: - int64_t stepPlaceHolder; + int64_t affineStep; }; } // end anonymous namespace @@ -268,8 +268,8 @@ class BatchMatMulOptimizePass StringRef getDescription() const final { return "BatchMatMul Optimization."; } BatchMatMulOptimizePass() = default; BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {} - explicit BatchMatMulOptimizePass(int64_t stepPlaceHolderParam) { - stepPlaceHolder = stepPlaceHolderParam; + explicit BatchMatMulOptimizePass(int64_t affineStepParam) { + affineStep = affineStepParam; } void runOnOperation() override; @@ -279,9 +279,9 @@ class BatchMatMulOptimizePass affine::AffineDialect, VectorDialect>(); } - Option stepPlaceHolder{ - *this, "step-placeholder", - llvm::cl::desc("Affine step placeholder size."), llvm::cl::init(64)}; + Option affineStep{ + *this, "vector-size", + llvm::cl::desc("Affine step size."), llvm::cl::init(64)}; }; } // end anonymous namespace. @@ -297,7 +297,7 @@ void BatchMatMulOptimizePass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, stepPlaceHolder); + patterns.add(context, affineStep); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure();