diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index b096961ef2..860f5e8add 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(MatMulOptimization BatchMatMulOptimize.cpp MatMulOptimize.cpp + MatMulVectorization.cpp ) add_mlir_library(BatchMatMulOptimization diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp new file mode 100644 index 0000000000..1cedbb69f5 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp @@ -0,0 +1,271 @@ + + +//===- MatMulVectorization.cpp -------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the matmul vectorization. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class MatMulVectorizationPattern : public ConversionPattern { +public: + explicit MatMulVectorizationPattern(MLIRContext *context, int64_t vecSizeParam, int64_t strideParam, + int64_t kernelMParam, int64_t kernelNParam) + : ConversionPattern(linalg::MatmulOp::getOperationName(), 1, context) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + stride = strideParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + // Get shape of input and output + ShapedType ATy = A.getType().cast(); + // ShapedType BTy = B.getType().cast(); + // ShapedType CTy = C.getType().cast(); + + auto ctx = op->getContext(); + // Currently use f32 as the element type. + // TODO: replace f32 with input type. + FloatType f32 = mlir::FloatType::getF32(ctx); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + // Define `*Type`. + VectorType vectorTy32 = mlir::VectorType::get({stride}, f32); + VectorType vectorMaskTy = VectorType::get({stride}, i1); + // Some constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = + rewriter.create(loc, stride); + const Value c0_f32 = rewriter.create( + loc, APFloat::getZero(f32.getFloatSemantics()), f32); + // Create pass through vector. + Value c0_f32_vec = rewriter.create(loc, vectorTy32, c0_f32); + + // Create DimOp. + const Value a_row = rewriter.create(loc, A, c0); + const Value a_col = rewriter.create(loc, A, c1); + const Value b_row = rewriter.create(loc, B, c0); + const Value b_col = rewriter.create(loc, B, c1); + // Size of strip mining. + AffineExpr d0; + bindDims(ctx, d0); + AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); + SmallVector lowerBounds(3, c0); + SmallVector uperBounds{b_row, a_row, b_col}; + SmallVector steps(3, /*Value=*/1); + affine::buildAffineLoopNest( + + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + Value a_ele = builder.create( + loc, A, ValueRange{ivs[1], ivs[0]}); + Value a_vec = + builder.create(loc, vectorTy32, a_ele); + + // Load input vector from memref. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {m, k * stride}, ctx); + AffineExpr x, y, z; + bindDims(ctx, x, y, z); + AffineMap CVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {y, z * stride}, ctx); + // Calculate the tail. + Value b_col_cur = + builder.create(loc, ivs[2], step); + Value tail_len = builder.create( + loc, b_col, b_col_cur); + Value tail_flag = rewriter.create( + loc, arith::CmpIPredicate::sge, tail_len, step); + + + + + + // If the current column does not reach the tail. + builder.create( + loc, tail_flag, + [&](OpBuilder &builder, Location loc) { + //if + Value b_vec = + builder.create( + loc, vectorTy32, B, BVectorMap, + ValueRange{ivs[0], ivs[1], ivs[2]}); + + + Value c_vec = + builder.create( + loc, vectorTy32, C, CVectorMap, + ValueRange{ivs[0], ivs[1], ivs[2]}); + + // FMA = Fused Multiply + Add + Value resultVector = builder.create( + loc, a_vec, b_vec, c_vec); + + builder.create( + loc, resultVector, C, CVectorMap, + ValueRange{ivs[0], ivs[1], ivs[2]}); + + builder.create(loc); + }, + // The else branch (the current column reaches the + // tail). + [&](OpBuilder &builder, Location loc) { + + // Create mask according to the tail. + //else + Value mask_vec = builder.create( + loc, vectorMaskTy, tail_len); + Value b_col_idx_tail = builder.create( + loc, ivs[2], step); + // Masked load input and output. + Value b_vec_tail = builder.create( + loc, vectorTy32, B, + ValueRange{ivs[0], b_col_idx_tail}, mask_vec, + c0_f32_vec); + Value c_vec_tail = builder.create( + loc, vectorTy32, C, + ValueRange{ivs[1], b_col_idx_tail}, mask_vec, + c0_f32_vec); + // FMA. + Value result_vec_tail = builder.create( + loc, a_vec, b_vec_tail, + c_vec_tail); + + builder.create( + loc, C, ValueRange{ivs[1], b_col_idx_tail}, + mask_vec, result_vec_tail); + + builder.create(loc); + }); + + }); + + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; + int64_t kernelM; + int64_t kernelN; + int64_t stride; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// MatMulVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class MatMulVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulVectorizationPass) + StringRef getArgument() const final { return "matmul-vectorization"; } + StringRef getDescription() const final { return "MatMul Vectorization."; } + MatMulVectorizationPass() = default; + MatMulVectorizationPass(const MatMulVectorizationPass &) {} + explicit MatMulVectorizationPass(int64_t vecSizeParam, int64_t kernelMParam, + int64_t kernelNParam) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option stride{*this, "strip-mining", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(32)}; + + Option vecSize{*this, "vec-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; + + Option kernelM{*this, "kernel-m", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(4)}; + + Option kernelN{*this, "kernel-n", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(2)}; +}; +} // end anonymous namespace. + +void MatMulVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, stride, kernelM, kernelN); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerMatMulVectorizationPass() { PassRegistration(); } +} // namespace buddy +} // namespace mlir diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 8bde8b5711..c906af8ff3 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -56,6 +56,8 @@ void registerLowerDAPPass(); void registerLowerRVVPass(); void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); +void registerMatMulVectorizationPass(); + void registerConvOptimizePass(); void registerLowerVectorExpPass(); void registerLowerGemminiPass(); @@ -81,6 +83,7 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerConvOptimizePass();