Skip to content

Commit

Permalink
[midend][examples] Fix type recognition bug in batchmatmul optimizati…
Browse files Browse the repository at this point in the history
…on and add int8 tests.
  • Loading branch information
EllisLambda committed Sep 18, 2023
1 parent db11165 commit 18eb761
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 12 deletions.
46 changes: 46 additions & 0 deletions examples/MLIRLinalg/linalg-batch-matmul-i8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: buddy-opt -batchmatmul-optimize -verify-diagnostics -expand-strided-metadata -lower-affine -convert-vector-to-llvm -finalize-memref-to-llvm -convert-scf-to-cf -convert-linalg-to-llvm -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts %s \
// RUN: | mlir-cpu-runner -O0 -e buddy_batchmatmul_i8 \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

memref.global "private" @A : memref<2x2x3xi8> = dense<[[[9, 4, 6],[2, 4, 0]],[[6, 3, 3],[0, 4, 7]]]>
memref.global "private" @B : memref<2x3x4xi8> = dense<[[[1, 3, 8, 0],[1, 8, 8, 7], [6, 9, 7, 9]],[[3, 8, 6, 8],[2, 7, 0, 6],[0, 4, 0, 4]]]>
memref.global "private" @C : memref<2x2x4xi8> = dense<[[[49, 12, 14, 82],[6, 38, 48, 28]],[[24, 81, 36, 78],[8, 56, 0, 52]]]>

func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }

func.func @buddy_batchmatmul_i8() -> f32{
%a = memref.get_global @A : memref<2x2x3xi8>
%b = memref.get_global @B : memref<2x3x4xi8>
%c = memref.get_global @C : memref<2x2x4xi8>

linalg.batch_matmul
ins(%a, %b: memref<2x2x3xi8>, memref<2x3x4xi8>)
outs(%c: memref<2x2x4xi8>)

%cst_0 = arith.constant 0 : index
%cst_1 = arith.constant 1 : index
%cst_2 = arith.constant 2 : index
%cst_4 = arith.constant 4 : index

%c_f32 = memref.alloca() : memref<2x2x4xf32>
scf.for %i = %cst_0 to %cst_2 step %cst_1 {
scf.for %j = %cst_0 to %cst_2 step %cst_1 {
scf.for %k = %cst_0 to %cst_4 step %cst_1 {
%val_i8 = memref.load %c[%i, %j, %k] : memref<2x2x4xi8>
%val_f32 = arith.sitofp %val_i8 : i8 to f32
memref.store %val_f32, %c_f32[%i, %j, %k] : memref<2x2x4xf32>
}
}
}

%printed_c = memref.cast %c_f32 : memref<2x2x4xf32> to memref<*xf32>
call @printMemrefF32(%printed_c) : (memref<*xf32>) -> ()
// CHECK: {{Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[2, 2, 4\] strides = \[8, 4, 1\] data =}}
// CHECK{LITERAL}: [[[98, 125, -96, -92],
// CHECK{LITERAL}: [12, 76, 96, 56]],
// CHECK{LITERAL}: [[48, -94, 72, -100],
// CHECK{LITERAL}: [16, 112, 0, 104]]]
%zero = arith.constant 0.0 :f32
return %zero :f32
}
73 changes: 64 additions & 9 deletions examples/MLIRLinalg/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ linalg-matmul-optimize-run:
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-optimize-run:
@${BUDDY_OPT} linalg-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand All @@ -152,34 +152,89 @@ linalg-batch-matmul-optimize-run:
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-lower:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts \
-o ./log.mlir

linalg-batch-matmul-translate:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-run:
@${MLIR_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
@${MLIR_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-optimize-lower:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-o ./log.mlir

linalg-batch-matmul-optimize-translate:
@${BUDDY_OPT} linalg-batch-matmul.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="step=64" \
@${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-i8-optimize-run:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-i8-lower:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts \
-o ./log.mlir

linalg-batch-matmul-i8-translate:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll

linalg-batch-matmul-i8-run:
@${MLIR_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batch-matmul-i8-optimize-lower:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-o ./log.mlir

linalg-batch-matmul-i8-optimize-translate:
@${BUDDY_OPT} linalg-batch-matmul-i8.mlir ${MLIR_OPT_OPTIONS} \
-batchmatmul-optimize="vector-size=64" \
-convert-linalg-to-loops \
-expand-strided-metadata \
-lower-affine \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {

rewriter.create<affine::AffinePrefetchOp>(
loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()),
ArrayRef<Value>{ivBatch, c0, c0}, false, 3, true);
ArrayRef<Value>{ivBatch, M, K}, false, 3, true);
affine::buildAffineLoopNest(
rewriter, loc, {c0}, {K}, 1,
[&](OpBuilder &builder, Location loc, ValueRange ivRange) {
Expand Down Expand Up @@ -174,7 +174,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
rewriter.getContext()),
ValueRange{ivBatch, ivA_row, ivB_col});
Value result_vec;
if (A_elementType.isIntOrFloat() && 0) { // bug
if (A_elementType.isa<IntegerType>()) {
Value add_vec = builder.create<arith::MulIOp>(
loc, a_vec, b_vec);
result_vec = builder.create<arith::AddIOp>(
Expand Down Expand Up @@ -220,7 +220,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
b_col_idx_tail},
mask_vec, c0_dynamicType_vec);
Value result_vec_tail;
if (A_elementType.isIntOrFloat() && 0) { // bug
if (A_elementType.isa<IntegerType>()) {
Value add_vec = builder.create<arith::MulIOp>(
loc, a_vec, b_vec_tail);
result_vec_tail = builder.create<arith::AddIOp>(
Expand Down

0 comments on commit 18eb761

Please sign in to comment.