From 06aec762238190b96b8ed55228503ed5e8af3b69 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Fri, 8 Nov 2024 05:08:36 +0000 Subject: [PATCH 1/6] [midend/lib/Conversion/MatMulOptimization] add batchmatmul vectorization pass and add relevant examples and tests. --- .../batchmatmul-vectorization.mlir | 134 ++++++ examples/BuddyMatmul/makefile | 39 ++ .../BatchMatMulOptimize.cpp | 393 ++++++------------ .../Conversion/batchmatmul-vectorization.mlir | 40 ++ 4 files changed, 345 insertions(+), 261 deletions(-) create mode 100644 examples/BuddyMatmul/batchmatmul-vectorization.mlir create mode 100644 tests/Conversion/batchmatmul-vectorization.mlir diff --git a/examples/BuddyMatmul/batchmatmul-vectorization.mlir b/examples/BuddyMatmul/batchmatmul-vectorization.mlir new file mode 100644 index 0000000000..a46d59e9a8 --- /dev/null +++ b/examples/BuddyMatmul/batchmatmul-vectorization.mlir @@ -0,0 +1,134 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + // CMK * CKN -> CMN + func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %vl_step = arith.constant 32 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.splat %cst : vector<32xf32> + %dim = memref.dim %arg0, %c0 : memref + %dim_1 = memref.dim %arg0, %c1 : memref + %dim_2 = memref.dim %arg1, %c1 : memref + %dim_3 = memref.dim %arg1, %c2 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %dim_3_upbound_tmp = arith.subi %dim_3, %vl_step : index + %dim_3_upbound = arith.addi %dim_3_upbound_tmp, %c1 : index + + affine.for %arg3 = %c0 to %dim { // C + affine.prefetch %arg0[%arg3, %dim_1, %dim_2], read, locality<3>, data : memref + affine.for %arg4 = %c0 to %dim_1 { // M + // Perform the vectorization body. + %iter_idx = scf.for %arg5 = %c0 to %dim_3_upbound + step %vl_step iter_args(%iter_init = %c0) -> (index) { // N + %1 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> + %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 + iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K + %5 = memref.load %arg0[%arg3, %arg4, %arg6] : memref + %6 = vector.broadcast %5 : f32 to vector<32xf32> + %4 = vector.load %arg1[%arg3, %arg6, %arg5] : memref, vector<32xf32> + %8 = vector.fma %6, %4, %iter_vec0 : vector<32xf32> + scf.yield %8 : vector<32xf32> + } + vector.store %iter_vec, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> + %arg5_next = arith.addi %arg5, %vl_step : index + scf.yield %arg5_next : index + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %dim_3, %iter_idx : index + %mask = vector.create_mask %tail_size : vector<32xi1> + %1 = vector.maskedload %arg2[%arg3, %arg4, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 + iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K + %5 = vector.maskedload %arg1[%arg3, %arg6, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %6 = memref.load %arg0[%arg3, %arg4, %arg6] : memref + %7 = vector.broadcast %6 : f32 to vector<32xf32> + %9 = vector.fma %7, %5, %iter_vec0 : vector<32xf32> + scf.yield %9 : vector<32xf32> + } + vector.maskedstore %arg2[%arg3, %arg4, %iter_idx], %mask, %iter_vec : memref, vector<32xi1>, vector<32xf32> + } + } + return + } + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2] : memref + } + } + } + return %0 : memref + } + + func.func @main(){ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c576 = arith.constant 576 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %f0 = arith.constant 0.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + + %m0 = call @alloc_f32(%c1, %c1, %c576, %f2) : (index, index, index, f32) -> memref + %m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref + %m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref + + call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () + + %printed_m2 = memref.cast %m2 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [3456{{(, 3456)*}}] + call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + + %m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref + %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref + %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref + + call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () + + %printed_m5 = memref.cast %m5 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [6144{{(, 6144)*}}] + call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + + return + } +} diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile index 0940d608da..1a0655e9fc 100644 --- a/examples/BuddyMatmul/makefile +++ b/examples/BuddyMatmul/makefile @@ -18,7 +18,29 @@ MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.dylib MTRIPLE := x86_64-apple-darwin endif +linalg-batchmatmul-lower: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -batchmatmul-optimize \ + -o ./log.mlir + linalg-batchmatmul-f32-run: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batchmatmul-f32-vectorization-run: @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ -batchmatmul-optimize \ -convert-linalg-to-affine-loops \ @@ -53,3 +75,20 @@ linalg-matmul-transpose-b-f32-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +batchmatmul-vectorization-run: + @${BUDDY_OPT} ./batchmatmul-vectorization.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 6cedaa1655..b940a86715 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -18,6 +18,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -25,6 +26,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/ValueRange.h" #include "llvm/ADT/ArrayRef.h" #include @@ -32,6 +34,10 @@ #include #include #include +#include +#include +#include +#include #include #include #include @@ -40,7 +46,6 @@ using namespace mlir; using namespace vector; -using namespace affine; //===----------------------------------------------------------------------===// // Rewrite Pattern @@ -51,292 +56,158 @@ namespace { class BatchMatMulOptimizePattern : public ConversionPattern { public: explicit BatchMatMulOptimizePattern(MLIRContext *context, - int64_t affineVectorSizeParam) + int64_t vecSizeParam) : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, context) { - affineVectorSize = affineVectorSizeParam; + vecSize = vecSizeParam; } LogicalResult matchAndRewrite(Operation *op, ArrayRef /*operands*/, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); + auto ctx = op->getContext(); // Retrieve input tensors A, B, and C. Value A = op->getOperand(0); Value B = op->getOperand(1); Value C = op->getOperand(2); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({vecSize}, i1); // Acquire the element type of input tensors. Type elementType = A.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({vecSize}, elementType); - // Define constants. - const Value zeroIndex = - rewriter.create(loc, rewriter.getIndexAttr(0)); const AffineExpr d0 = rewriter.getAffineDimExpr(0); const AffineExpr d1 = rewriter.getAffineDimExpr(1); const AffineExpr d2 = rewriter.getAffineDimExpr(2); - const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); - const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); - const Value zeroElementType = rewriter.create( + // Define constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value vl_step = rewriter.create(loc, vecSize); + const Value zero = rewriter.create( loc, rewriter.getZeroAttr(elementType)); - const Value zeroElementTypeVec = rewriter.create( - loc, VectorType::get({affineVectorSize}, elementType), zeroElementType); - - // Get dimensions of input tensors. - Value batch = rewriter.create(loc, A, 0); - Value aRow = rewriter.create(loc, A, 1); - Value bCol = rewriter.create(loc, B, 2); - Value bRow = rewriter.create(loc, B, 1); - - // Calculate the length of the tail, which might not fit in a vector. - Value tailLength = rewriter.create( - loc, AffineMap::get(1, 0, d0 % affineVectorSize), ValueRange{bCol}); - // Generate a mask vector based on the tail length. - Value maskVector = rewriter.create( - loc, VectorType::get({affineVectorSize}, rewriter.getI1Type()), - ValueRange{tailLength}); + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); - SmallVector reducedValues = llvm::to_vector<4>( - llvm::map_range(ArrayRef{}, - [](const LoopReduction &red) { return red.value; })); - - // Apply the column of matrix B. - Value appliedColOfB = rewriter.create( - loc, AffineMap::get(1, 0, d0.ceilDiv(affineVectorSize)), - ValueRange{bCol}); - - // Create the primary parallel batch level loop. - AffineParallelOp parallelBatchLoop = - rewriter.create( - loc, ValueRange(reducedValues).getTypes(), ValueRange{batch}, - ArrayRef{ - rewriter.getNamedAttr("lowerBoundsGroups", - rewriter.getI32TensorAttr({1})), - rewriter.getNamedAttr("upperBoundsGroups", - rewriter.getI32TensorAttr({1})), - rewriter.getNamedAttr( - "lowerBoundsMap", - AffineMapAttr::get(AffineMap::get(0, 0, {zeroAffine}, - rewriter.getContext()))), - rewriter.getNamedAttr("upperBoundsMap", - AffineMapAttr::get(AffineMap::get( - 1, 0, {d0}, rewriter.getContext()))), - rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), - rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); - - // Create the loop body for the parallel loop. - Block *loopBody = new Block(); - rewriter.setInsertionPointToStart(loopBody); - loopBody->addArgument(rewriter.getIndexType(), loc); - Value loopVarBatchIdx = loopBody->getArguments()[0]; - - // Prefetching data from tensor 'A' for better cache utilization. - rewriter.create( - loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), - ArrayRef{loopVarBatchIdx, aRow, bRow}, false, 3, true); + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, c0); + Value aRow = rewriter.create(loc, A, c1); + Value bCol = rewriter.create(loc, B, c2); + Value bRow = rewriter.create(loc, B, c1); + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBound_tmp = rewriter.create(loc, bCol, vl_step); + Value upperBound = rewriter.create(loc, upperBound_tmp, c1); affine::buildAffineLoopNest( - rewriter, loc, {zeroIndex}, {appliedColOfB}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarColOfB = ivRange.front(); - - // Compile time branch detection. - if (C.getType().cast().isDynamicDim(2) or - C.getType().cast().getDimSize(2) % affineVectorSize != - 0) { - - // Depending on the position, use either full vectors or tail - // vectors. - affine::AffineIfOp branchingOp = builder.create( - loc, - IntegerSet::get( - 1, 1, {d0 * -affineVectorSize + s0 - affineVectorSize}, - {false}), - ValueRange{loopVarColOfB, bCol}, true); - - // Branch handling full vector operations. - OpBuilder trueBranchBuilder = branchingOp.getThenBodyBuilder(); - affine::buildAffineLoopNest( - trueBranchBuilder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(3, 0, {d0, d1, d2 * affineVectorSize}, - rewriter.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfB, - loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, computedVec, C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - }); - }); - - // Branch handling operations on the tail. - OpBuilder falseBranchBuilder = branchingOp.getElseBodyBuilder(); - affine::buildAffineLoopNest( - falseBranchBuilder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value tailIdxColOfB = builder.create( - loc, AffineMap::get(1, 0, d0 * affineVectorSize), - ValueRange{loopVarColOfB}); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - ValueRange{loopVarBatchIdx, loopVarRowOfB, tailIdxColOfB}, - maskVector, zeroElementTypeVec); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - tailIdxColOfB}, - maskVector, zeroElementTypeVec); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, C, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - tailIdxColOfB}, - maskVector, computedVec); - }); - }); - } else { - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {bRow}, 1, - [&](OpBuilder &builder, Location loc, ValueRange ivRange) { - Value loopVarRowOfB = ivRange.front(); - Value bVec = builder.create( - loc, VectorType::get({affineVectorSize}, elementType), B, - AffineMap::get(3, 0, {d0, d1, d2 * affineVectorSize}, - rewriter.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfB, - loopVarColOfB}); - affine::buildAffineLoopNest( - builder, loc, {zeroIndex}, {aRow}, 1, - [&](OpBuilder &builder, Location loc, - ValueRange ivRange) { - Value loopVarRowOfA = ivRange.front(); - Value aElement = builder.create( - loc, A, - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarRowOfB}); - Value aVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), - aElement); - Value cVec = builder.create( - loc, - VectorType::get({affineVectorSize}, elementType), C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - Value computedVec; - - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - if (isa(elementType)) { - Value mulVec = - builder.create(loc, aVec, bVec); - computedVec = - builder.create(loc, mulVec, cVec); - } else { - computedVec = builder.create( - loc, aVec, bVec, cVec); - } - builder.create( - loc, computedVec, C, - AffineMap::get(3, 0, - {d0, d1, d2 * affineVectorSize}, - builder.getContext()), - ValueRange{loopVarBatchIdx, loopVarRowOfA, - loopVarColOfB}); - }); - }); - } + rewriter, loc, {c0}, {batch}, /*Step=*/1, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Prefetching data from tensor 'A' for better cache utilization. + builder.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, ctx), + ArrayRef{ivs[0], aRow, bRow}, false, 3, true); + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{aRow}, builder.getDimIdentityMap(), + /*Step=*/1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs0) { + auto iter_idx = builder.create( + loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv2, + ValueRange itrArgs0) { + Value cVec = builder.create( + loc, vectorTy, C, ValueRange{ivs[0], iv1, iv2}); + auto iter_vec = nestedBuilder.create( + nestedLoc, c0, bRow, /*Step=*/c1, ValueRange{cVec}, + [&](OpBuilder &builder, Location loc, Value iv3, + ValueRange itrArgs1) { + Value aValue = builder.create( + loc, elementType, A, + ValueRange{ivs[0], iv1, iv3}); + Value aVec = builder.create( + loc, vectorTy, aValue); + Value bVec = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], iv3, iv2}); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value computedVec; + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, itrArgs1[0]); + } else { + computedVec = builder.create( + loc, aVec, bVec, itrArgs1[0]); + } + builder.create(loc, computedVec); + }); + nestedBuilder.create( + nestedLoc, iter_vec.getResult(0), C, + ValueRange{ivs[0], iv1, iv2}); + Value idx = nestedBuilder.create( + nestedLoc, iv2, vl_step); + nestedBuilder.create(nestedLoc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iter_idx.getResult(0); + Value tailSize = builder.create(loc, bCol, idx); + // Create mask according to the tail. + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + Value maskedCVec = builder.create( + loc, vectorTy, C, ValueRange{ivs[0], iv1, idx}, tailMask, + passThroughVec); + auto iter_vec = builder.create( + loc, c0, bRow, /*Step=*/c1, ValueRange{maskedCVec}, + [&](OpBuilder &builder, Location loc, Value iv3, + ValueRange itrArgs1) { + Value aValue = builder.create( + loc, A, ValueRange{ivs[0], iv1, iv3}); + Value aVec = builder.create( + loc, vectorTy, aValue); + Value maskedBVec = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], iv3, idx}, + tailMask, passThroughVec); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value computedVec; + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, maskedBVec); + computedVec = builder.create( + loc, mulVec, itrArgs1[0]); + } else { + computedVec = builder.create( + loc, aVec, maskedBVec, itrArgs1[0]); + } + builder.create(loc, computedVec); + }); + builder.create(loc, C, + ValueRange{ivs[0], iv1, idx}, + tailMask, iter_vec.getResult(0)); + builder.create(loc); + }); }); - - rewriter.create(loc); - - // Finalize the loop and erase the original operation. - parallelBatchLoop.getRegion().push_back(loopBody); - rewriter.setInsertionPointAfter(parallelBatchLoop); - rewriter.eraseOp(op); return success(); } private: - int64_t affineVectorSize; + int64_t vecSize; }; } // end anonymous namespace @@ -355,8 +226,8 @@ class BatchMatMulOptimizePass StringRef getDescription() const final { return "BatchMatMul Optimization."; } BatchMatMulOptimizePass() = default; BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {} - explicit BatchMatMulOptimizePass(int64_t affineVectorSizeParam) { - affineVectorSize = affineVectorSizeParam; + explicit BatchMatMulOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; } void runOnOperation() override; @@ -366,9 +237,9 @@ class BatchMatMulOptimizePass affine::AffineDialect, VectorDialect>(); } - Option affineVectorSize{*this, "vector-size", - llvm::cl::desc("Affine Vector size."), - llvm::cl::init(64)}; + Option vecSize{*this, "vector-size", + llvm::cl::desc("Affine Vector size."), + llvm::cl::init(32)}; }; } // end anonymous namespace. @@ -384,7 +255,7 @@ void BatchMatMulOptimizePass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, affineVectorSize); + patterns.add(context, vecSize); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/tests/Conversion/batchmatmul-vectorization.mlir b/tests/Conversion/batchmatmul-vectorization.mlir new file mode 100644 index 0000000000..47864ec031 --- /dev/null +++ b/tests/Conversion/batchmatmul-vectorization.mlir @@ -0,0 +1,40 @@ +// RUN: buddy-opt -batchmatmul-optimize %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim) { +// CHECK-NEXT: affine.prefetch %arg0[%arg3, %dim_0, %dim_2], read, locality<3>, data : memref +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_0) { +// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %2 step %c32 iter_args(%arg6 = %c0) -> (index) { +// CHECK-NEXT: %8 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %9 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %8) -> (vector<32xf32>) { +// CHECK-NEXT: %11 = memref.load %arg0[%arg3, %arg4, %arg7] : memref +// CHECK-NEXT: %12 = vector.broadcast %11 : f32 to vector<32xf32> +// CHECK-NEXT: %13 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %14 = vector.fma %12, %13, %arg8 : vector<32xf32> +// CHECK-NEXT: scf.yield %14 : vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.store %9, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %10 = arith.addi %arg5, %c32 : index +// CHECK-NEXT: scf.yield %10 : index +// CHECK-NEXT: } +// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index +// CHECK-NEXT: %5 = vector.create_mask %4 : vector<32xi1> +// CHECK-NEXT: %6 = vector.maskedload %arg2[%arg3, %arg4, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %7 = scf.for %arg5 = %c0 to %dim_2 step %c1 iter_args(%arg6 = %6) -> (vector<32xf32>) { +// CHECK-NEXT: %8 = memref.load %arg0[%arg3, %arg4, %arg5] : memref +// CHECK-NEXT: %9 = vector.broadcast %8 : f32 to vector<32xf32> +// CHECK-NEXT: %10 = vector.maskedload %arg1[%arg3, %arg5, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %11 = vector.fma %9, %10, %arg6 : vector<32xf32> +// CHECK-NEXT: scf.yield %11 : vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %5, %7 : memref, vector<32xi1>, vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.batch_matmul + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} From 1381f5874f9e613a593008f4b394032ed327387b Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Sun, 24 Nov 2024 05:35:19 +0000 Subject: [PATCH 2/6] [midend/lib/Conversion/MatMulOptimization] Fix batchmatmul vectorization pass and examples. --- .../BatchMatMulOptimize.cpp | 80 ++++++++++--------- .../Conversion/batchmatmul-vectorization.mlir | 42 +++++----- 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index b940a86715..79a055324c 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -162,44 +162,52 @@ class BatchMatMulOptimizePattern : public ConversionPattern { }); // Compute the tail size and Process the remaining elements // using masked vector operations. - Value idx = iter_idx.getResult(0); - Value tailSize = builder.create(loc, bCol, idx); - // Create mask according to the tail. - Value tailMask = - builder.create(loc, vectorMaskTy, tailSize); - Value maskedCVec = builder.create( - loc, vectorTy, C, ValueRange{ivs[0], iv1, idx}, tailMask, - passThroughVec); - auto iter_vec = builder.create( - loc, c0, bRow, /*Step=*/c1, ValueRange{maskedCVec}, - [&](OpBuilder &builder, Location loc, Value iv3, - ValueRange itrArgs1) { - Value aValue = builder.create( - loc, A, ValueRange{ivs[0], iv1, iv3}); - Value aVec = builder.create( - loc, vectorTy, aValue); - Value maskedBVec = builder.create( - loc, vectorTy, B, ValueRange{ivs[0], iv3, idx}, + builder.create( + loc, iter_idx.getResult(0), bCol, /*Step=*/vl_step, + std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, + ValueRange itrArgs) { + Value idx = iter_idx.getResult(0); + Value tailSize = + builder.create(loc, bCol, idx); + // Create mask according to the tail. + Value tailMask = builder.create( + loc, vectorMaskTy, tailSize); + Value maskedCVec = builder.create( + loc, vectorTy, C, ValueRange{ivs[0], iv1, idx}, tailMask, passThroughVec); - // Compute the result vector either through integer - // multiplication and addition or fused multiply-add - // based on the element type. - Value computedVec; - if (isa(elementType)) { - Value mulVec = builder.create( - loc, aVec, maskedBVec); - computedVec = builder.create( - loc, mulVec, itrArgs1[0]); - } else { - computedVec = builder.create( - loc, aVec, maskedBVec, itrArgs1[0]); - } - builder.create(loc, computedVec); + auto iter_vec = builder.create( + loc, c0, bRow, /*Step=*/c1, ValueRange{maskedCVec}, + [&](OpBuilder &builder, Location loc, Value iv3, + ValueRange itrArgs1) { + Value aValue = builder.create( + loc, A, ValueRange{ivs[0], iv1, iv3}); + Value aVec = builder.create( + loc, vectorTy, aValue); + Value maskedBVec = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], iv3, idx}, + tailMask, passThroughVec); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value computedVec; + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, maskedBVec); + computedVec = builder.create( + loc, mulVec, itrArgs1[0]); + } else { + computedVec = builder.create( + loc, aVec, maskedBVec, itrArgs1[0]); + } + builder.create(loc, computedVec); + }); + builder.create( + loc, C, ValueRange{ivs[0], iv1, idx}, tailMask, + iter_vec.getResult(0)); + builder.create(loc); }); - builder.create(loc, C, - ValueRange{ivs[0], iv1, idx}, - tailMask, iter_vec.getResult(0)); - builder.create(loc); + builder.create(loc); }); }); rewriter.eraseOp(op); diff --git a/tests/Conversion/batchmatmul-vectorization.mlir b/tests/Conversion/batchmatmul-vectorization.mlir index 47864ec031..070c836f7a 100644 --- a/tests/Conversion/batchmatmul-vectorization.mlir +++ b/tests/Conversion/batchmatmul-vectorization.mlir @@ -6,29 +6,31 @@ // CHECK-NEXT: affine.prefetch %arg0[%arg3, %dim_0, %dim_2], read, locality<3>, data : memref // CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_0) { // CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %2 step %c32 iter_args(%arg6 = %c0) -> (index) { -// CHECK-NEXT: %8 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %9 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %8) -> (vector<32xf32>) { -// CHECK-NEXT: %11 = memref.load %arg0[%arg3, %arg4, %arg7] : memref -// CHECK-NEXT: %12 = vector.broadcast %11 : f32 to vector<32xf32> -// CHECK-NEXT: %13 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %14 = vector.fma %12, %13, %arg8 : vector<32xf32> -// CHECK-NEXT: scf.yield %14 : vector<32xf32> +// CHECK-NEXT: %4 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %5 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %4) -> (vector<32xf32>) { +// CHECK-NEXT: %7 = memref.load %arg0[%arg3, %arg4, %arg7] : memref +// CHECK-NEXT: %8 = vector.broadcast %7 : f32 to vector<32xf32> +// CHECK-NEXT: %9 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %10 = vector.fma %8, %9, %arg8 : vector<32xf32> +// CHECK-NEXT: scf.yield %10 : vector<32xf32> // CHECK-NEXT: } -// CHECK-NEXT: vector.store %9, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %10 = arith.addi %arg5, %c32 : index -// CHECK-NEXT: scf.yield %10 : index +// CHECK-NEXT: vector.store %5, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %6 = arith.addi %arg5, %c32 : index +// CHECK-NEXT: scf.yield %6 : index // CHECK-NEXT: } -// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index -// CHECK-NEXT: %5 = vector.create_mask %4 : vector<32xi1> -// CHECK-NEXT: %6 = vector.maskedload %arg2[%arg3, %arg4, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> -// CHECK-NEXT: %7 = scf.for %arg5 = %c0 to %dim_2 step %c1 iter_args(%arg6 = %6) -> (vector<32xf32>) { -// CHECK-NEXT: %8 = memref.load %arg0[%arg3, %arg4, %arg5] : memref -// CHECK-NEXT: %9 = vector.broadcast %8 : f32 to vector<32xf32> -// CHECK-NEXT: %10 = vector.maskedload %arg1[%arg3, %arg5, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> -// CHECK-NEXT: %11 = vector.fma %9, %10, %arg6 : vector<32xf32> -// CHECK-NEXT: scf.yield %11 : vector<32xf32> +// CHECK-NEXT: scf.for %arg5 = %3 to %dim_1 step %c32 { +// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index +// CHECK-NEXT: %5 = vector.create_mask %4 : vector<32xi1> +// CHECK-NEXT: %6 = vector.maskedload %arg2[%arg3, %arg4, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %7 = scf.for %arg6 = %c0 to %dim_2 step %c1 iter_args(%arg7 = %6) -> (vector<32xf32>) { +// CHECK-NEXT: %8 = memref.load %arg0[%arg3, %arg4, %arg6] : memref +// CHECK-NEXT: %9 = vector.broadcast %8 : f32 to vector<32xf32> +// CHECK-NEXT: %10 = vector.maskedload %arg1[%arg3, %arg6, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %11 = vector.fma %9, %10, %arg7 : vector<32xf32> +// CHECK-NEXT: scf.yield %11 : vector<32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %5, %7 : memref, vector<32xi1>, vector<32xf32> // CHECK-NEXT: } -// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %5, %7 : memref, vector<32xi1>, vector<32xf32> // CHECK-NEXT: } // CHECK-NEXT: } From 9976e4b3aaa703d0fbd38857fcc5f34675e621d6 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Mon, 25 Nov 2024 05:58:42 +0000 Subject: [PATCH 3/6] [midend/lib/Conversion/MatMulOptimization] Fix batchmatmul vectorization pass and examples. --- .../batchmatmul-vectorization.mlir | 27 ++++++------ .../BatchMatMulOptimize.cpp | 17 ++++---- .../Conversion/batchmatmul-vectorization.mlir | 43 ++++++++++--------- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/examples/BuddyMatmul/batchmatmul-vectorization.mlir b/examples/BuddyMatmul/batchmatmul-vectorization.mlir index a46d59e9a8..1931f8c3cd 100644 --- a/examples/BuddyMatmul/batchmatmul-vectorization.mlir +++ b/examples/BuddyMatmul/batchmatmul-vectorization.mlir @@ -40,9 +40,9 @@ module { %dim_3_upbound_tmp = arith.subi %dim_3, %vl_step : index %dim_3_upbound = arith.addi %dim_3_upbound_tmp, %c1 : index - affine.for %arg3 = %c0 to %dim { // C + affine.for %arg3 = %c0 to %dim { // C affine.prefetch %arg0[%arg3, %dim_1, %dim_2], read, locality<3>, data : memref - affine.for %arg4 = %c0 to %dim_1 { // M + affine.for %arg4 = %c0 to %dim_1 { // M // Perform the vectorization body. %iter_idx = scf.for %arg5 = %c0 to %dim_3_upbound step %vl_step iter_args(%iter_init = %c0) -> (index) { // N @@ -62,17 +62,20 @@ module { // Compute the tail size and Process the remaining elements // using masked vector operations. %tail_size = arith.subi %dim_3, %iter_idx : index - %mask = vector.create_mask %tail_size : vector<32xi1> - %1 = vector.maskedload %arg2[%arg3, %arg4, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> - %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 - iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K - %5 = vector.maskedload %arg1[%arg3, %arg6, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> - %6 = memref.load %arg0[%arg3, %arg4, %arg6] : memref - %7 = vector.broadcast %6 : f32 to vector<32xf32> - %9 = vector.fma %7, %5, %iter_vec0 : vector<32xf32> - scf.yield %9 : vector<32xf32> + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { + %mask = vector.create_mask %tail_size : vector<32xi1> + %1 = vector.maskedload %arg2[%arg3, %arg4, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %iter_vec = scf.for %arg6 = %c0 to %dim_2 step %c1 + iter_args(%iter_vec0 = %1) -> (vector<32xf32>) { // K + %5 = vector.maskedload %arg1[%arg3, %arg6, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %6 = memref.load %arg0[%arg3, %arg4, %arg6] : memref + %7 = vector.broadcast %6 : f32 to vector<32xf32> + %9 = vector.fma %7, %5, %iter_vec0 : vector<32xf32> + scf.yield %9 : vector<32xf32> + } + vector.maskedstore %arg2[%arg3, %arg4, %iter_idx], %mask, %iter_vec : memref, vector<32xi1>, vector<32xf32> } - vector.maskedstore %arg2[%arg3, %arg4, %iter_idx], %mask, %iter_vec : memref, vector<32xi1>, vector<32xf32> } } return diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 79a055324c..df4f7bd49f 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -162,14 +162,13 @@ class BatchMatMulOptimizePattern : public ConversionPattern { }); // Compute the tail size and Process the remaining elements // using masked vector operations. - builder.create( - loc, iter_idx.getResult(0), bCol, /*Step=*/vl_step, - std::nullopt, - [&](OpBuilder &builder, Location loc, Value iv, - ValueRange itrArgs) { - Value idx = iter_idx.getResult(0); - Value tailSize = - builder.create(loc, bCol, idx); + Value idx = iter_idx.getResult(0); + Value tailSize = builder.create(loc, bCol, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sge, tailSize, c0); + // If the current column does not reach the tail. + builder.create( + loc, tailCond, [&](OpBuilder &builder, Location loc) { // Create mask according to the tail. Value tailMask = builder.create( loc, vectorMaskTy, tailSize); @@ -207,7 +206,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { iter_vec.getResult(0)); builder.create(loc); }); - builder.create(loc); + builder.create(loc); }); }); rewriter.eraseOp(op); diff --git a/tests/Conversion/batchmatmul-vectorization.mlir b/tests/Conversion/batchmatmul-vectorization.mlir index 070c836f7a..7f77dbebca 100644 --- a/tests/Conversion/batchmatmul-vectorization.mlir +++ b/tests/Conversion/batchmatmul-vectorization.mlir @@ -6,30 +6,31 @@ // CHECK-NEXT: affine.prefetch %arg0[%arg3, %dim_0, %dim_2], read, locality<3>, data : memref // CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_0) { // CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %2 step %c32 iter_args(%arg6 = %c0) -> (index) { -// CHECK-NEXT: %4 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %5 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %4) -> (vector<32xf32>) { -// CHECK-NEXT: %7 = memref.load %arg0[%arg3, %arg4, %arg7] : memref -// CHECK-NEXT: %8 = vector.broadcast %7 : f32 to vector<32xf32> -// CHECK-NEXT: %9 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %10 = vector.fma %8, %9, %arg8 : vector<32xf32> -// CHECK-NEXT: scf.yield %10 : vector<32xf32> +// CHECK-NEXT: %6 = vector.load %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %7 = scf.for %arg7 = %c0 to %dim_2 step %c1 iter_args(%arg8 = %6) -> (vector<32xf32>) { +// CHECK-NEXT: %9 = memref.load %arg0[%arg3, %arg4, %arg7] : memref +// CHECK-NEXT: %10 = vector.broadcast %9 : f32 to vector<32xf32> +// CHECK-NEXT: %11 = vector.load %arg1[%arg3, %arg7, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %12 = vector.fma %10, %11, %arg8 : vector<32xf32> +// CHECK-NEXT: scf.yield %12 : vector<32xf32> // CHECK-NEXT: } -// CHECK-NEXT: vector.store %5, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> -// CHECK-NEXT: %6 = arith.addi %arg5, %c32 : index -// CHECK-NEXT: scf.yield %6 : index +// CHECK-NEXT: vector.store %7, %arg2[%arg3, %arg4, %arg5] : memref, vector<32xf32> +// CHECK-NEXT: %8 = arith.addi %arg5, %c32 : index +// CHECK-NEXT: scf.yield %8 : index // CHECK-NEXT: } -// CHECK-NEXT: scf.for %arg5 = %3 to %dim_1 step %c32 { -// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index -// CHECK-NEXT: %5 = vector.create_mask %4 : vector<32xi1> -// CHECK-NEXT: %6 = vector.maskedload %arg2[%arg3, %arg4, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> -// CHECK-NEXT: %7 = scf.for %arg6 = %c0 to %dim_2 step %c1 iter_args(%arg7 = %6) -> (vector<32xf32>) { -// CHECK-NEXT: %8 = memref.load %arg0[%arg3, %arg4, %arg6] : memref -// CHECK-NEXT: %9 = vector.broadcast %8 : f32 to vector<32xf32> -// CHECK-NEXT: %10 = vector.maskedload %arg1[%arg3, %arg6, %3], %5, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> -// CHECK-NEXT: %11 = vector.fma %9, %10, %arg7 : vector<32xf32> -// CHECK-NEXT: scf.yield %11 : vector<32xf32> +// CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index +// CHECK-NEXT: %5 = arith.cmpi sge, %4, %c0 : index +// CHECK-NEXT: scf.if %5 { +// CHECK-NEXT: %6 = vector.create_mask %4 : vector<32xi1> +// CHECK-NEXT: %7 = vector.maskedload %arg2[%arg3, %arg4, %3], %6, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %8 = scf.for %arg5 = %c0 to %dim_2 step %c1 iter_args(%arg6 = %7) -> (vector<32xf32>) { +// CHECK-NEXT: %9 = memref.load %arg0[%arg3, %arg4, %arg5] : memref +// CHECK-NEXT: %10 = vector.broadcast %9 : f32 to vector<32xf32> +// CHECK-NEXT: %11 = vector.maskedload %arg1[%arg3, %arg5, %3], %6, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK-NEXT: %12 = vector.fma %10, %11, %arg6 : vector<32xf32> +// CHECK-NEXT: scf.yield %12 : vector<32xf32> // CHECK-NEXT: } -// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %5, %7 : memref, vector<32xi1>, vector<32xf32> +// CHECK-NEXT: vector.maskedstore %arg2[%arg3, %arg4, %3], %6, %8 : memref, vector<32xi1>, vector<32xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } From 91917e752d0294138b228f17825119205079eed0 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Tue, 26 Nov 2024 09:07:02 +0000 Subject: [PATCH 4/6] [midend/lib/Conversion/MatMulOptimization] Fix batchmatmul vectorization pass and examples. --- .../lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp | 2 +- tests/Conversion/batchmatmul-vectorization.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index df4f7bd49f..318fd57524 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -165,7 +165,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value idx = iter_idx.getResult(0); Value tailSize = builder.create(loc, bCol, idx); Value tailCond = rewriter.create( - loc, arith::CmpIPredicate::sge, tailSize, c0); + loc, arith::CmpIPredicate::sgt, tailSize, c0); // If the current column does not reach the tail. builder.create( loc, tailCond, [&](OpBuilder &builder, Location loc) { diff --git a/tests/Conversion/batchmatmul-vectorization.mlir b/tests/Conversion/batchmatmul-vectorization.mlir index 7f77dbebca..d53a27b10d 100644 --- a/tests/Conversion/batchmatmul-vectorization.mlir +++ b/tests/Conversion/batchmatmul-vectorization.mlir @@ -19,7 +19,7 @@ // CHECK-NEXT: scf.yield %8 : index // CHECK-NEXT: } // CHECK-NEXT: %4 = arith.subi %dim_1, %3 : index -// CHECK-NEXT: %5 = arith.cmpi sge, %4, %c0 : index +// CHECK-NEXT: %5 = arith.cmpi sgt, %4, %c0 : index // CHECK-NEXT: scf.if %5 { // CHECK-NEXT: %6 = vector.create_mask %4 : vector<32xi1> // CHECK-NEXT: %7 = vector.maskedload %arg2[%arg3, %arg4, %3], %6, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> From a383d338464a8bbbeb6fbf424c0965e08e9c2fc4 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Thu, 19 Dec 2024 09:04:49 +0000 Subject: [PATCH 5/6] [midend/lib/Conversion/MatMulOptimization] add time computation batchmatmul examples. --- examples/BuddyMatmul/linalg-batchmatmul-f32.mlir | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir index 58c9142398..5b256f6d8d 100644 --- a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir +++ b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir @@ -18,6 +18,7 @@ // RUN: | FileCheck %s func.func private @printMemrefF32(memref<*xf32>) +func.func private @rtclock() -> f64 func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.batch_matmul @@ -68,7 +69,9 @@ func.func @main(){ %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref + %t_start = call @rtclock() : () -> f64 call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () + %t_end = call @rtclock() : () -> f64 %printed_m5 = memref.cast %m5 : memref to memref<*xf32> @@ -78,5 +81,8 @@ func.func @main(){ // CHECK: [6144{{(, 6144)*}}] call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + return } From 16d2fed14d5dcb0d211795959e727790bc6f64d4 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Wed, 25 Dec 2024 08:33:01 +0000 Subject: [PATCH 6/6] examples/BuddyMatmul/: linalg-batchmatmul-f32.mlir: add operator execution time calculation; batchmatmul-vectorisation.mlir: handwritten mlir file makefile: update related commands. midend/lib/Conversion/MatMulOptimization/: update BatchMatMulOptimize.cpp implementation of vectorisation, now using the standard method in buddy. tests/Conversion/: batchmatmul-vectorisation.mlir: a test file. Before vectorisation optimization: 0.0043869, after vectorization optimization: 0.00018692 (with vectorization size of 128), optimization is about 23.5x. After vectorization size adjustment to 256: 0.000138998, optimization is about 31.6x. Compared to the previous version of BatchMatMulOptimize.cpp: before the update: 0.000586033, after the update: 0.000358105, the optimisation is improved by 1.5x. In addition, the original implementation will even have a negative optimisation. --- .../batchmatmul-vectorization.mlir | 18 +++++++-------- .../BuddyMatmul/linalg-batchmatmul-f32.mlir | 23 +++++++------------ examples/BuddyMatmul/makefile | 10 ++++---- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/examples/BuddyMatmul/batchmatmul-vectorization.mlir b/examples/BuddyMatmul/batchmatmul-vectorization.mlir index 1931f8c3cd..8d3b30c44e 100644 --- a/examples/BuddyMatmul/batchmatmul-vectorization.mlir +++ b/examples/BuddyMatmul/batchmatmul-vectorization.mlir @@ -40,6 +40,7 @@ module { %dim_3_upbound_tmp = arith.subi %dim_3, %vl_step : index %dim_3_upbound = arith.addi %dim_3_upbound_tmp, %c1 : index + %t_start = call @rtclock() : () -> f64 affine.for %arg3 = %c0 to %dim { // C affine.prefetch %arg0[%arg3, %dim_1, %dim_2], read, locality<3>, data : memref affine.for %arg4 = %c0 to %dim_1 { // M @@ -78,6 +79,11 @@ module { } } } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + vector.print %time : f64 return } func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg4: f32) -> memref { @@ -108,29 +114,21 @@ module { %m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref %m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref - call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () - - %printed_m2 = memref.cast %m2 : memref to memref<*xf32> - // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data = // CHECK-NEXT: [ // CHECK: [ // CHECK: [3456{{(, 3456)*}}] - call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () %m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref - call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () - - %printed_m5 = memref.cast %m5 : memref to memref<*xf32> - // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data = // CHECK-NEXT: [ // CHECK: [ // CHECK: [6144{{(, 6144)*}}] - call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () return } diff --git a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir index 5b256f6d8d..478ec6f0c4 100644 --- a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir +++ b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir @@ -21,9 +21,15 @@ func.func private @printMemrefF32(memref<*xf32>) func.func private @rtclock() -> f64 func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + %t_start = call @rtclock() : () -> f64 linalg.batch_matmul ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + vector.print %time : f64 return } @@ -55,34 +61,21 @@ func.func @main(){ %m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref %m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref - call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () - - %printed_m2 = memref.cast %m2 : memref to memref<*xf32> - // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data = // CHECK-NEXT: [ // CHECK: [ // CHECK: [3456{{(, 3456)*}}] - call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () %m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref - %t_start = call @rtclock() : () -> f64 - call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () - %t_end = call @rtclock() : () -> f64 - - %printed_m5 = memref.cast %m5 : memref to memref<*xf32> - // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data = // CHECK-NEXT: [ // CHECK: [ // CHECK: [6144{{(, 6144)*}}] - call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () - - %time = arith.subf %t_end, %t_start : f64 - vector.print %time : f64 + call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () return } diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile index 1a0655e9fc..1cc5bb9dd9 100644 --- a/examples/BuddyMatmul/makefile +++ b/examples/BuddyMatmul/makefile @@ -18,11 +18,6 @@ MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.dylib MTRIPLE := x86_64-apple-darwin endif -linalg-batchmatmul-lower: - @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ - -batchmatmul-optimize \ - -o ./log.mlir - linalg-batchmatmul-f32-run: @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ -convert-linalg-to-loops \ @@ -40,6 +35,11 @@ linalg-batchmatmul-f32-run: ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +linalg-batchmatmul-vectorization-lower: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -batchmatmul-optimize \ + -o ./log.mlir + linalg-batchmatmul-f32-vectorization-run: @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ -batchmatmul-optimize \