From 9c5e724c50014c3870c62483a6ef3e8173efd5b3 Mon Sep 17 00:00:00 2001 From: EllisLambda Date: Tue, 12 Sep 2023 17:50:21 +0800 Subject: [PATCH 1/5] [midend][examples] Correct parameter and variable `stepPlaceholder` in batchmatmul optimization. --- examples/MLIRLinalg/makefile | 6 +-- .../BatchMatMulOptimize.cpp | 42 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index 6b377e577d..7c3a073c45 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -138,7 +138,7 @@ linalg-matmul-optimize-run: linalg-batch-matmul-optimize-run: @${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="vector-size=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ @@ -174,12 +174,12 @@ linalg-batch-matmul-run: linalg-batch-matmul-optimize-lower: @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="vector-size=64" \ -o ./log.mlir linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ - -batchmatmul-optimize="step-placeholder=64" \ + -batchmatmul-optimize="vector-size=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 9b3924b7d8..c09de4718f 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -51,10 +51,10 @@ namespace { class BatchMatMulOptimizePattern : public ConversionPattern { public: explicit BatchMatMulOptimizePattern(MLIRContext *context, - int64_t stepPlaceHolderParam) + int64_t affineVectorSizeParam) : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, context) { - stepPlaceHolder = stepPlaceHolderParam; + affineVectorSize = affineVectorSizeParam; } LogicalResult @@ -73,7 +73,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { const Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); const Value step = rewriter.create( - loc, rewriter.getIndexAttr(stepPlaceHolder)); + loc, rewriter.getIndexAttr(affineVectorSize)); const AffineExpr d0 = rewriter.getAffineDimExpr(0); const AffineExpr d1 = rewriter.getAffineDimExpr(1); const AffineExpr d2 = rewriter.getAffineDimExpr(2); @@ -82,7 +82,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { const Value c0_dynamicType = rewriter.create( loc, rewriter.getZeroAttr(A_elementType)); const Value c0_dynamicType_vec = rewriter.create( - loc, VectorType::get({stepPlaceHolder}, A_elementType), c0_dynamicType); + loc, VectorType::get({affineVectorSize}, A_elementType), c0_dynamicType); // Dims Value BATCH = rewriter.create(loc, A, 0); // Batch size @@ -132,7 +132,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { [&](OpBuilder &builder, Location loc, ValueRange ivRange) { Value ivA_row = ivRange.front(); Value applied_n = builder.create( - loc, AffineMap::get(1, 0, d0.ceilDiv(stepPlaceHolder)), + loc, AffineMap::get(1, 0, d0.ceilDiv(affineVectorSize)), ValueRange{N}); affine::buildAffineLoopNest( builder, loc, {c0}, {applied_n}, 1, @@ -142,7 +142,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { loc, A, ValueRange{ivBatch, ivA_row, ivB_row}); Value a_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, A_elementType), + VectorType::get({affineVectorSize}, A_elementType), a_ele); Value b_col_cur = builder.create(loc, ivB_col, step); @@ -156,21 +156,21 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value b_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineVectorSize}, A_elementType), B, AffineMap::get( - 3, 0, {d0, d1, d2 * stepPlaceHolder}, + 3, 0, {d0, d1, d2 * affineVectorSize}, rewriter.getContext()), ValueRange{ivBatch, ivB_row, ivB_col}); Value c_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineVectorSize}, A_elementType), C, AffineMap::get( - 3, 0, {d0, d1, d2 * stepPlaceHolder}, + 3, 0, {d0, d1, d2 * affineVectorSize}, rewriter.getContext()), ValueRange{ivBatch, ivA_row, ivB_col}); Value result_vec; @@ -186,7 +186,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { builder.create( loc, result_vec, C, AffineMap::get(3, 0, - {d0, d1, d2 * stepPlaceHolder}, + {d0, d1, d2 * affineVectorSize}, rewriter.getContext()), ValueRange{ivBatch, ivA_row, ivB_col}); builder.create(loc); @@ -195,7 +195,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value mask_vec = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineVectorSize}, rewriter.getI1Type()), ValueRange{tail_len}); Value b_col_idx_tail = @@ -204,7 +204,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value b_vec_tail = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineVectorSize}, A_elementType), B, ValueRange{ivBatch, ivB_row, @@ -213,7 +213,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { Value c_vec_tail = builder.create( loc, - VectorType::get({stepPlaceHolder}, + VectorType::get({affineVectorSize}, A_elementType), C, ValueRange{ivBatch, ivA_row, @@ -249,7 +249,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { } private: - int64_t stepPlaceHolder; + int64_t affineVectorSize; }; } // end anonymous namespace @@ -268,8 +268,8 @@ class BatchMatMulOptimizePass StringRef getDescription() const final { return "BatchMatMul Optimization."; } BatchMatMulOptimizePass() = default; BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {} - explicit BatchMatMulOptimizePass(int64_t stepPlaceHolderParam) { - stepPlaceHolder = stepPlaceHolderParam; + explicit BatchMatMulOptimizePass(int64_t affineVectorSizeParam) { + affineVectorSize = affineVectorSizeParam; } void runOnOperation() override; @@ -279,9 +279,9 @@ class BatchMatMulOptimizePass affine::AffineDialect, VectorDialect>(); } - Option stepPlaceHolder{ - *this, "step-placeholder", - llvm::cl::desc("Affine step placeholder size."), llvm::cl::init(64)}; + Option affineVectorSize{ + *this, "vector-size", + llvm::cl::desc("Affine Vector size."), llvm::cl::init(64)}; }; } // end anonymous namespace. @@ -297,7 +297,7 @@ void BatchMatMulOptimizePass::runOnOperation() { target.addLegalOp(); RewritePatternSet patterns(context); - patterns.add(context, stepPlaceHolder); + patterns.add(context, affineVectorSize); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); From 71fdb5d775e2e2554d16017283b477364652cf10 Mon Sep 17 00:00:00 2001 From: EllisLambda Date: Wed, 20 Sep 2023 00:35:51 +0800 Subject: [PATCH 2/5] [midend][examples] Fix type recognition bug in batchmatmul optimization and add int8 tests. --- ...tmul.mlir => linalg-batch-matmul-f32.mlir} | 0 .../MLIRLinalg/linalg-batch-matmul-i8.mlir | 46 +++++++++++++ examples/MLIRLinalg/makefile | 67 +++++++++++++++++-- .../BatchMatMulOptimize.cpp | 6 +- 4 files changed, 110 insertions(+), 9 deletions(-) rename examples/MLIRLinalg/{linalg-batch-matmul.mlir => linalg-batch-matmul-f32.mlir} (100%) create mode 100644 examples/MLIRLinalg/linalg-batch-matmul-i8.mlir diff --git a/examples/MLIRLinalg/linalg-batch-matmul.mlir b/examples/MLIRLinalg/linalg-batch-matmul-f32.mlir similarity index 100% rename from examples/MLIRLinalg/linalg-batch-matmul.mlir rename to examples/MLIRLinalg/linalg-batch-matmul-f32.mlir diff --git a/examples/MLIRLinalg/linalg-batch-matmul-i8.mlir b/examples/MLIRLinalg/linalg-batch-matmul-i8.mlir new file mode 100644 index 0000000000..7b39258fcf --- /dev/null +++ b/examples/MLIRLinalg/linalg-batch-matmul-i8.mlir @@ -0,0 +1,46 @@ +// RUN: buddy-opt -batchmatmul-optimize -verify-diagnostics -expand-strided-metadata -lower-affine -convert-vector-to-llvm -finalize-memref-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts %s \ +// RUN: | mlir-cpu-runner -O0 -e buddy_batchmatmul_i8 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @A : memref<2x2x3xi8> = dense<[[[9, 4, 6],[2, 4, 0]],[[6, 3, 3],[0, 4, 7]]]> +memref.global "private" @B : memref<2x3x4xi8> = dense<[[[1, 3, 8, 0],[1, 8, 8, 7], [6, 9, 7, 9]],[[3, 8, 6, 8],[2, 7, 0, 6],[0, 4, 0, 4]]]> +memref.global "private" @C : memref<2x2x4xi8> = dense<[[[49, 12, 14, 82],[6, 38, 48, 28]],[[24, 81, 36, 78],[8, 56, 0, 52]]]> + +func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func.func @buddy_batchmatmul_i8() -> f32{ + %a = memref.get_global @A : memref<2x2x3xi8> + %b = memref.get_global @B : memref<2x3x4xi8> + %c = memref.get_global @C : memref<2x2x4xi8> + + linalg.batch_matmul + ins(%a, %b: memref<2x2x3xi8>, memref<2x3x4xi8>) + outs(%c: memref<2x2x4xi8>) + + %cst_0 = arith.constant 0 : index + %cst_1 = arith.constant 1 : index + %cst_2 = arith.constant 2 : index + %cst_4 = arith.constant 4 : index + + %c_f32 = memref.alloca() : memref<2x2x4xf32> + scf.for %i = %cst_0 to %cst_2 step %cst_1 { + scf.for %j = %cst_0 to %cst_2 step %cst_1 { + scf.for %k = %cst_0 to %cst_4 step %cst_1 { + %val_i8 = memref.load %c[%i, %j, %k] : memref<2x2x4xi8> + %val_f32 = arith.sitofp %val_i8 : i8 to f32 + memref.store %val_f32, %c_f32[%i, %j, %k] : memref<2x2x4xf32> + } + } + } + + %printed_c = memref.cast %c_f32 : memref<2x2x4xf32> to memref<*xf32> + call @printMemrefF32(%printed_c) : (memref<*xf32>) -> () + // CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[2, 2, 4\] strides = \[8, 4, 1\] data =}} + // CHECK{LITERAL}: [[[98, 125, -96, -92], + // CHECK{LITERAL}: [12, 76, 96, 56]], + // CHECK{LITERAL}: [[48, -94, 72, -100], + // CHECK{LITERAL}: [16, 112, 0, 104]]] + %zero = arith.constant 0.0 :f32 + return %zero :f32 +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index 7c3a073c45..c606b92112 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -137,7 +137,7 @@ linalg-matmul-optimize-run: -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} linalg-batch-matmul-optimize-run: - @${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ @@ -152,33 +152,88 @@ linalg-batch-matmul-optimize-run: -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} linalg-batch-matmul-lower: - @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ -convert-func-to-llvm -reconcile-unrealized-casts \ -o ./log.mlir linalg-batch-matmul-translate: - @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll linalg-batch-matmul-run: - @${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} linalg-batch-matmul-optimize-lower: - @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ -o ./log.mlir linalg-batch-matmul-optimize-translate: - @${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \ + @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="vector-size=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +linalg-batch-matmul-i8-optimize-run: + @${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="vector-size=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batch-matmul-i8-lower: + @${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts \ + -o ./log.mlir + +linalg-batch-matmul-i8-translate: + @${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +linalg-batch-matmul-i8-run: + @${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batch-matmul-i8-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-optimize="vector-size=64" \ + -o ./log.mlir + +linalg-batch-matmul-i8-optimize-translate: + @${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ -convert-linalg-to-loops \ -expand-strided-metadata \ diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index c09de4718f..334c6c4023 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -122,7 +122,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { rewriter.create( loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), - ArrayRef{ivBatch, c0, c0}, false, 3, true); + ArrayRef{ivBatch, M, K}, false, 3, true); affine::buildAffineLoopNest( rewriter, loc, {c0}, {K}, 1, [&](OpBuilder &builder, Location loc, ValueRange ivRange) { @@ -174,7 +174,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { rewriter.getContext()), ValueRange{ivBatch, ivA_row, ivB_col}); Value result_vec; - if (A_elementType.isIntOrFloat() && 0) { // bug + if (A_elementType.isa()) { Value add_vec = builder.create( loc, a_vec, b_vec); result_vec = builder.create( @@ -220,7 +220,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { b_col_idx_tail}, mask_vec, c0_dynamicType_vec); Value result_vec_tail; - if (A_elementType.isIntOrFloat() && 0) { // bug + if (A_elementType.isa()) { Value add_vec = builder.create( loc, a_vec, b_vec_tail); result_vec_tail = builder.create( From 07a89704357c40bfaefb7f6ccc4e6e39963cff2c Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Thu, 21 Sep 2023 09:36:30 +0000 Subject: [PATCH 3/5] [midend] Add matmul vectorization pass. Co-authored-by: breeze <1627211374@qq.com> --- examples/MLIRLinalg/makefile | 5 + .../MatMulOptimization/CMakeLists.txt | 1 + .../MatMulVectorization.cpp | 220 ++++++++++++++++++ tests/Conversion/matmul-vectorization.mlir | 64 +++++ tools/buddy-opt/buddy-opt.cpp | 3 + 5 files changed, 293 insertions(+) create mode 100644 midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp create mode 100644 tests/Conversion/matmul-vectorization.mlir diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index c606b92112..5fba1b9b6b 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -284,3 +284,8 @@ linalg-conv2d_nchw_fchw-optimize-run: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-matmul-vectorization-lower: + @${BUDDY_OPT} linalg-matmul.mlir \ + -matmul-vectorization \ + -o log.mlir diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index b096961ef2..860f5e8add 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(MatMulOptimization BatchMatMulOptimize.cpp MatMulOptimize.cpp + MatMulVectorization.cpp ) add_mlir_library(BatchMatMulOptimization diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp new file mode 100644 index 0000000000..5849c6a54c --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp @@ -0,0 +1,220 @@ +//===- MatMulVectorization.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the matmul vectorization. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class MatMulVectorizationPattern : public ConversionPattern { +public: + explicit MatMulVectorizationPattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::MatmulOp::getOperationName(), 1, context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + // Get shape of input and output + // ShapedType ATy = A.getType().cast(); + // Type eleTy = ATy.getElementType(); + // ShapedType BTy = B.getType().cast(); + // ShapedType CTy = C.getType().cast(); + + auto ctx = op->getContext(); + // Currently use f32 as the element type. + // TODO: replace f32 with input type. + FloatType f32 = mlir::FloatType::getF32(ctx); + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + // Define `*Type`. + VectorType vectorTy = mlir::VectorType::get({vecSize}, f32); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + // Some constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + const Value c0F32 = rewriter.create( + loc, APFloat::getZero(f32.getFloatSemantics()), f32); + // Create pass through vector. + Value c0F32Vec = rewriter.create(loc, vectorTy, c0F32); + + // Create DimOp. + const Value aRow = rewriter.create(loc, A, c0); + // This algorithm does not use the column A index. + // const Value aCol = rewriter.create(loc, A, c1); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + // Size of vector type. + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{bRow, aRow}; + SmallVector steps(2, /*Value=*/1); + // clang-format off + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, /*Step=*/1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + // Load element and broadcast to vector. + Value aEle = builder.create( + loc, A, ValueRange{ivs[1], ivs[0]}); + Value aVec = builder.create(loc, vectorTy, aEle); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + AffineExpr x, y, z; + bindDims(ctx, x, y, z); + AffineMap CVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {y, z * vecSize}, ctx); + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + builder.create(loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, ValueRange{ivs[0], ivs[1], iv}); + Value cVec = builder.create( + loc, vectorTy, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv}); + // FMA = Fused Multiply + Add + Value resultVector = builder.create(loc, aVec, bVec, cVec); + builder.create( + loc, resultVector, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv}); + builder.create(loc); + }, + // The else branch (the current column reaches the tail). + [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value bColIdxTail = builder.create(loc, iv, step); + // Masked load input and output. + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[0], bColIdxTail}, + maskVec, c0F32Vec); + Value cVecTail = builder.create( + loc, vectorTy, C, ValueRange{ivs[1], bColIdxTail}, + maskVec, c0F32Vec); + // FMA. + Value resultVecTail = + builder.create(loc, aVec, bVecTail, cVecTail); + builder.create( + loc, C, ValueRange{ivs[1], bColIdxTail}, maskVec, resultVecTail); + builder.create(loc); + }); + builder.create(loc); + }); + }); + // clang-format on + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// MatMulVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg matmul operations to mixture of +/// Affine + Vector operations. +namespace { +class MatMulVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulVectorizationPass) + StringRef getArgument() const final { return "matmul-vectorization"; } + StringRef getDescription() const final { return "MatMul Vectorization."; } + MatMulVectorizationPass() = default; + MatMulVectorizationPass(const MatMulVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecSize{*this, "vector-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(32)}; +}; +} // end anonymous namespace. + +void MatMulVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerMatMulVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tests/Conversion/matmul-vectorization.mlir b/tests/Conversion/matmul-vectorization.mlir new file mode 100644 index 0000000000..0b713013ca --- /dev/null +++ b/tests/Conversion/matmul-vectorization.mlir @@ -0,0 +1,64 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-vectorization="vector-size=64" \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @printMemrefF32(memref<*xf32>) + + func.func @matmul(%a : memref, %b : memref, %c : memref) { + linalg.matmul + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @main(){ + // Set up dims. + %cM = arith.constant 4 : index + %cN = arith.constant 4 : index + %cK = arith.constant 4 : index + + // Set Init Value. + %cf1 = arith.constant 1.0 : f32 + + %A = memref.alloc(%cM, %cK) : memref + %B = memref.alloc(%cK, %cN) : memref + %C = memref.alloc(%cM, %cN) : memref + + linalg.fill + ins(%cf1 : f32) + outs(%A:memref) + + linalg.fill + ins(%cf1 : f32) + outs(%B:memref) + + linalg.fill + ins(%cf1 : f32) + outs(%C:memref) + + call @matmul(%A, %B, %C) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%print_C) : (memref<*xf32>) -> () + + memref.dealloc %C : memref + memref.dealloc %B : memref + memref.dealloc %A : memref + return + } +} diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 8bde8b5711..c906af8ff3 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -56,6 +56,8 @@ void registerLowerDAPPass(); void registerLowerRVVPass(); void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); +void registerMatMulVectorizationPass(); + void registerConvOptimizePass(); void registerLowerVectorExpPass(); void registerLowerGemminiPass(); @@ -81,6 +83,7 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerConvOptimizePass(); From da71b6be2fe823824aef0b6a9a36ae06d21b157c Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Fri, 22 Sep 2023 08:25:24 +0800 Subject: [PATCH 4/5] [midend] Move insertZeroConstantOp from dip-utils to utils. --- midend/include/Utils/DIPUtils.h | 5 ----- midend/include/Utils/Utils.h | 5 +++++ .../lib/Conversion/LowerDIP/LowerDIPPass.cpp | 9 +++------ midend/lib/Utils/DIPUtils.cpp | 19 ------------------- midend/lib/Utils/Utils.cpp | 19 +++++++++++++++++++ 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/midend/include/Utils/DIPUtils.h b/midend/include/Utils/DIPUtils.h index 8f4978ae6e..7c80bd36c1 100644 --- a/midend/include/Utils/DIPUtils.h +++ b/midend/include/Utils/DIPUtils.h @@ -38,11 +38,6 @@ enum class DIP_OP { CORRELATION_2D, EROSION_2D, DILATION_2D }; // from lowering passes with appropriate messages. enum class DIP_ERROR { INCONSISTENT_TYPES, UNSUPPORTED_TYPE, NO_ERROR }; -// Inserts a constant op with value 0 into a location `loc` based on type -// `type`. Supported types are : f32, f64, integer types. -Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc, - Type elemTy); - // Inserts FMA operation into a given location `loc` based on type `type`. // Note: FMA is done by Multiply and Add for integer types, because there is no // dedicated FMA operation for them. diff --git a/midend/include/Utils/Utils.h b/midend/include/Utils/Utils.h index 04b3ce9681..60590a317b 100644 --- a/midend/include/Utils/Utils.h +++ b/midend/include/Utils/Utils.h @@ -25,6 +25,11 @@ using namespace mlir; namespace buddy { +// Inserts a constant op with value 0 into a location `loc` based on type +// `type`. Supported types are : f32, f64, integer types. +Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc, + Type elemTy); + // Function to test whether a value is equivalent to zero or not. Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value, Value zeroElem); diff --git a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp index 93a21d7f61..66b6d69ad2 100644 --- a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -795,8 +795,7 @@ class DIPTopHat2DOpLowering : public OpRewritePattern { VectorType vectorTy32 = VectorType::get({stride}, inElemTy); IntegerType i1 = IntegerType::get(ctx, 1); VectorType vectorMaskTy = VectorType::get({stride}, i1); - Value zeroPaddingElem = - dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy); + Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy); Value zeroPaddingVec = rewriter.create(loc, vectorTy32, zeroPaddingElem); @@ -1001,8 +1000,7 @@ class DIPBottomHat2DOpLowering : public OpRewritePattern { VectorType vectorTy32 = VectorType::get({stride}, inElemTy); IntegerType i1 = IntegerType::get(ctx, 1); VectorType vectorMaskTy = VectorType::get({stride}, i1); - Value zeroPaddingElem = - dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy); + Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy); Value zeroPaddingVec = rewriter.create(loc, vectorTy32, zeroPaddingElem); @@ -1203,8 +1201,7 @@ class DIPMorphGrad2DOpLowering : public OpRewritePattern { VectorType vectorTy32 = VectorType::get({stride}, inElemTy); IntegerType i1 = IntegerType::get(ctx, 1); VectorType vectorMaskTy = VectorType::get({stride}, i1); - Value zeroPaddingElem = - dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy); + Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy); Value zeroPaddingVec = rewriter.create(loc, vectorTy32, zeroPaddingElem); diff --git a/midend/lib/Utils/DIPUtils.cpp b/midend/lib/Utils/DIPUtils.cpp index ee31d4f40d..0e71af0f1f 100644 --- a/midend/lib/Utils/DIPUtils.cpp +++ b/midend/lib/Utils/DIPUtils.cpp @@ -184,25 +184,6 @@ DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector &args) { return DIP_ERROR::NO_ERROR; } -// Inserts a constant op with value 0 into a location `loc` based on type -// `type`. Supported types are : f32, f64, integer types. -Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc, - Type elemTy) { - Value op = {}; - auto bitWidth = elemTy.getIntOrFloatBitWidth(); - if (elemTy.isF32() || elemTy.isF64()) { - FloatType type = - elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx); - auto zero = APFloat::getZero(type.getFloatSemantics()); - op = builder.create(loc, zero, type); - } else if (elemTy.isInteger(bitWidth)) { - IntegerType type = IntegerType::get(ctx, bitWidth); - op = builder.create(loc, 0, type); - } - - return op; -} - // Inserts FMA operation into a given location `loc` based on type `type`. // Note: FMA is done by Multiply and Add for integer types, because there is no // dedicated FMA operation for them. diff --git a/midend/lib/Utils/Utils.cpp b/midend/lib/Utils/Utils.cpp index 4e355ed019..90d627e20d 100644 --- a/midend/lib/Utils/Utils.cpp +++ b/midend/lib/Utils/Utils.cpp @@ -39,6 +39,25 @@ using namespace mlir; namespace buddy { +// Inserts a constant op with value 0 into a location `loc` based on type +// `type`. Supported types are : f32, f64, integer types. +Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc, + Type elemTy) { + Value op = {}; + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + if (elemTy.isF32() || elemTy.isF64()) { + FloatType type = + elemTy.isF32() ? FloatType::getF32(ctx) : FloatType::getF64(ctx); + auto zero = APFloat::getZero(type.getFloatSemantics()); + op = builder.create(loc, zero, type); + } else if (elemTy.isInteger(bitWidth)) { + IntegerType type = IntegerType::get(ctx, bitWidth); + op = builder.create(loc, 0, type); + } + + return op; +} + // Function to test whether a value is equivalent to zero or not. Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value, Value zeroElem) { From d43b4633666b2a66ee98e2cf8d34c0c76a2d0c0f Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Fri, 22 Sep 2023 09:19:08 +0800 Subject: [PATCH 5/5] [midend] Extend MatMulVectorization pattern to multiple types. --- .../MatMulOptimization/CMakeLists.txt | 3 + .../MatMulVectorization.cpp | 22 +++--- tests/Conversion/matmul-vectorization.mlir | 77 ++++++++++++++----- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 860f5e8add..4f28fa8a44 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -2,6 +2,9 @@ add_mlir_library(MatMulOptimization BatchMatMulOptimize.cpp MatMulOptimize.cpp MatMulVectorization.cpp + + LINK_LIBS PUBLIC + BuddyUtils ) add_mlir_library(BatchMatMulOptimization diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp index 5849c6a54c..4b2c87eb26 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp @@ -27,6 +27,8 @@ #include #include +#include "Utils/Utils.h" + using namespace mlir; using namespace vector; @@ -53,19 +55,16 @@ class MatMulVectorizationPattern : public ConversionPattern { Value B = op->getOperand(1); Value C = op->getOperand(2); // Get shape of input and output - // ShapedType ATy = A.getType().cast(); - // Type eleTy = ATy.getElementType(); + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); // ShapedType BTy = B.getType().cast(); // ShapedType CTy = C.getType().cast(); auto ctx = op->getContext(); - // Currently use f32 as the element type. - // TODO: replace f32 with input type. - FloatType f32 = mlir::FloatType::getF32(ctx); // Get i1 as the element type for mask vector. IntegerType i1 = IntegerType::get(ctx, 1); // Define `*Type`. - VectorType vectorTy = mlir::VectorType::get({vecSize}, f32); + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); VectorType vectorMaskTy = VectorType::get({vecSize}, i1); // Some constants. const Value c0 = @@ -73,10 +72,9 @@ class MatMulVectorizationPattern : public ConversionPattern { const Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); const Value step = rewriter.create(loc, vecSize); - const Value c0F32 = rewriter.create( - loc, APFloat::getZero(f32.getFloatSemantics()), f32); // Create pass through vector. - Value c0F32Vec = rewriter.create(loc, vectorTy, c0F32); + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); // Create DimOp. const Value aRow = rewriter.create(loc, A, c0); @@ -127,6 +125,8 @@ class MatMulVectorizationPattern : public ConversionPattern { Value cVec = builder.create( loc, vectorTy, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv}); // FMA = Fused Multiply + Add + // FMAOp only supports floating point type input. + // TODO: Write a utils function for FMA to support both int and float. Value resultVector = builder.create(loc, aVec, bVec, cVec); builder.create( loc, resultVector, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv}); @@ -141,10 +141,10 @@ class MatMulVectorizationPattern : public ConversionPattern { // Masked load input and output. Value bVecTail = builder.create( loc, vectorTy, B, ValueRange{ivs[0], bColIdxTail}, - maskVec, c0F32Vec); + maskVec, passthruVec); Value cVecTail = builder.create( loc, vectorTy, C, ValueRange{ivs[1], bColIdxTail}, - maskVec, c0F32Vec); + maskVec, passthruVec); // FMA. Value resultVecTail = builder.create(loc, aVec, bVecTail, cVecTail); diff --git a/tests/Conversion/matmul-vectorization.mlir b/tests/Conversion/matmul-vectorization.mlir index 0b713013ca..5a7e9b5fbb 100644 --- a/tests/Conversion/matmul-vectorization.mlir +++ b/tests/Conversion/matmul-vectorization.mlir @@ -10,40 +10,76 @@ module{ func.func private @printMemrefF32(memref<*xf32>) + func.func private @printMemrefF64(memref<*xf64>) - func.func @matmul(%a : memref, %b : memref, %c : memref) { + func.func @matmul_f32(%a : memref, %b : memref, %c : memref) { linalg.matmul ins(%a, %b: memref, memref) outs(%c:memref) return } + func.func @matmul_f64(%a : memref, %b : memref, %c : memref) { + linalg.matmul + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + func.func @main(){ // Set up dims. %cM = arith.constant 4 : index %cN = arith.constant 4 : index %cK = arith.constant 4 : index + // ------------------------------------------------------------------------- + // Test f32 as element type. + // ------------------------------------------------------------------------- + // Set Init Value. - %cf1 = arith.constant 1.0 : f32 + %cf1_32 = arith.constant 1.0 : f32 + + %A_f32 = memref.alloc(%cM, %cK) : memref + %B_f32 = memref.alloc(%cK, %cN) : memref + %C_f32 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%A_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%B_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%C_f32 : memref) + + call @matmul_f32(%A_f32, %B_f32, %C_f32) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f32 = memref.cast %C_f32 : memref to memref<*xf32> + call @printMemrefF32(%print_C_f32) : (memref<*xf32>) -> () - %A = memref.alloc(%cM, %cK) : memref - %B = memref.alloc(%cK, %cN) : memref - %C = memref.alloc(%cM, %cN) : memref + memref.dealloc %C_f32 : memref + memref.dealloc %B_f32 : memref + memref.dealloc %A_f32 : memref - linalg.fill - ins(%cf1 : f32) - outs(%A:memref) + // ------------------------------------------------------------------------- + // Test f64 as element type. + // ------------------------------------------------------------------------- - linalg.fill - ins(%cf1 : f32) - outs(%B:memref) + // Set Init Value. + %cf1_64 = arith.constant 1.0 : f64 - linalg.fill - ins(%cf1 : f32) - outs(%C:memref) + %A_f64 = memref.alloc(%cM, %cK) : memref + %B_f64 = memref.alloc(%cK, %cN) : memref + %C_f64 = memref.alloc(%cM, %cN) : memref - call @matmul(%A, %B, %C) : (memref, memref, memref) -> () + linalg.fill ins(%cf1_64 : f64) outs(%A_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%B_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%C_f64 : memref) + + call @matmul_f64(%A_f64, %B_f64, %C_f64) : (memref, memref, memref) -> () // Print output. // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = @@ -53,12 +89,13 @@ module{ // CHECK-NEXT: [5, 5, 5, 5], // CHECK-NEXT: [5, 5, 5, 5] // CHECK-SAME: ] - %print_C = memref.cast %C : memref to memref<*xf32> - call @printMemrefF32(%print_C) : (memref<*xf32>) -> () + %print_C_f64 = memref.cast %C_f64 : memref to memref<*xf64> + call @printMemrefF64(%print_C_f64) : (memref<*xf64>) -> () + + memref.dealloc %C_f64 : memref + memref.dealloc %B_f64 : memref + memref.dealloc %A_f64 : memref - memref.dealloc %C : memref - memref.dealloc %B : memref - memref.dealloc %A : memref return } }