Skip to content

Commit

Permalink
[DIP] Merge remote-tracking branch 'upstream/main' into DIP_png
Browse files Browse the repository at this point in the history
  • Loading branch information
Guan-schoolmate committed Sep 22, 2023
2 parents 044864e + d43b463 commit 398e381
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 63 deletions.
46 changes: 46 additions & 0 deletions examples/MLIRLinalg/linalg-batch-matmul-i8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: buddy-opt -batchmatmul-optimize -verify-diagnostics -expand-strided-metadata -lower-affine -convert-vector-to-llvm -finalize-memref-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts %s \
// RUN: | mlir-cpu-runner -O0 -e buddy_batchmatmul_i8 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @A : memref<2x2x3xi8> = dense<[[[9, 4, 6],[2, 4, 0]],[[6, 3, 3],[0, 4, 7]]]>
memref.global "private" @B : memref<2x3x4xi8> = dense<[[[1, 3, 8, 0],[1, 8, 8, 7], [6, 9, 7, 9]],[[3, 8, 6, 8],[2, 7, 0, 6],[0, 4, 0, 4]]]>
memref.global "private" @C : memref<2x2x4xi8> = dense<[[[49, 12, 14, 82],[6, 38, 48, 28]],[[24, 81, 36, 78],[8, 56, 0, 52]]]>

func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }

func.func @buddy_batchmatmul_i8() -> f32{
%a = memref.get_global @A : memref<2x2x3xi8>
%b = memref.get_global @B : memref<2x3x4xi8>
%c = memref.get_global @C : memref<2x2x4xi8>

linalg.batch_matmul
ins(%a, %b: memref<2x2x3xi8>, memref<2x3x4xi8>)
outs(%c: memref<2x2x4xi8>)

%cst_0 = arith.constant 0 : index
%cst_1 = arith.constant 1 : index
%cst_2 = arith.constant 2 : index
%cst_4 = arith.constant 4 : index

%c_f32 = memref.alloca() : memref<2x2x4xf32>
scf.for %i = %cst_0 to %cst_2 step %cst_1 {
scf.for %j = %cst_0 to %cst_2 step %cst_1 {
scf.for %k = %cst_0 to %cst_4 step %cst_1 {
%val_i8 = memref.load %c[%i, %j, %k] : memref<2x2x4xi8>
%val_f32 = arith.sitofp %val_i8 : i8 to f32
memref.store %val_f32, %c_f32[%i, %j, %k] : memref<2x2x4xf32>
}
}
}

%printed_c = memref.cast %c_f32 : memref<2x2x4xf32> to memref<*xf32>
call @printMemrefF32(%printed_c) : (memref<*xf32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[2, 2, 4\] strides = \[8, 4, 1\] data =}}
// CHECK{LITERAL}: [[[98, 125, -96, -92],
// CHECK{LITERAL}: [12, 76, 96, 56]],
// CHECK{LITERAL}: [[48, -94, 72, -100],
// CHECK{LITERAL}: [16, 112, 0, 104]]]
%zero = arith.constant 0.0 :f32
return %zero :f32
}
78 changes: 69 additions & 9 deletions examples/MLIRLinalg/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ linalg-matmul-optimize-run:
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-optimize-run:
@${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand All @@ -152,34 +152,89 @@ linalg-batch-matmul-optimize-run:
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-lower:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts \
-o ./log.mlir

linalg-batch-matmul-translate:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-run:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-optimize-lower:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-o ./log.mlir

linalg-batch-matmul-optimize-translate:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step-placeholder=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-i8-optimize-run:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-i8-lower:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts \
-o ./log.mlir

linalg-batch-matmul-i8-translate:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-i8-run:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-i8-optimize-lower:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-o ./log.mlir

linalg-batch-matmul-i8-optimize-translate:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand Down Expand Up @@ -229,3 +284,8 @@ linalg-conv2d_nchw_fchw-optimize-run:
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-matmul-vectorization-lower:
@${BUDDY_OPT} linalg-matmul.mlir \
-matmul-vectorization \
-o log.mlir
5 changes: 0 additions & 5 deletions midend/include/Utils/DIPUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ enum class DIP_OP { CORRELATION_2D, EROSION_2D, DILATION_2D };
// from lowering passes with appropriate messages.
enum class DIP_ERROR { INCONSISTENT_TYPES, UNSUPPORTED_TYPE, NO_ERROR };

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy);

// Inserts FMA operation into a given location `loc` based on type `type`.
// Note: FMA is done by Multiply and Add for integer types, because there is no
// dedicated FMA operation for them.
Expand Down
5 changes: 5 additions & 0 deletions midend/include/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ using namespace mlir;

namespace buddy {

// Inserts a constant op with value 0 into a location `loc` based on type
// `type`. Supported types are : f32, f64, integer types.
Value insertZeroConstantOp(MLIRContext *ctx, OpBuilder &builder, Location loc,
Type elemTy);

// Function to test whether a value is equivalent to zero or not.
Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value,
Value zeroElem);
Expand Down
9 changes: 3 additions & 6 deletions midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,7 @@ class DIPTopHat2DOpLowering : public OpRewritePattern<dip::TopHat2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down Expand Up @@ -1001,8 +1000,7 @@ class DIPBottomHat2DOpLowering : public OpRewritePattern<dip::BottomHat2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down Expand Up @@ -1203,8 +1201,7 @@ class DIPMorphGrad2DOpLowering : public OpRewritePattern<dip::MorphGrad2DOp> {
VectorType vectorTy32 = VectorType::get({stride}, inElemTy);
IntegerType i1 = IntegerType::get(ctx, 1);
VectorType vectorMaskTy = VectorType::get({stride}, i1);
Value zeroPaddingElem =
dip::insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingElem = insertZeroConstantOp(ctx, rewriter, loc, inElemTy);
Value zeroPaddingVec =
rewriter.create<vector::BroadcastOp>(loc, vectorTy32, zeroPaddingElem);

Expand Down
48 changes: 24 additions & 24 deletions midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ namespace {
class BatchMatMulOptimizePattern : public ConversionPattern {
public:
explicit BatchMatMulOptimizePattern(MLIRContext *context,
int64_t stepPlaceHolderParam)
int64_t affineVectorSizeParam)
: ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1,
context) {
stepPlaceHolder = stepPlaceHolderParam;
affineVectorSize = affineVectorSizeParam;
}

LogicalResult
Expand All @@ -73,7 +73,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
const Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
const Value step = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(stepPlaceHolder));
loc, rewriter.getIndexAttr(affineVectorSize));
const AffineExpr d0 = rewriter.getAffineDimExpr(0);
const AffineExpr d1 = rewriter.getAffineDimExpr(1);
const AffineExpr d2 = rewriter.getAffineDimExpr(2);
Expand All @@ -82,7 +82,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
const Value c0_dynamicType = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(A_elementType));
const Value c0_dynamicType_vec = rewriter.create<vector::SplatOp>(
loc, VectorType::get({stepPlaceHolder}, A_elementType), c0_dynamicType);
loc, VectorType::get({affineVectorSize}, A_elementType), c0_dynamicType);

// Dims
Value BATCH = rewriter.create<memref::DimOp>(loc, A, 0); // Batch size
Expand Down Expand Up @@ -122,7 +122,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {

rewriter.create<affine::AffinePrefetchOp>(
loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()),
ArrayRef<Value>{ivBatch, c0, c0}, false, 3, true);
ArrayRef<Value>{ivBatch, M, K}, false, 3, true);
affine::buildAffineLoopNest(
rewriter, loc, {c0}, {K}, 1,
[&](OpBuilder &builder, Location loc, ValueRange ivRange) {
Expand All @@ -132,7 +132,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, ValueRange ivRange) {
Value ivA_row = ivRange.front();
Value applied_n = builder.create<affine::AffineApplyOp>(
loc, AffineMap::get(1, 0, d0.ceilDiv(stepPlaceHolder)),
loc, AffineMap::get(1, 0, d0.ceilDiv(affineVectorSize)),
ValueRange{N});
affine::buildAffineLoopNest(
builder, loc, {c0}, {applied_n}, 1,
Expand All @@ -142,7 +142,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
loc, A, ValueRange{ivBatch, ivA_row, ivB_row});
Value a_vec = builder.create<vector::BroadcastOp>(
loc,
VectorType::get({stepPlaceHolder}, A_elementType),
VectorType::get({affineVectorSize}, A_elementType),
a_ele);
Value b_col_cur =
builder.create<arith::MulIOp>(loc, ivB_col, step);
Expand All @@ -156,25 +156,25 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value b_vec =
builder.create<affine::AffineVectorLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineVectorSize},
A_elementType),
B,
AffineMap::get(
3, 0, {d0, d1, d2 * stepPlaceHolder},
3, 0, {d0, d1, d2 * affineVectorSize},
rewriter.getContext()),
ValueRange{ivBatch, ivB_row, ivB_col});
Value c_vec =
builder.create<affine::AffineVectorLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineVectorSize},
A_elementType),
C,
AffineMap::get(
3, 0, {d0, d1, d2 * stepPlaceHolder},
3, 0, {d0, d1, d2 * affineVectorSize},
rewriter.getContext()),
ValueRange{ivBatch, ivA_row, ivB_col});
Value result_vec;
if (A_elementType.isIntOrFloat() && 0) { // bug
if (A_elementType.isa<IntegerType>()) {
Value add_vec = builder.create<arith::MulIOp>(
loc, a_vec, b_vec);
result_vec = builder.create<arith::AddIOp>(
Expand All @@ -186,7 +186,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
builder.create<affine::AffineVectorStoreOp>(
loc, result_vec, C,
AffineMap::get(3, 0,
{d0, d1, d2 * stepPlaceHolder},
{d0, d1, d2 * affineVectorSize},
rewriter.getContext()),
ValueRange{ivBatch, ivA_row, ivB_col});
builder.create<scf::YieldOp>(loc);
Expand All @@ -195,7 +195,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value mask_vec =
builder.create<vector::CreateMaskOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineVectorSize},
rewriter.getI1Type()),
ValueRange{tail_len});
Value b_col_idx_tail =
Expand All @@ -204,7 +204,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value b_vec_tail =
builder.create<vector::MaskedLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineVectorSize},
A_elementType),
B,
ValueRange{ivBatch, ivB_row,
Expand All @@ -213,14 +213,14 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value c_vec_tail =
builder.create<vector::MaskedLoadOp>(
loc,
VectorType::get({stepPlaceHolder},
VectorType::get({affineVectorSize},
A_elementType),
C,
ValueRange{ivBatch, ivA_row,
b_col_idx_tail},
mask_vec, c0_dynamicType_vec);
Value result_vec_tail;
if (A_elementType.isIntOrFloat() && 0) { // bug
if (A_elementType.isa<IntegerType>()) {
Value add_vec = builder.create<arith::MulIOp>(
loc, a_vec, b_vec_tail);
result_vec_tail = builder.create<arith::AddIOp>(
Expand Down Expand Up @@ -249,7 +249,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
}

private:
int64_t stepPlaceHolder;
int64_t affineVectorSize;
};
} // end anonymous namespace

Expand All @@ -268,8 +268,8 @@ class BatchMatMulOptimizePass
StringRef getDescription() const final { return "BatchMatMul Optimization."; }
BatchMatMulOptimizePass() = default;
BatchMatMulOptimizePass(const BatchMatMulOptimizePass &) {}
explicit BatchMatMulOptimizePass(int64_t stepPlaceHolderParam) {
stepPlaceHolder = stepPlaceHolderParam;
explicit BatchMatMulOptimizePass(int64_t affineVectorSizeParam) {
affineVectorSize = affineVectorSizeParam;
}

void runOnOperation() override;
Expand All @@ -279,9 +279,9 @@ class BatchMatMulOptimizePass
affine::AffineDialect, VectorDialect>();
}

Option<int64_t> stepPlaceHolder{
*this, "step-placeholder",
llvm::cl::desc("Affine step placeholder size."), llvm::cl::init(64)};
Option<int64_t> affineVectorSize{
*this, "vector-size",
llvm::cl::desc("Affine Vector size."), llvm::cl::init(64)};
};
} // end anonymous namespace.

Expand All @@ -297,7 +297,7 @@ void BatchMatMulOptimizePass::runOnOperation() {
target.addLegalOp<linalg::FillOp>();

RewritePatternSet patterns(context);
patterns.add<BatchMatMulOptimizePattern>(context, stepPlaceHolder);
patterns.add<BatchMatMulOptimizePattern>(context, affineVectorSize);

if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
Expand Down
Loading

0 comments on commit 398e381

Please sign in to comment.