From 1b58206553e884c846c94c9888d9fe435942f526 Mon Sep 17 00:00:00 2001 From: heyi Date: Tue, 26 Sep 2023 06:08:13 +0000 Subject: [PATCH 01/11] add Convlution NHWC HWCF Optimize Pass file --- .../ConvOptimization/ConvBroadcast.cpp | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp new file mode 100644 index 0000000000..209392f643 --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -0,0 +1,23 @@ +//====- ConvBroadcast.cpp --------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv Broadcast Optmize for linalg.conv_2d_nhwc_hwcf +// +//===----------------------------------------------------------------------===// + +#include +#include + From 41494a0480e7afc6f99ea2618cf8cfc8816571a2 Mon Sep 17 00:00:00 2001 From: heyi Date: Tue, 26 Sep 2023 10:26:44 +0000 Subject: [PATCH 02/11] add test mlir file --- examples/ConvOpt/conv2d_nhwc_hwcf.mlir | 5 +++++ .../Conversion/ConvOptimization/CMakeLists.txt | 1 + .../Conversion/ConvOptimization/ConvBroadcast.cpp | 15 +++++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 examples/ConvOpt/conv2d_nhwc_hwcf.mlir diff --git a/examples/ConvOpt/conv2d_nhwc_hwcf.mlir b/examples/ConvOpt/conv2d_nhwc_hwcf.mlir new file mode 100644 index 0000000000..ab04005e84 --- /dev/null +++ b/examples/ConvOpt/conv2d_nhwc_hwcf.mlir @@ -0,0 +1,5 @@ +func.func @conv_2d_nchw_fchw(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nchw_fchw ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + return +} \ No newline at end of file diff --git a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt index fc88a92ef6..7cc0fceb1d 100644 --- a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt @@ -1,3 +1,4 @@ add_mlir_library(ConvOptimization ConvOptimize.cpp + ConvBroadcast.cpp ) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 209392f643..7031bfb73f 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -20,4 +20,19 @@ #include #include +#include +#include +#include +using namespace mlir; +using namespace vector; + +//===-------------------------------------------- +// Rewrite Pattern +//===-------------------------------------------- + +// namespace { +// class ConvBroadcastOptimizePattern : public ConversionPattern { + +// } +// } // end anonymous namespace \ No newline at end of file From 9fe0cac35c5f5de8d38ff3a48687d84d1fa3b26e Mon Sep 17 00:00:00 2001 From: heyi Date: Tue, 26 Sep 2023 14:49:32 +0000 Subject: [PATCH 03/11] Definition ConvBroadcast Pattern and Pass --- .../ConvOptimization/ConvBroadcast.cpp | 95 ++++++++++++++++++- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 7031bfb73f..86f6af044e 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,94 @@ using namespace vector; // Rewrite Pattern //===-------------------------------------------- -// namespace { -// class ConvBroadcastOptimizePattern : public ConversionPattern { +namespace { +class ConvBroadcastOptimizePattern : public ConversionPattern { +public: + explicit ConvBroadcastOptimizePattern(MLIRContext *context, int64_t strideParam, + ArrayRef tileParam) + : ConversionPattern(linalg::Conv2DNhwcHwcfOp::getOperationName(), 1, context) { + + stride = strideParam; + tileSizes = tileParam; + } -// } -// } // end anonymous namespace \ No newline at end of file + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return failure(); + } + +private: + int64_t stride; + ArrayRef tileSizes; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvBroadcastNhwcHwcf +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg conv2d_nhwc_hwcf to mixture of +/// Affine + Vector + Std operations. +namespace +{ +class ConvBroadcastNhwcHwcfPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvBroadcastNhwcHwcfPass) + StringRef getArgument() const final { return "conv-broadcast"; } + StringRef getDescription() const final { + return "Convolution Broadcast optimize for conv2d_nhwc_hwcf"; + } + ConvBroadcastNhwcHwcfPass() = default; + ConvBroadcastNhwcHwcfPass(const ConvBroadcastNhwcHwcfPass &) {} + explicit ConvBroadcastNhwcHwcfPass(int64_t strideParam, + ArrayRef tileParam) { + stride = strideParam; + tile = tileParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®iistery) const override { + regiistery.insert(); + } + + Option stride{*this, "strip-mining", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(32)}; + ListOption tile{*this, "tile-sizes", llvm::cl::desc("Tile sizes"), + llvm::cl::ZeroOrMore}; +}; +} // end anonymous namespace + +void ConvBroadcastNhwcHwcfPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, stride, tile); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + { + signalPassFailure(); + } + +} + +namespace mlir { +namespace buddy { +void registerConvBroadcastNhwcHwcfPass() { + PassRegistration(); +} +} +} From e5c62053893e4a08e17e1f71455b020e5be48958 Mon Sep 17 00:00:00 2001 From: heyi Date: Thu, 28 Sep 2023 09:31:54 +0000 Subject: [PATCH 04/11] finish CB --- .../ConvOptimization/ConvBroadcast.cpp | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 86f6af044e..f2584618ee 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -25,6 +25,8 @@ #include #include +#include "Utils/Utils.h" + using namespace mlir; using namespace vector; @@ -46,6 +48,114 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Currently use f32 as the element type. + FloatType f32 = mlir::FloatType::getF32(ctx); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + // Define `*Type`. + VectorType vectorTy32 = mlir::VectorType::get({stride}, f32); + VectorType vectorMaskTy = VectorType::get({stride}, i1); + // Create constant index. + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value cStride = rewriter.create(loc, stride); + Value f0 = rewriter.create( + loc, APFloat::getZero(f32.getFloatSemantics()), f32 + ); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy32, f0); + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Create DimOp. + Value kernelRow = rewriter.create(loc, kernel, c0); + Value kernelCol = rewriter.create(loc, kernel, c1); + Value outputRow = rewriter.create(loc, output, c0); + Value outputCol = rewriter.create(loc, output, c1); + // Size of strip mining. + AffineExpr d0; + bindDims(ctx, d0); + AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); + SmallVector lowerBounds(3, c0); + SmallVector uperBounds{outputRow, kernelRow, kernelCol}; + SmallVector steps(3, 1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create strip mining loop. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{outputCol}, stripMap, /*Step=*/1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location nestedLoc,Value iv, + ValueRange itrArgs) { + // Vectorize the kernel. + // Broadcast element of the kernel. + Value kernelValue = builder.create( + loc, kernel, ValueRange{ivs[1], ivs[2]} + ); + // Coefficients handle, if kernel item == 0, skip compute + Value kernelNonZeroCond = buddy::zeroCond( + builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) + ); + builder.create( + loc, kernelNonZeroCond, + [&](OpBuilder &builder, Location loc) { + Value kernelVector = builder.create( + loc, vectorTy32, kernelValue + ); + // Load input vector from memref. + AffineExpr m, n, k, j; + bindDims(ctx, m, n, k, j); + AffineMap inputVectorMap = AffineMap::get( + 4, 0, + {m + n, k + j * stride}, ctx + ); + // Calculate the tail. + Value currCol = + nestedBuilder.create(loc, iv, cStride); + Value tail = nestedBuilder.create( + loc, outputCol, currCol); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sge, tail, cStride); + // If the current column does not reach the tail. + builder.create( + loc, tailCond, + [&](OpBuilder &builder, Location loc) { + Value inputVector = + nestedBuilder.create( + loc, vectorTy32, input, inputVectorMap, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + // Define AffineMap. + // The `outputVector` and `resultVector` share the + // same AffineMap. + AffineExpr x, y; + bindDims(ctx, x, y); + AffineMap outputVectorMap = AffineMap::get( + 2, 0, + {x, y * stride}, ctx); + Value outputVector = + nestedBuilder.create( + loc, vectorTy32, output, outputVectorMap, + ValueRange{ivs[0], iv}); + // Multiply InputVector and KernelBroadcastVector then Add OutputVector to OutputVector + Value resultVector = nestedBuilder.create( + loc, inputVector, kernelVector, outputVector); + nestedBuilder.create( + loc, resultVector, output, outputVectorMap, + ValueRange{ivs[0], iv}); + builder.create(loc); + } + ); + } + ); + } + ); + } + ); return failure(); } From abac9f487e754db6624544e2ea2fa9d7cd45815b Mon Sep 17 00:00:00 2001 From: heyi Date: Thu, 28 Sep 2023 10:44:54 +0000 Subject: [PATCH 05/11] registerPass --- .../ConvOptimization/ConvBroadcast.cpp | 36 +++++++++++++++++-- tools/buddy-opt/buddy-opt.cpp | 2 ++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index f2584618ee..810cdb5edd 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -148,15 +148,47 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { loc, resultVector, output, outputVectorMap, ValueRange{ivs[0], iv}); builder.create(loc); + }, + // The else branch (the current column reaches the tail). + [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value tailMask = builder.create(loc, vectorMaskTy, tail); + // Calculate the index of the input and output. + Value inputRow = nestedBuilder.create(loc, ivs[0], ivs[1]); + Value outputCol = nestedBuilder.create(loc, iv, cStride); + Value inputCol = nestedBuilder.create(loc, ivs[2], outputCol); + // Masked load input and output. + Value maskedInputVec = builder.create( + loc, vectorTy32, input, + ValueRange{inputRow, inputCol}, tailMask, + passThroughVec + ); + Value maskedOutputVec = builder.create( + loc, vectorTy32, output, + ValueRange{ivs[0], outputCol}, tailMask, + passThroughVec + ); + // FMA + Value resultVec = builder.create(loc, maskedInputVec, kernelVector, maskedOutputVec); + // Masked store the result to output. + builder.create( + loc, output, ValueRange{ivs[0], outputCol}, + tailMask, resultVec + ); + builder.create(loc); } ); + builder.create(loc); } ); - } + nestedBuilder.create(nestedLoc); + } ); } ); - return failure(); + // Remove the origin convolution operation + rewriter.eraseOp(op); + return success(); } private: diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index c906af8ff3..2d86b6c3b8 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -48,6 +48,7 @@ namespace mlir { namespace buddy { void registerConvVectorizationPass(); +void registerConvBroadcastNhwcHwcfPass(); void registerPointwiseConvToGemmPass(); void registerPoolingVectorizationPass(); void registerLowerBudPass(); @@ -71,6 +72,7 @@ int main(int argc, char **argv) { mlir::buddy::registerPointwiseConvToGemmPass(); // Register Vectorization of Convolution. mlir::buddy::registerConvVectorizationPass(); + mlir::buddy::registerConvBroadcastNhwcHwcfPass(); // Register Vectorization of Pooling. mlir::buddy::registerPoolingVectorizationPass(); mlir::buddy::registerLowerBudPass(); From 3d9d093bce353832fc8ce95a1f5701ac33f78880 Mon Sep 17 00:00:00 2001 From: heyi Date: Fri, 29 Sep 2023 07:40:57 +0000 Subject: [PATCH 06/11] test affine api --- midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 810cdb5edd..17e718f7b3 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -60,6 +60,8 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { // Create constant index. Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); Value cStride = rewriter.create(loc, stride); Value f0 = rewriter.create( loc, APFloat::getZero(f32.getFloatSemantics()), f32 @@ -74,8 +76,11 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { // Create DimOp. Value kernelRow = rewriter.create(loc, kernel, c0); Value kernelCol = rewriter.create(loc, kernel, c1); - Value outputRow = rewriter.create(loc, output, c0); - Value outputCol = rewriter.create(loc, output, c1); + Value outputRow = rewriter.create(loc, output, c1); + Value outputCol = rewriter.create(loc, output, c2); + Value batch = rewriter.create(loc, input, c0); + Value feature = rewriter.create(loc, kernel, c3); + Value channel = rewriter.create(loc, kernel, c2); // Size of strip mining. AffineExpr d0; bindDims(ctx, d0); From 1d3b439c159af91d0f02faf0215eda2ee226bea8 Mon Sep 17 00:00:00 2001 From: wangaobo Date: Wed, 4 Oct 2023 00:37:16 +0800 Subject: [PATCH 07/11] Finish and Pass test --- examples/ConvOpt/conv2d_nhwc_hwcf.mlir | 4 +- .../ConvOptimization/ConvBroadcast.cpp | 175 +++++++----------- .../conv2d_nhwc_hwcf-broadcast.mlir | 89 +++++++++ 3 files changed, 156 insertions(+), 112 deletions(-) create mode 100644 tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir diff --git a/examples/ConvOpt/conv2d_nhwc_hwcf.mlir b/examples/ConvOpt/conv2d_nhwc_hwcf.mlir index ab04005e84..5a369fd288 100644 --- a/examples/ConvOpt/conv2d_nhwc_hwcf.mlir +++ b/examples/ConvOpt/conv2d_nhwc_hwcf.mlir @@ -1,5 +1,5 @@ -func.func @conv_2d_nchw_fchw(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nchw_fchw ins (%arg0, %arg1: memref, memref) +func.func @conv_2d_nhwc_hwcf(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_hwcf ins (%arg0, %arg1: memref, memref) outs (%arg2: memref) return } \ No newline at end of file diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 17e718f7b3..8ec5e8f03a 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include "Utils/Utils.h" @@ -67,8 +68,6 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { loc, APFloat::getZero(f32.getFloatSemantics()), f32 ); - // Create pass through vector. - Value passThroughVec = rewriter.create(loc, vectorTy32, f0); // Get input, kernel and output. Value input = op->getOperand(0); Value kernel = op->getOperand(1); @@ -85,112 +84,68 @@ class ConvBroadcastOptimizePattern : public ConversionPattern { AffineExpr d0; bindDims(ctx, d0); AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); - SmallVector lowerBounds(3, c0); - SmallVector uperBounds{outputRow, kernelRow, kernelCol}; - SmallVector steps(3, 1); - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, uperBounds, steps, - [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Create strip mining loop. - builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{outputCol}, stripMap, /*Step=*/1, std::nullopt, - [&](OpBuilder &nestedBuilder, Location nestedLoc,Value iv, - ValueRange itrArgs) { - // Vectorize the kernel. - // Broadcast element of the kernel. - Value kernelValue = builder.create( - loc, kernel, ValueRange{ivs[1], ivs[2]} - ); - // Coefficients handle, if kernel item == 0, skip compute - Value kernelNonZeroCond = buddy::zeroCond( - builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) - ); - builder.create( - loc, kernelNonZeroCond, - [&](OpBuilder &builder, Location loc) { - Value kernelVector = builder.create( - loc, vectorTy32, kernelValue - ); - // Load input vector from memref. - AffineExpr m, n, k, j; - bindDims(ctx, m, n, k, j); - AffineMap inputVectorMap = AffineMap::get( - 4, 0, - {m + n, k + j * stride}, ctx - ); - // Calculate the tail. - Value currCol = - nestedBuilder.create(loc, iv, cStride); - Value tail = nestedBuilder.create( - loc, outputCol, currCol); - Value tailCond = rewriter.create( - loc, arith::CmpIPredicate::sge, tail, cStride); - // If the current column does not reach the tail. - builder.create( - loc, tailCond, - [&](OpBuilder &builder, Location loc) { - Value inputVector = - nestedBuilder.create( - loc, vectorTy32, input, inputVectorMap, - ValueRange{ivs[0], ivs[1], ivs[2], iv}); - // Define AffineMap. - // The `outputVector` and `resultVector` share the - // same AffineMap. - AffineExpr x, y; - bindDims(ctx, x, y); - AffineMap outputVectorMap = AffineMap::get( - 2, 0, - {x, y * stride}, ctx); - Value outputVector = - nestedBuilder.create( - loc, vectorTy32, output, outputVectorMap, - ValueRange{ivs[0], iv}); - // Multiply InputVector and KernelBroadcastVector then Add OutputVector to OutputVector - Value resultVector = nestedBuilder.create( - loc, inputVector, kernelVector, outputVector); - nestedBuilder.create( - loc, resultVector, output, outputVectorMap, - ValueRange{ivs[0], iv}); - builder.create(loc); - }, - // The else branch (the current column reaches the tail). - [&](OpBuilder &builder, Location loc) { - // Create mask according to the tail. - Value tailMask = builder.create(loc, vectorMaskTy, tail); - // Calculate the index of the input and output. - Value inputRow = nestedBuilder.create(loc, ivs[0], ivs[1]); - Value outputCol = nestedBuilder.create(loc, iv, cStride); - Value inputCol = nestedBuilder.create(loc, ivs[2], outputCol); - // Masked load input and output. - Value maskedInputVec = builder.create( - loc, vectorTy32, input, - ValueRange{inputRow, inputCol}, tailMask, - passThroughVec - ); - Value maskedOutputVec = builder.create( - loc, vectorTy32, output, - ValueRange{ivs[0], outputCol}, tailMask, - passThroughVec - ); - // FMA - Value resultVec = builder.create(loc, maskedInputVec, kernelVector, maskedOutputVec); - // Masked store the result to output. - builder.create( - loc, output, ValueRange{ivs[0], outputCol}, - tailMask, resultVec - ); - builder.create(loc); - } - ); - builder.create(loc); - } - ); - nestedBuilder.create(nestedLoc); - } - ); - } - ); + SmallVector lowerBounds(6, c0); + SmallVector uperBounds{batch, outputRow, kernelRow, kernelCol, channel, feature}; + SmallVector steps(6, 1); + affine::buildAffineLoopNest(rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create stride loop. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{outputCol}, stripMap, 1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location channelLoc, Value iv, ValueRange itrArgs) { + // Get element frome kernel. + Value kernelValue = builder.create( + loc, kernel, ValueRange{ivs[2], ivs[3], ivs[4], ivs[5]} + ); + // Coefficients handle, if kernel item == 0, skip compute + Value kernelNonZeroCond = buddy::zeroCond( + builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) + ); + builder.create(loc, kernelNonZeroCond, + [&](OpBuilder &builder, Location loc) { + // Broadcast element of the kernel. + Value kernelVector = builder.create( + loc, vectorTy32, kernelValue + ); + // Calculate the tail. + Value currCol = nestedBuilder.create(loc, iv, cStride); + Value tail = nestedBuilder.create(loc, outputCol, currCol); + + Value inputRowTailIdx = builder.create(loc, ivs[1], ivs[2]); + //Value outputColTailIdx = builder.create(loc, iv, cStride); + Value inputColTailIdx = builder.create(loc, ivs[3], currCol); + Value tailMask = builder.create(loc, vectorMaskTy, tail); + // Define AffineMap (d0, d1, d2, d3) -> (d2) + AffineExpr d0, d1, d2, d3; + bindDims(ctx, d0, d1, d2, d3); + AffineMap tansposeMap = AffineMap::get(4, 0, {d2}, ctx); + SmallVector inBounds(1, true); + // Load input/output vector from memref. + Value inputVector = builder.create( + loc, vectorTy32, input, ValueRange{ivs[0], inputRowTailIdx, inputColTailIdx, ivs[4]}, + AffineMapAttr::get(tansposeMap), f0, tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + ); + Value outputVector = builder.create( + loc, vectorTy32, output, ValueRange{ivs[0], ivs[1], currCol, ivs[5]}, + AffineMapAttr::get(tansposeMap), f0, tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + ); + + // Multiply input vector and kernel vector then Add OutputVector(FMA). + Value resultVector = builder.create( + loc, inputVector, kernelVector, outputVector + ); + // Store result vector to output. + builder.create( + loc, resultVector, output, ValueRange{ivs[0], ivs[1], currCol, ivs[5]}, + AffineMapAttr::get(tansposeMap), tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + ); + builder.create(loc); + }); + nestedBuilder.create(channelLoc); + } + ); + }); // Remove the origin convolution operation rewriter.eraseOp(op); return success(); @@ -233,8 +188,8 @@ class ConvBroadcastNhwcHwcfPass affine::AffineDialect, VectorDialect, func::FuncDialect>(); } - Option stride{*this, "strip-mining", - llvm::cl::desc("Strip mining size."), + Option stride{*this, "stride", + llvm::cl::desc("Stride size."), llvm::cl::init(32)}; ListOption tile{*this, "tile-sizes", llvm::cl::desc("Tile sizes"), llvm::cl::ZeroOrMore}; diff --git a/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir new file mode 100644 index 0000000000..89b534973d --- /dev/null +++ b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir @@ -0,0 +1,89 @@ +// RUN: buddy-opt %s \ +// RUN: -conv-broadcast="stride=32" \ +// RUN: -convert-linalg-to-loops -convert-vector-to-scf -lower-affine \ +// RUN: -convert-scf-to-cf -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -convert-cf-to-llvm -reconcile-unrealized-cas \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_lib \ +// RUN: -shared-libs=%mlir_c_runner_utils_lib \ +// RUN: | FileCheck %s +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %val: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %arg5 = %c0 to %arg0 step %c1 { + scf.for %arg6 = %c0 to %arg1 step %c1 { + scf.for %arg7 = %c0 to %arg2 step %c1 { + scf.for %arg8 = %c0 to %arg3 step %c1 { + memref.store %val, %0[%arg5, %arg6, %arg7, %arg8] : memref + } + } + } + } + return %0 : memref + } + + func.func @conv_2d_nhwc_hwcf(%a: memref, %b: memref, %c: memref) { + linalg.conv_2d_nhwc_hwcf + ins(%a, %b: memref, memref) + outs(%c: memref) + return + } + + func.func @main() { + // Input and kernel value. + %cst = arith.constant 1.000000e+00 : f32 + // Output value. + %cst_0 = arith.constant 0.000000e+00 : f32 + + // Define layout. + %input_n = arith.constant 1 : index + %input_h = arith.constant 3 : index + %input_w = arith.constant 3 : index + %input_c = arith.constant 2 : index + + %kernel_h = arith.constant 2 : index + %kernel_w = arith.constant 2 : index + %kernel_c = arith.constant 2 : index + %kernel_f = arith.constant 2 : index + + %output_n = arith.constant 1 : index + %output_h = arith.constant 2 : index + %output_w = arith.constant 2 : index + %output_c = arith.constant 2 : index + + + // Define input, kernel, and output memref. + %input = call @alloc_f32(%input_n, %input_h, %input_w, %input_c, %cst) : (index, index, index, index, f32) -> memref + %kernel = call @alloc_f32(%kernel_h, %kernel_w, %kernel_c, %kernel_f, %cst) : (index, index, index, index, f32) -> memref + %output = call @alloc_f32(%output_n, %output_h, %output_w, %output_c, %cst_0) : (index, index, index, index, f32) -> memref + + // Perform convolution + call @conv_2d_nhwc_hwcf(%input, %kernel, %output) : (memref, memref, memref) -> () + + // Print the output + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 2, 2, 2] strides = [8, 4, 2, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-SAME: [8, 8], + // CHECK-NEXT: [8, 8] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-SAME: [8, 8], + // CHECK-NEXT: [8, 8] + // CHECK-SAME: ] + // CHECK-SAME: ] + // CHECK-SAME: ] + %print_output = memref.cast %output : memref to memref<*xf32> + call @printMemrefF32(%print_output) : (memref<*xf32>) -> () + + memref.dealloc %output : memref + memref.dealloc %input : memref + memref.dealloc %kernel : memref + return + } +} \ No newline at end of file From 028abc88f47f1ea396e7536ed3b7b5fa93104e1a Mon Sep 17 00:00:00 2001 From: wangaobo Date: Thu, 5 Oct 2023 20:29:30 +0800 Subject: [PATCH 08/11] Finish Data Layout transpose --- .../ConvOptimization/ConvBroadcast.cpp | 478 ++++++++++++++---- 1 file changed, 369 insertions(+), 109 deletions(-) diff --git a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp index 8ec5e8f03a..84eb94b0e5 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvBroadcast.cpp @@ -31,129 +31,384 @@ using namespace mlir; using namespace vector; -//===-------------------------------------------- -// Rewrite Pattern -//===-------------------------------------------- +//===----------------------------------------------------------------------===// +// Pattern Details +//===----------------------------------------------------------------------===// -namespace { -class ConvBroadcastOptimizePattern : public ConversionPattern { -public: - explicit ConvBroadcastOptimizePattern(MLIRContext *context, int64_t strideParam, - ArrayRef tileParam) - : ConversionPattern(linalg::Conv2DNhwcHwcfOp::getOperationName(), 1, context) { - - stride = strideParam; - tileSizes = tileParam; - } +void populateTransposePattern(Operation *op, int64_t stride, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Currently use f32 as the element type. + FloatType f32 = mlir::FloatType::getF32(ctx); + VectorType vectorTy32 = mlir::VectorType::get({stride}, f32); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + // Define `*Type`. + VectorType vectorMaskTy = VectorType::get({stride}, i1); + // Create constant index. + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + Value cStride = rewriter.create(loc, stride); + Value f0 = rewriter.create( + loc, APFloat::getZero(f32.getFloatSemantics()), f32 + ); + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy32, f0); + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Create DimOp. + Value kernelRow = rewriter.create(loc, kernel, c0); + Value kernelCol = rewriter.create(loc, kernel, c1); + Value inputRow = rewriter.create(loc, input, c1); + Value outputRow = rewriter.create(loc, output, c1); + Value outputCol = rewriter.create(loc, output, c2); + Value batch = rewriter.create(loc, input, c0); + Value inputCol = rewriter.create(loc, input, c2); + Value feature = rewriter.create(loc, kernel, c3); + Value channel = rewriter.create(loc, kernel, c2); - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto ctx = op->getContext(); - // Currently use f32 as the element type. - FloatType f32 = mlir::FloatType::getF32(ctx); - // Get i1 as the element type for mask vector. - IntegerType i1 = IntegerType::get(ctx, 1); - // Define `*Type`. - VectorType vectorTy32 = mlir::VectorType::get({stride}, f32); - VectorType vectorMaskTy = VectorType::get({stride}, i1); - // Create constant index. - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); - Value c2 = rewriter.create(loc, 2); - Value c3 = rewriter.create(loc, 3); - Value cStride = rewriter.create(loc, stride); - Value f0 = rewriter.create( - loc, APFloat::getZero(f32.getFloatSemantics()), f32 + // Size of strip mining. + AffineExpr d0; + bindDims(ctx, d0); + AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); + + // Define affine bounds + SmallVector transposeLower(4, c0); + SmallVector inputUpper{batch, inputRow, inputCol, channel}; + SmallVector transposeStep(4, 1); + SmallVector kernelUpper{kernelRow, kernelCol, channel, feature}; + SmallVector outputUpper{batch, outputRow, outputCol, feature}; + SmallVector lowerBounds(6, c0); + SmallVector uperBounds{batch, feature, channel, outputRow, kernelRow, kernelCol}; + SmallVector steps(6, 1); + + // Transpose DataLayout NHWC to NCHW / HWCF to FCHW + MemRefType memType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic, + ShapedType::kDynamic, ShapedType::kDynamic}, f32); + Value inputTranspose = rewriter.create(loc, memType, ValueRange{batch, channel, inputRow, inputCol}); + Value kernelTranspose = rewriter.create(loc, memType, ValueRange{feature, channel, kernelRow, kernelCol}); + Value outputTranspose = rewriter.create(loc, memType, ValueRange{batch, feature, outputRow, outputCol}); + + affine::buildAffineLoopNest(rewriter, loc, transposeLower, inputUpper, transposeStep, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Load data from input NHWC + Value val = builder.create( + loc, input, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]} ); + // Store data to transpose NCHW + builder.create( + loc, val, inputTranspose, ValueRange{ivs[0], ivs[3], ivs[1], ivs[2]} + ); + }); - // Get input, kernel and output. - Value input = op->getOperand(0); - Value kernel = op->getOperand(1); - Value output = op->getOperand(2); - // Create DimOp. - Value kernelRow = rewriter.create(loc, kernel, c0); - Value kernelCol = rewriter.create(loc, kernel, c1); - Value outputRow = rewriter.create(loc, output, c1); - Value outputCol = rewriter.create(loc, output, c2); - Value batch = rewriter.create(loc, input, c0); - Value feature = rewriter.create(loc, kernel, c3); - Value channel = rewriter.create(loc, kernel, c2); - // Size of strip mining. - AffineExpr d0; - bindDims(ctx, d0); - AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); - SmallVector lowerBounds(6, c0); - SmallVector uperBounds{batch, outputRow, kernelRow, kernelCol, channel, feature}; - SmallVector steps(6, 1); - affine::buildAffineLoopNest(rewriter, loc, lowerBounds, uperBounds, steps, + affine::buildAffineLoopNest(rewriter, loc, transposeLower, kernelUpper, transposeStep, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Load data from kernel HWCF + Value val = builder.create( + loc, kernel, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]} + ); + // Store to transpose FCHW + builder.create( + loc, val, kernelTranspose, ValueRange{ivs[3], ivs[2], ivs[0], ivs[1]} + ); + }); + + affine::buildAffineLoopNest( + rewriter, loc, transposeLower, outputUpper, transposeStep, [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Create stride loop. - builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{outputCol}, stripMap, 1, std::nullopt, - [&](OpBuilder &nestedBuilder, Location channelLoc, Value iv, ValueRange itrArgs) { - // Get element frome kernel. - Value kernelValue = builder.create( - loc, kernel, ValueRange{ivs[2], ivs[3], ivs[4], ivs[5]} - ); - // Coefficients handle, if kernel item == 0, skip compute - Value kernelNonZeroCond = buddy::zeroCond( - builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) + // Load from origin + Value val = builder.create( + loc, output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]} + ); + // Store to transpose + builder.create( + loc, val, outputTranspose, ValueRange{ivs[0], ivs[3], ivs[1], ivs[2]} + ); + } + ); + + // Cofficients Broadcast + affine::buildAffineLoopNest(rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create stride loop. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{outputCol}, stripMap, 1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange itrArgs) { + // Get element from kernel. + Value kernelValue = builder.create( + loc, kernelTranspose, ValueRange{ivs[1], ivs[2], ivs[4], ivs[5]} + ); + // if kernel item == 0, skip compute + Value kernelNonZeroCond = buddy::zeroCond( + builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) + ); + builder.create(loc, kernelNonZeroCond, + [&](OpBuilder &builder, Location loc) { + // Broadcast element of the kernel. + Value kernelVec = builder.create(loc, vectorTy32, kernelValue); + // Calculate the tail. + Value currCol = builder.create(loc, iv, cStride); + Value tail = builder.create(loc, outputCol, currCol); + Value tailCond = builder.create( + loc, arith::CmpIPredicate::sge, tail, cStride ); - builder.create(loc, kernelNonZeroCond, + // If the current column does not reach the tail. + builder.create(loc, tailCond, [&](OpBuilder &builder, Location loc) { - // Broadcast element of the kernel. - Value kernelVector = builder.create( - loc, vectorTy32, kernelValue + // Load input vector from memref. + AffineExpr b1, c, m, n, k, j; + bindDims(ctx, b1, c, m, n, k, j); + AffineMap inputVectorMap = AffineMap::get(6, 0, {b1, c, m + n, k + j * stride}, ctx); + Value inputVec = builder.create( + loc, vectorTy32, inputTranspose, inputVectorMap, + ValueRange{ivs[0], ivs[2], ivs[3], ivs[4], ivs[5], iv} ); - // Calculate the tail. - Value currCol = nestedBuilder.create(loc, iv, cStride); - Value tail = nestedBuilder.create(loc, outputCol, currCol); - - Value inputRowTailIdx = builder.create(loc, ivs[1], ivs[2]); - //Value outputColTailIdx = builder.create(loc, iv, cStride); - Value inputColTailIdx = builder.create(loc, ivs[3], currCol); - Value tailMask = builder.create(loc, vectorMaskTy, tail); - // Define AffineMap (d0, d1, d2, d3) -> (d2) - AffineExpr d0, d1, d2, d3; - bindDims(ctx, d0, d1, d2, d3); - AffineMap tansposeMap = AffineMap::get(4, 0, {d2}, ctx); - SmallVector inBounds(1, true); - // Load input/output vector from memref. - Value inputVector = builder.create( - loc, vectorTy32, input, ValueRange{ivs[0], inputRowTailIdx, inputColTailIdx, ivs[4]}, - AffineMapAttr::get(tansposeMap), f0, tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + // Define AffineMap. + // The `outputVector` and `resultVector` share the same AffineMap. + AffineExpr b2, f, x, y; + bindDims(ctx, b2, f, x, y); + AffineMap outputVecMap = AffineMap::get(4, 0, {b2, f, x, y * stride}, ctx); + Value outputVec = builder.create( + loc, vectorTy32, outputTranspose, outputVecMap, + ValueRange{ivs[0], ivs[1], ivs[3], iv} ); - Value outputVector = builder.create( - loc, vectorTy32, output, ValueRange{ivs[0], ivs[1], currCol, ivs[5]}, - AffineMapAttr::get(tansposeMap), f0, tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + // Multiply and Add + Value resultVector = builder.create(loc, inputVec, kernelVec, outputVec); + builder.create( + loc, resultVector, outputTranspose, outputVecMap, + ValueRange{ivs[0], ivs[1], ivs[3], iv} ); - - // Multiply input vector and kernel vector then Add OutputVector(FMA). - Value resultVector = builder.create( - loc, inputVector, kernelVector, outputVector + builder.create(loc); + }, + // Else branch (the current column reaches the tail) + [&](OpBuilder &builder, Location loc) { + Value tailMask = builder.create(loc, vectorMaskTy, tail); + // Calculate the index of the input and output. + Value inputRowTail = builder.create(loc, ivs[3], ivs[4]); + Value outputColTail = builder.create(loc, iv, cStride); + Value inputColTail = builder.create(loc, ivs[5], outputColTail); + // Masked load input and output. + Value maskedInputVec = builder.create( + loc, vectorTy32, inputTranspose, + ValueRange{ivs[0], ivs[2], inputRowTail, inputColTail}, tailMask, passThroughVec ); - // Store result vector to output. - builder.create( - loc, resultVector, output, ValueRange{ivs[0], ivs[1], currCol, ivs[5]}, - AffineMapAttr::get(tansposeMap), tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + Value maskedOutputVec = builder.create( + loc, vectorTy32, outputTranspose, + ValueRange{ivs[0], ivs[1], ivs[3], outputColTail}, tailMask, passThroughVec + ); + // FMA. + Value resultVec = builder.create(loc, maskedInputVec, kernelVec, maskedOutputVec); + + // Masked store the result to output. + builder.create( + loc, outputTranspose, ValueRange{ivs[0], ivs[1], ivs[3], outputColTail}, + tailMask, resultVec ); builder.create(loc); }); - nestedBuilder.create(channelLoc); - } + nestedBuilder.create(loc); + }); + builder.create(nestedLoc); + } + ); + }); + + affine::buildAffineLoopNest( + rewriter, loc, transposeLower, outputUpper, transposeStep, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Load from transpose output + Value val = builder.create( + loc, outputTranspose, ValueRange{ivs[0], ivs[3], ivs[1], ivs[2]} + ); + // Store to origin + builder.create( + loc, val, output, ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]} ); - }); - // Remove the origin convolution operation - rewriter.eraseOp(op); + } + ); + + // Remove the origin convolution operation. + rewriter.eraseOp(op); +} + +void populateTransferReadPattern(Operation *op, int64_t stride, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Currently use f32 as the element type. + FloatType f32 = mlir::FloatType::getF32(ctx); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + // Define `*Type`. + VectorType vectorTy32 = mlir::VectorType::get({stride}, f32); + VectorType vectorMaskTy = VectorType::get({stride}, i1); + // Create constant index. + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + Value cStride = rewriter.create(loc, stride); + Value f0 = rewriter.create( + loc, APFloat::getZero(f32.getFloatSemantics()), f32 + ); + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy32, f0); + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Create DimOp. + Value kernelRow = rewriter.create(loc, kernel, c0); + Value kernelCol = rewriter.create(loc, kernel, c1); + Value outputRow = rewriter.create(loc, output, c1); + Value outputCol = rewriter.create(loc, output, c2); + Value batch = rewriter.create(loc, input, c0); + Value feature = rewriter.create(loc, kernel, c3); + Value channel = rewriter.create(loc, kernel, c2); + // Size of strip mining. + AffineExpr d0; + bindDims(ctx, d0); + AffineMap stripMap = AffineMap::get(1, 0, {d0.ceilDiv(stride)}, ctx); + SmallVector lowerBounds(6, c0); + SmallVector uperBounds{batch, outputRow, kernelRow, kernelCol, channel, feature}; + SmallVector steps(6, 1); + + // SmallVector transposeShape{batch, feature, outputRow, outputCol}; + MemRefType memType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic, + ShapedType::kDynamic, ShapedType::kDynamic}, f32); + Value outputTranspose = rewriter.create(loc, memType, ValueRange{batch, feature, outputRow, outputCol}); + + SmallVector lower(4, c0); + SmallVector upper{batch, outputRow, outputCol, feature}; + SmallVector storeStep(4, 1); + affine::buildAffineLoopNest( + rewriter, loc, lower, upper, storeStep, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Load from origin + Value val = builder.create(loc, output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + // Store to transpose + builder.create(loc, val, outputTranspose, + ValueRange{ivs[0], ivs[3], ivs[1], ivs[2]}); + } + ); + + affine::buildAffineLoopNest(rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create stride loop. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{outputCol}, stripMap, 1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location channelLoc, Value iv, ValueRange itrArgs) { + // Get element frome kernel. + Value kernelValue = builder.create( + loc, kernel, ValueRange{ivs[2], ivs[3], ivs[4], ivs[5]} + ); + // Coefficients handle, if kernel item == 0, skip compute + Value kernelNonZeroCond = buddy::zeroCond( + builder, loc, f32, kernelValue, buddy::indexToF32(builder, loc, c0) + ); + builder.create(loc, kernelNonZeroCond, + [&](OpBuilder &builder, Location loc) { + // Broadcast element of the kernel. + Value kernelVector = builder.create( + loc, vectorTy32, kernelValue + ); + // Calculate the tail. + Value currCol = nestedBuilder.create(loc, iv, cStride); + Value tail = nestedBuilder.create(loc, outputCol, currCol); + + Value inputRowTailIdx = builder.create(loc, ivs[1], ivs[2]); + //Value outputColTailIdx = builder.create(loc, iv, cStride); + Value inputColTailIdx = builder.create(loc, ivs[3], currCol); + Value tailMask = builder.create(loc, vectorMaskTy, tail); + // Define AffineMap (d0, d1, d2, d3) -> (d2) + AffineExpr d0, d1, d2, d3; + bindDims(ctx, d0, d1, d2, d3); + AffineMap tansposeMap = AffineMap::get(4, 0, {d2}, ctx); + SmallVector inBounds(1, true); + // Load input/output vector from memref. + Value inputVector = builder.create( + loc, vectorTy32, input, ValueRange{ivs[0], inputRowTailIdx, inputColTailIdx, ivs[4]}, + AffineMapAttr::get(tansposeMap), f0, tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + ); + Value outputVector = builder.create(loc, vectorTy32, outputTranspose, + ValueRange{ivs[0], ivs[5], ivs[1], currCol}, tailMask, passThroughVec + ); + + // Multiply input vector and kernel vector then Add OutputVector(FMA). + Value resultVector = builder.create( + loc, inputVector, kernelVector, outputVector + ); + // Store result vector to output. + // builder.create( + // loc, resultVector, output, ValueRange{ivs[0], ivs[1], currCol, ivs[5]}, + // AffineMapAttr::get(tansposeMap), tailMask, ArrayAttr::get(ctx, builder.getBoolAttr(true)) + // ); + builder.create( + loc, outputTranspose, ValueRange{ivs[0], ivs[5], ivs[1], currCol}, + tailMask, resultVector + ); + builder.create(loc); + }); + nestedBuilder.create(channelLoc); + } + ); + }); + + + affine::buildAffineLoopNest( + rewriter, loc, lower, upper, storeStep, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Load from transpose + Value val = builder.create(loc, outputTranspose, + ValueRange{ivs[0], ivs[3], ivs[1], ivs[2]}); + // Store to origin + builder.create(loc, val, output, + ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]}); + } + ); + // Remove the origin convolution operation + rewriter.eraseOp(op); +} + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +enum OptMode {Transpose, Transfer}; + +namespace { +class ConvBroadcastOptimizePattern : public ConversionPattern { +public: + explicit ConvBroadcastOptimizePattern(MLIRContext *context, int64_t strideParam, + OptMode modeParam) + : ConversionPattern(linalg::Conv2DNhwcHwcfOp::getOperationName(), 1, context) { + + stride = strideParam; + mode = modeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (mode == Transpose) { + populateTransposePattern(op, stride, rewriter); + } else { + populateTransferReadPattern(op, stride, rewriter); + } return success(); } private: int64_t stride; - ArrayRef tileSizes; + OptMode mode; }; } // end anonymous namespace @@ -176,9 +431,9 @@ class ConvBroadcastNhwcHwcfPass ConvBroadcastNhwcHwcfPass() = default; ConvBroadcastNhwcHwcfPass(const ConvBroadcastNhwcHwcfPass &) {} explicit ConvBroadcastNhwcHwcfPass(int64_t strideParam, - ArrayRef tileParam) { + OptMode modeParam) { stride = strideParam; - tile = tileParam; + mode = modeParam; } void runOnOperation() override; @@ -189,10 +444,15 @@ class ConvBroadcastNhwcHwcfPass } Option stride{*this, "stride", - llvm::cl::desc("Stride size."), + llvm::cl::desc("Transfer Read Stride size."), llvm::cl::init(32)}; - ListOption tile{*this, "tile-sizes", llvm::cl::desc("Tile sizes"), - llvm::cl::ZeroOrMore}; + Option mode{*this, "mode", + llvm::cl::desc("Broadcast Optmize mode"), + llvm::cl::values( + clEnumValN(Transpose, "transpose", "Transpose Data Layout"), + clEnumValN(Transfer, "transfer", "Use Transfer Read") + ), + llvm::cl::init(Transpose)}; }; } // end anonymous namespace @@ -209,7 +469,7 @@ void ConvBroadcastNhwcHwcfPass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, stride, tile); + patterns.add(context, stride, mode); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From e2529d81aa3d4ddb47f4a5501d6acce7abad7d6b Mon Sep 17 00:00:00 2001 From: heyi Date: Sat, 7 Oct 2023 00:53:50 +0000 Subject: [PATCH 09/11] modify test file run command --- tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir index 89b534973d..01aaebaa63 100644 --- a/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir +++ b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir @@ -3,7 +3,7 @@ // RUN: -convert-linalg-to-loops -convert-vector-to-scf -lower-affine \ // RUN: -convert-scf-to-cf -convert-vector-to-llvm \ // RUN: -finalize-memref-to-llvm -convert-arith-to-llvm \ -// RUN: -convert-func-to-llvm -convert-cf-to-llvm -reconcile-unrealized-cas \ +// RUN: -convert-func-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_lib \ // RUN: -shared-libs=%mlir_c_runner_utils_lib \ From bb2b6c453f0ecab6b4c5462af15c538834322b92 Mon Sep 17 00:00:00 2001 From: heyi Date: Mon, 9 Oct 2023 13:39:30 +0000 Subject: [PATCH 10/11] fix test file issue --- tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir index 01aaebaa63..1abd35ab67 100644 --- a/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir +++ b/tests/Conversion/conv2d_nhwc_hwcf-broadcast.mlir @@ -5,8 +5,8 @@ // RUN: -finalize-memref-to-llvm -convert-arith-to-llvm \ // RUN: -convert-func-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils_lib \ -// RUN: -shared-libs=%mlir_c_runner_utils_lib \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s module { func.func private @printMemrefF32(memref<*xf32>) From 9626daf072a405e7713f4ad67067bc3dd551f774 Mon Sep 17 00:00:00 2001 From: Kingkiller Date: Mon, 30 Oct 2023 22:22:59 +0800 Subject: [PATCH 11/11] Add Polyhedral Tiling Optimization --- midend/include/Utils/TileSizeSelection.h | 27 ++ midend/lib/Conversion/CMakeLists.txt | 1 + .../PolyhedralOptimization/CMakeLists.txt | 4 + .../PolyhedralOptimization/TileAndFuse.cpp | 252 ++++++++++++++++++ .../PolyhedralOptimization/Tiling.cpp | 109 ++++++++ midend/lib/Utils/CMakeLists.txt | 1 + midend/lib/Utils/TileSizeSelection.cpp | 35 +++ 7 files changed, 429 insertions(+) create mode 100644 midend/include/Utils/TileSizeSelection.h create mode 100644 midend/lib/Conversion/PolyhedralOptimization/CMakeLists.txt create mode 100644 midend/lib/Conversion/PolyhedralOptimization/TileAndFuse.cpp create mode 100644 midend/lib/Conversion/PolyhedralOptimization/Tiling.cpp create mode 100644 midend/lib/Utils/TileSizeSelection.cpp diff --git a/midend/include/Utils/TileSizeSelection.h b/midend/include/Utils/TileSizeSelection.h new file mode 100644 index 0000000000..08b1f23518 --- /dev/null +++ b/midend/include/Utils/TileSizeSelection.h @@ -0,0 +1,27 @@ +#ifndef INCLUDE_UTILS_TILESIZESELECTION_H_ +#define INCLUDE_UTILS_TILESIZESELECTION_H_ + +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" + +using namespace mlir; + +namespace buddy { +// Map a range to a SmallVectot with element types deduced from the mapping. +template +auto map_to_vec(ContainerTy &&C, FuncTy &&F) { + return llvm::to_vector( + llvm::map_range(std::forward(C), std::forward(F))); +} + +template +auto map_to_vec(ContainerTy &&C, FuncTy &&F) { + return to_vector( + map_range(std::forward(C), std::forward(F))); +} + +void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp, + SmallVector tileSizes, + SmallVector tileScalableFlags); +} // namespace buddy + +#endif // INCLUDE_UTILS_TILESIZESELECTION_H_ \ No newline at end of file diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 4a1067a982..c4cc314dd1 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(ConvOptimization) add_subdirectory(LowerVectorExp) add_subdirectory(LowerGemmini) add_subdirectory(LowerLinalgToGemmini) +add_subdirectory(PolyhedralOptimization) \ No newline at end of file diff --git a/midend/lib/Conversion/PolyhedralOptimization/CMakeLists.txt b/midend/lib/Conversion/PolyhedralOptimization/CMakeLists.txt new file mode 100644 index 0000000000..92f2b8f7f4 --- /dev/null +++ b/midend/lib/Conversion/PolyhedralOptimization/CMakeLists.txt @@ -0,0 +1,4 @@ +add_mlir_library(PolyhedralOptimization + Tiling.cpp + TileAndFuse.cpp + ) \ No newline at end of file diff --git a/midend/lib/Conversion/PolyhedralOptimization/TileAndFuse.cpp b/midend/lib/Conversion/PolyhedralOptimization/TileAndFuse.cpp new file mode 100644 index 0000000000..413c7bbf14 --- /dev/null +++ b/midend/lib/Conversion/PolyhedralOptimization/TileAndFuse.cpp @@ -0,0 +1,252 @@ +#include "llvm/Support/Debug.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "Utils/TileSizeSelection.h" + +#define DEBUG_TYPE "polyhedral-tile-and-fuse" + +using namespace mlir; + +namespace { + +/// Starting from `op` walk all operands backwards to find all potentially +/// fuseable operations, i.e. operations that implement the `TilingInterface` +void collectTiledAndFusedOps(Operation *rootOp, + llvm::SmallDenseSet &result) { + SmallVector worklist; + worklist.push_back(rootOp); + result.insert(rootOp); + while (!worklist.empty()) + { + Operation *current = worklist.pop_back_val(); + for (OpOperand &operand : current->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer || !isa(producer) || + result.count(producer)) + continue; + worklist.push_back(producer); + result.insert(producer); + } + } +} + +FailureOr +foldIfGeneratedFromPadding(RewriterBase &rewriter, tensor::PadOp untiledPadOp, + tensor::PadOp tiledPadOp) { + auto ifOp = dyn_cast(tiledPadOp->getParentOp()); + if (!ifOp) + return failure(); + Block *block = tiledPadOp->getBlock(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, ifOp, {}); + rewriter.replaceOp(ifOp, results); + rewriter.eraseOp(terminator); + return tiledPadOp; +} + +struct TileAndFusePass : public PassWrapper> { + TileAndFusePass(int64_t tilingLevel) { + this->tilingLevel = tilingLevel; + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileAndFusePass) + StringRef getArgument() const final { return "polyhedral-tile-and-fuse"; } + StringRef getDescription() const final { return "Tile and fuse"; } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() override; + + int64_t tilingLevel; +}; + +LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp, + DominanceInfo &dominanceInfo, + scf::SCFTilingOptions options) { + llvm::SmallDenseSet originTiledAndFuseOps; + collectTiledAndFusedOps(rootOp, originTiledAndFuseOps); + auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) { + return originTiledAndFuseOps.count(user) || isa(user); + }; + + // 1. Tile the consumer. + SmallVector yieldedValuesToOrigValues; + SmallVector tiledOps; + FailureOr tilingResult = + scf::tileUsingSCFForOp(rewriter, cast(rootOp), options); + if (failed(tilingResult)) { + return failure(); + } + auto forLoops = llvm::to_vector(llvm::map_range(tilingResult->loops, + [](Operation *op) { return cast(op); })); + yieldedValuesToOrigValues.append(rootOp->result_begin(), + rootOp->result_end()); + // A map from untiled value to scf.for iter_arg. The iter_arg is used for DPS + // init operand of they use the same init operand + llvm::DenseMap mapToIterArg; + + if (auto rootPadOp = dyn_cast(rootOp)) { + assert(tilingResult->tiledOps.size() == 1 && + "Expecting only one tiled op for tensor::PadOp"); + FailureOr replacementTiledOp = foldIfGeneratedFromPadding( + rewriter, rootPadOp, cast(tilingResult->tiledOps[0])); + if (!failed(replacementTiledOp)) { + tilingResult->tiledOps[0] = replacementTiledOp.value(); + } + } else if (auto dpsOp = dyn_cast(rootOp)) { + for (auto [init, iterArg] : llvm::zip_equal( + dpsOp.getDpsInitOperands(), + cast(forLoops.back()).getRegionIterArgs())) { + mapToIterArg[init->get()] = iterArg; + } + } + tiledOps.append(tilingResult->tiledOps); + + // 2. Tiling each operation results in generation of slices. The source of + // these slices could be producers that can be fused into the tiled loops by + // computing the slices of these producers in-place. This results in more + // slices created for operands of the "fused producer". This open up more + // opportunities for fusion. Use a worklist to fuse greedily. + auto addCandidateSlices = + [&](Operation *fusedOp, std::deque &candidates) { + for (OpOperand &operand : fusedOp->getOpOperands()) { + auto sliceOp = operand.get().getDefiningOp(); + if (!sliceOp) + continue; + candidates.push_back(sliceOp); + + auto dpsOp = dyn_cast(fusedOp); + if (!dpsOp) + continue; + + if (dpsOp.isDpsInit(&operand) && + mapToIterArg.contains(sliceOp.getSource())) { + rewriter.startRootUpdate(sliceOp); + sliceOp.getSourceMutable().assign(mapToIterArg[sliceOp.getSource()]); + rewriter.finalizeRootUpdate(sliceOp); + } + } + }; + + std::deque candidates; + addCandidateSlices(tilingResult->tiledOps.back(), candidates); + OpBuilder::InsertionGuard g(rewriter); + while (!candidates.empty()) + { + // Traverse the slices in BFS fashion. + tensor::ExtractSliceOp candidateSliceOp = candidates.front(); + candidates.pop_front(); + + // Materialize the slice of the producer in place. + std::optional fusedProducer = + scf::tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops); + if (!fusedProducer) + continue; + + // Check if the fused producer has other uses that require the value + // to be yielded from within the tiled loop. + OpResult untiledProducer = fusedProducer->origProducer; + if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) { + return !isIgnoredUser(user, forLoops.front()) && + !forLoops.front()->isAncestor(user); + })) { + scf::yieldReplacementForFusedProducer(rewriter, candidateSliceOp, + fusedProducer.value(), forLoops); + yieldedValuesToOrigValues.push_back(untiledProducer); + } + + // Add more fusion candidates to the worklist. + for (auto tiledOp : fusedProducer->tiledOps) { + addCandidateSlices(tiledOp, candidates); + tiledOps.push_back(tiledOp); + } + } + + scf::ForOp outermostLoop = forLoops.front(); + for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) { + Value replacement = outermostLoop.getResult(index); + rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { + return !isIgnoredUser(use.getOwner(), outermostLoop) && + dominanceInfo.properlyDominates(outermostLoop, use.getOwner()); + }); + } + + return success(); +} + +void TileAndFusePass::runOnOperation() { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + TilingInterface consumerOp; + funcOp.walk([&](TilingInterface op) { + // Find the next consumer op if it does not have loops. + if (op.getLoopIteratorTypes().empty()) + return WalkResult::advance(); + consumerOp = op; + return WalkResult::interrupt(); + }); + if (!consumerOp) { + LLVM_DEBUG(llvm::dbgs() << "No consumer op found, skip tiling\n"); + return; + } + + SmallVector tileSizes; + SmallVector tileScalableFlags; + + // todo: configure tile sizes and tile scalable flags + + if (llvm::all_of(tileSizes, [&](int64_t size) { return size == 0; })) { + LLVM_DEBUG(llvm::dbgs() << "All tile sizes are 0, skip tiling\n"); + return; + } + + scf::SCFTilingOptions options{}; + buddy::setSCFTileSizes(options, consumerOp, std::move(tileSizes), + std::move(tileScalableFlags)); + + IRRewriter rewriter(context); + DominanceInfo domainInfo(funcOp); + if (failed(applyTileAndFuse(rewriter, consumerOp, domainInfo, options))) { + LLVM_DEBUG(llvm::dbgs() << "Failed to tile and fuse\n"); + return signalPassFailure(); + } + + RewritePatternSet patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + context->getLoadedDialect() + ->getCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "Failed to canonicalize\n"); + return signalPassFailure(); + } +} + +} // namespace + +namespace mlir { +namespace buddy { +void registerPolyhedralTileAndFusePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir \ No newline at end of file diff --git a/midend/lib/Conversion/PolyhedralOptimization/Tiling.cpp b/midend/lib/Conversion/PolyhedralOptimization/Tiling.cpp new file mode 100644 index 0000000000..3c749aacac --- /dev/null +++ b/midend/lib/Conversion/PolyhedralOptimization/Tiling.cpp @@ -0,0 +1,109 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Iterators.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/SmallVector.h" +#include "Utils/TileSizeSelection.h" + +#define DEBUG_TYPE "polyhedral-tiling" + +using namespace mlir; + +namespace { + +SmallVector getComputeOps(func::FuncOp funcOp) { + SmallVector computeOps; + funcOp.walk([&](Operation *op) { + if (isa(op)) { + computeOps.push_back(op); + } + }); + return computeOps; +} + +struct TilingPass : public PassWrapper> { + TilingPass(int64_t tilingLevel = -1) { + this->tilingLevel = tilingLevel; + } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TilingPass) + StringRef getArgument() const final { return "polyhedral-tiling"; } + StringRef getDescription() const final { + return "Tiling for polyhedral optimization"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override; + + int64_t tilingLevel; +}; + +void TilingPass::runOnOperation() { + if (tilingLevel == -1) { + LLVM_DEBUG(llvm::dbgs() << "tilingLevel is -1, skip tiling\n"); + return; + } + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + SmallVector computeOps = getComputeOps(funcOp); + + for (auto computeOp : computeOps) { + auto op = cast(computeOp); + if (op.getLoopIteratorTypes().empty()) + continue; + + SmallVector tileSizes; + SmallVector tileScalableFlags; + + if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0;})) { + LLVM_DEBUG(llvm::dbgs() << "tileSizes are all 0, skip tiling\n"); + return; + } + IRRewriter rewriter(context); + scf::SCFTilingOptions options{}; + buddy::setSCFTileSizes(options, op, std::move(tileSizes), + std::move(tileScalableFlags)); + FailureOr tiledResults = + scf::tileUsingSCFForOp(rewriter, op, options); + if (failed(tiledResults)) { + continue; + } + rewriter.replaceOp(op, tiledResults->replacements); + } + + RewritePatternSet patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + context->getLoadedDialect() + ->getCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace + + +namespace mlir { +namespace buddy { +void registerPolyhedralTilingPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir \ No newline at end of file diff --git a/midend/lib/Utils/CMakeLists.txt b/midend/lib/Utils/CMakeLists.txt index 7d21a67657..b8937cc961 100644 --- a/midend/lib/Utils/CMakeLists.txt +++ b/midend/lib/Utils/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(BuddyUtils Utils.cpp DIPUtils.cpp AffineTransformUtils.cpp + TileSizeSelection.cpp ) add_mlir_library(BuddyDIPUtils diff --git a/midend/lib/Utils/TileSizeSelection.cpp b/midend/lib/Utils/TileSizeSelection.cpp new file mode 100644 index 0000000000..c9d7da9ac4 --- /dev/null +++ b/midend/lib/Utils/TileSizeSelection.cpp @@ -0,0 +1,35 @@ +#include "Utils/TileSizeSelection.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace buddy { +void setSCFTileSizes(scf::SCFTilingOptions &options, TilingInterface consumerOp, + SmallVector tileSizes, + SmallVector tileScalableFlags) { + int numLoops = consumerOp.getLoopIteratorTypes().size(); + tileSizes.resize(numLoops, 0); + tileScalableFlags.resize(numLoops, false); + if (!llvm::is_contained(tileSizes, 1)) { + // Non-scalabe case: All constant tile sizes. + options.setTileSizes( + tileSizes); + // getAsIndexOpFoldResult(consumerOp.getContext(), tileSizes)); + } else { + // Scalable case: Multiply scalable tile sizes by a vector.vscale op. + options.setTileSizeComputationFunction( + [=](OpBuilder &builder, Operation *op) -> SmallVector { + auto loc = op->getLoc(); + return map_to_vec( + llvm::zip(tileSizes, tileScalableFlags), + [&](auto pair) -> Value { + auto [t, isScalable] = pair; + Value size = builder.create(loc, t); + if (isScalable) { + Value vscale = builder.create(loc); + size = builder.create(loc, size, vscale); + } + return size; + }); + }); + } +} +} // namespace buddy \ No newline at end of file