Skip to content

Commit

Permalink
[midend] Extend MatMulVectorization pattern to multiple types.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Sep 22, 2023
1 parent da71b6b commit d43b463
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 31 deletions.
3 changes: 3 additions & 0 deletions midend/lib/Conversion/MatMulOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ add_mlir_library(MatMulOptimization
BatchMatMulOptimize.cpp
MatMulOptimize.cpp
MatMulVectorization.cpp

LINK_LIBS PUBLIC
BuddyUtils
)

add_mlir_library(BatchMatMulOptimization
Expand Down
22 changes: 11 additions & 11 deletions midend/lib/Conversion/MatMulOptimization/MatMulVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <mlir/IR/Value.h>
#include <mlir/Pass/Pass.h>

#include "Utils/Utils.h"

using namespace mlir;
using namespace vector;

Expand All @@ -53,30 +55,26 @@ class MatMulVectorizationPattern : public ConversionPattern {
Value B = op->getOperand(1);
Value C = op->getOperand(2);
// Get shape of input and output
// ShapedType ATy = A.getType().cast<ShapedType>();
// Type eleTy = ATy.getElementType();
ShapedType ATy = A.getType().cast<ShapedType>();
Type eleTy = ATy.getElementType();
// ShapedType BTy = B.getType().cast<ShapedType>();
// ShapedType CTy = C.getType().cast<ShapedType>();

auto ctx = op->getContext();
// Currently use f32 as the element type.
// TODO: replace f32 with input type.
FloatType f32 = mlir::FloatType::getF32(ctx);
// Get i1 as the element type for mask vector.
IntegerType i1 = IntegerType::get(ctx, 1);
// Define `*Type`.
VectorType vectorTy = mlir::VectorType::get({vecSize}, f32);
VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy);
VectorType vectorMaskTy = VectorType::get({vecSize}, i1);
// Some constants.
const Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
const Value c1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
const Value step = rewriter.create<arith::ConstantIndexOp>(loc, vecSize);
const Value c0F32 = rewriter.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(f32.getFloatSemantics()), f32);
// Create pass through vector.
Value c0F32Vec = rewriter.create<SplatOp>(loc, vectorTy, c0F32);
const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy);
Value passthruVec = rewriter.create<SplatOp>(loc, vectorTy, c0Ele);

// Create DimOp.
const Value aRow = rewriter.create<memref::DimOp>(loc, A, c0);
Expand Down Expand Up @@ -127,6 +125,8 @@ class MatMulVectorizationPattern : public ConversionPattern {
Value cVec = builder.create<affine::AffineVectorLoadOp>(
loc, vectorTy, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv});
// FMA = Fused Multiply + Add
// FMAOp only supports floating point type input.
// TODO: Write a utils function for FMA to support both int and float.
Value resultVector = builder.create<FMAOp>(loc, aVec, bVec, cVec);
builder.create<affine::AffineVectorStoreOp>(
loc, resultVector, C, CVectorMap, ValueRange{ivs[0], ivs[1], iv});
Expand All @@ -141,10 +141,10 @@ class MatMulVectorizationPattern : public ConversionPattern {
// Masked load input and output.
Value bVecTail = builder.create<MaskedLoadOp>(
loc, vectorTy, B, ValueRange{ivs[0], bColIdxTail},
maskVec, c0F32Vec);
maskVec, passthruVec);
Value cVecTail = builder.create<MaskedLoadOp>(
loc, vectorTy, C, ValueRange{ivs[1], bColIdxTail},
maskVec, c0F32Vec);
maskVec, passthruVec);
// FMA.
Value resultVecTail =
builder.create<FMAOp>(loc, aVec, bVecTail, cVecTail);
Expand Down
77 changes: 57 additions & 20 deletions tests/Conversion/matmul-vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,76 @@

module{
func.func private @printMemrefF32(memref<*xf32>)
func.func private @printMemrefF64(memref<*xf64>)

func.func @matmul(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>) {
func.func @matmul_f32(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>) {
linalg.matmul
ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
outs(%c:memref<?x?xf32>)
return
}

func.func @matmul_f64(%a : memref<?x?xf64>, %b : memref<?x?xf64>, %c : memref<?x?xf64>) {
linalg.matmul
ins(%a, %b: memref<?x?xf64>, memref<?x?xf64>)
outs(%c:memref<?x?xf64>)
return
}

func.func @main(){
// Set up dims.
%cM = arith.constant 4 : index
%cN = arith.constant 4 : index
%cK = arith.constant 4 : index

// -------------------------------------------------------------------------
// Test f32 as element type.
// -------------------------------------------------------------------------

// Set Init Value.
%cf1 = arith.constant 1.0 : f32
%cf1_32 = arith.constant 1.0 : f32

%A_f32 = memref.alloc(%cM, %cK) : memref<?x?xf32>
%B_f32 = memref.alloc(%cK, %cN) : memref<?x?xf32>
%C_f32 = memref.alloc(%cM, %cN) : memref<?x?xf32>

linalg.fill ins(%cf1_32 : f32) outs(%A_f32 : memref<?x?xf32>)
linalg.fill ins(%cf1_32 : f32) outs(%B_f32 : memref<?x?xf32>)
linalg.fill ins(%cf1_32 : f32) outs(%C_f32 : memref<?x?xf32>)

call @matmul_f32(%A_f32, %B_f32, %C_f32) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()

// Print output.
// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data =
// CHECK-NEXT: [
// CHECK-SAME: [5, 5, 5, 5],
// CHECK-NEXT: [5, 5, 5, 5],
// CHECK-NEXT: [5, 5, 5, 5],
// CHECK-NEXT: [5, 5, 5, 5]
// CHECK-SAME: ]
%print_C_f32 = memref.cast %C_f32 : memref<?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_C_f32) : (memref<*xf32>) -> ()

%A = memref.alloc(%cM, %cK) : memref<?x?xf32>
%B = memref.alloc(%cK, %cN) : memref<?x?xf32>
%C = memref.alloc(%cM, %cN) : memref<?x?xf32>
memref.dealloc %C_f32 : memref<?x?xf32>
memref.dealloc %B_f32 : memref<?x?xf32>
memref.dealloc %A_f32 : memref<?x?xf32>

linalg.fill
ins(%cf1 : f32)
outs(%A:memref<?x?xf32>)
// -------------------------------------------------------------------------
// Test f64 as element type.
// -------------------------------------------------------------------------

linalg.fill
ins(%cf1 : f32)
outs(%B:memref<?x?xf32>)
// Set Init Value.
%cf1_64 = arith.constant 1.0 : f64

linalg.fill
ins(%cf1 : f32)
outs(%C:memref<?x?xf32>)
%A_f64 = memref.alloc(%cM, %cK) : memref<?x?xf64>
%B_f64 = memref.alloc(%cK, %cN) : memref<?x?xf64>
%C_f64 = memref.alloc(%cM, %cN) : memref<?x?xf64>

call @matmul(%A, %B, %C) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
linalg.fill ins(%cf1_64 : f64) outs(%A_f64 : memref<?x?xf64>)
linalg.fill ins(%cf1_64 : f64) outs(%B_f64 : memref<?x?xf64>)
linalg.fill ins(%cf1_64 : f64) outs(%C_f64 : memref<?x?xf64>)

call @matmul_f64(%A_f64, %B_f64, %C_f64) : (memref<?x?xf64>, memref<?x?xf64>, memref<?x?xf64>) -> ()

// Print output.
// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data =
Expand All @@ -53,12 +89,13 @@ module{
// CHECK-NEXT: [5, 5, 5, 5],
// CHECK-NEXT: [5, 5, 5, 5]
// CHECK-SAME: ]
%print_C = memref.cast %C : memref<?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_C) : (memref<*xf32>) -> ()
%print_C_f64 = memref.cast %C_f64 : memref<?x?xf64> to memref<*xf64>
call @printMemrefF64(%print_C_f64) : (memref<*xf64>) -> ()

memref.dealloc %C_f64 : memref<?x?xf64>
memref.dealloc %B_f64 : memref<?x?xf64>
memref.dealloc %A_f64 : memref<?x?xf64>

memref.dealloc %C : memref<?x?xf32>
memref.dealloc %B : memref<?x?xf32>
memref.dealloc %A : memref<?x?xf32>
return
}
}

0 comments on commit d43b463

Please sign in to comment.