Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gemmini Dialect] Gemmini Dialect enhancement on tiled_matmul #178

Merged
merged 51 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f23cb7d
add support for lowering config_norm
Xinyu302 Jun 30, 2023
c81f0fd
finish matmul os but has bug
Xinyu302 Jul 3, 2023
4db570f
Fix compile-time errors
Xinyu302 Jul 3, 2023
b5651c0
fix bug & add test for tiled-matmul-os
Xinyu302 Jul 4, 2023
28e5d56
Clear comments
Xinyu302 Jul 4, 2023
b56ad99
fix bug in GemminiConfigNormOpLowering
Xinyu302 Jul 10, 2023
ec9b426
before merge patch of gemmini dialect
Xinyu302 Jul 10, 2023
54ffd2c
add act in ConfigSt lowering, solve wrong computing
Xinyu302 Jul 10, 2023
b7aa811
Add test for tile-matmul-ws-igelu
Xinyu302 Jul 10, 2023
d242b9c
change signature, use class variable to replace "sizeof" operator
Xinyu302 Jul 11, 2023
498eb87
add mvin2, mvin3 lowering
Xinyu302 Jul 11, 2023
85c656b
make gemminiMvinOffset a tool function with template to adapt mvin2 a…
Xinyu302 Jul 12, 2023
1f444d6
extend config_ld op
Xinyu302 Jul 12, 2023
5e69ee1
Fixed a spell error, but still don't know how to describe the paramet…
Xinyu302 Aug 2, 2023
8143b89
delete empty line
Xinyu302 Aug 2, 2023
aa73372
add filecheck for tile-matmul-os.mlir
Xinyu302 Aug 8, 2023
91ce9c8
add FILECHECK for tile-matmul-ws-igelu.mlir
Xinyu302 Aug 8, 2023
1898f4c
add detailed descrition generated by chatGPT.
Xinyu302 Aug 8, 2023
4a2ecd7
delete sizeof, using class attribute
Xinyu302 Aug 9, 2023
bd17696
add space in int_riscv_mvin
Xinyu302 Aug 10, 2023
5cfaab4
add description for config_norm op
Xinyu302 Aug 10, 2023
2e30b2b
delete scale_t and acc_t
Xinyu302 Aug 10, 2023
31cedc4
delete function inner
Xinyu302 Aug 10, 2023
a5780fc
delete define acc_t, DIM, ADDR_LEN
Xinyu302 Sep 5, 2023
2e11d1c
update to gemmini upstream and complete rectangle conv
Xinyu302 Jul 12, 2023
70e68ee
modify by gemmini upstream
Xinyu302 Sep 5, 2023
f7237aa
fix bug: tile_conv interface changes
Xinyu302 Sep 5, 2023
58f450f
modify tile_conv case in ciface.mlir and tile-conv.mlir; fix Legalize…
Xinyu302 Sep 6, 2023
b570358
fix writing error in ciface.mlir
Xinyu302 Sep 6, 2023
cebddcd
fix bug in RISCVInstrInfoBuddyExt.td and delete comments in LegalizeF…
Xinyu302 Sep 6, 2023
ceb85c0
fix bug in rectangle conv
Xinyu302 Sep 6, 2023
46eeb74
delete comment
Xinyu302 Sep 6, 2023
dd82aca
add test case tile-rect-conv.mlir
Xinyu302 Sep 6, 2023
5dc0b58
add blank lines
Xinyu302 Oct 24, 2023
501f28b
add a space
Xinyu302 Oct 24, 2023
6572223
modify comments in tile-conv.mlir and tile-rect-conv.mlir
Xinyu302 Oct 24, 2023
0f5125a
modify tile-rect-conv.mlir
Xinyu302 Oct 24, 2023
55fc0f3
change matrix shape
Xinyu302 Oct 26, 2023
95fa559
add relu and softmax test
Xinyu302 Oct 26, 2023
74436ed
fix test error
Xinyu302 Oct 26, 2023
c7ee7fb
add tile-conv-relu
Xinyu302 Oct 26, 2023
bc9b1e2
add conv-igelu and conv-softmax
Xinyu302 Oct 26, 2023
ab50d5b
print origin result of matmul
Xinyu302 Oct 27, 2023
3ef138f
fix small bug
Xinyu302 Oct 27, 2023
a7785ad
add print origin result for conv test
Xinyu302 Oct 27, 2023
10deee8
add layernorm test
Xinyu302 Oct 27, 2023
02a1291
fix bug in tile-matmul-ws-layernorm
Xinyu302 Oct 27, 2023
834e837
add layernorm and filecheck
Xinyu302 Oct 27, 2023
15ac6fd
delete useless var
Xinyu302 Oct 27, 2023
5ea3690
handle compiler warnings
Xinyu302 Oct 27, 2023
ae7ebcc
delete useless vars
Xinyu302 Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
//
//===----------------------------------------------------------------------===//
let TargetPrefix = "riscv" in
def int_riscv_mvin : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>;
def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

let TargetPrefix = "riscv" in
def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

let TargetPrefix = "riscv" in
def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>;

Xinyu302 marked this conversation as resolved.
Show resolved Hide resolved
let TargetPrefix = "riscv" in
def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;
Expand All @@ -35,6 +41,9 @@ def int_riscv_config_st : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;
let TargetPrefix = "riscv" in
def int_riscv_config_ex : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

let TargetPrefix = "riscv" in
def int_riscv_config_norm : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

let TargetPrefix = "riscv" in
def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>;

Expand Down
27 changes: 27 additions & 0 deletions backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs),
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in
def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{
Expand Down Expand Up @@ -65,6 +77,12 @@ def CONFIG_EX : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
let rd = 0;
}

let Predicates = [HasBuddyExt] in
def CONFIG_NORM : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "config_norm", "$rs1, $rs2"> {
let rd = 0;
}

let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in
def PRELOAD : RVInstR<0b0000110, 0b011,OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{
Expand Down Expand Up @@ -164,6 +182,12 @@ def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs),
let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>;

Expand All @@ -179,6 +203,9 @@ def : Pat<(int_riscv_config_st GPR:$rs1, GPR:$rs2), (CONFIG_ST GPR:$rs1, GPR:$rs
let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_config_ex GPR:$rs1, GPR:$rs2), (CONFIG_EX GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_config_norm GPR:$rs1, GPR:$rs2), (CONFIG_NORM GPR:$rs1, GPR:$rs2)>;

let Predicates = [HasBuddyExt] in
def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>;

Expand Down
24 changes: 12 additions & 12 deletions examples/GemminiDialect/ciface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,53 +127,53 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13
func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) {
%outdim = arith.constant 254 : i64
%kernelDim = arith.constant 3 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv2
func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) {
%outdim = arith.constant 252 : i64
%kernelDim = arith.constant 5 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv3
func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) {
%outdim = arith.constant 250 : i64
%kernelDim = arith.constant 7 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv4
func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) {
%outdim = arith.constant 248 : i64
%kernelDim = arith.constant 9 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv5
func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) {
%outdim = arith.constant 246 : i64
%kernelDim = arith.constant 11 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 i64
return
}

// CHECK: llvm.func @_mlir_ciface_gemmini_conv6
func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) {
%outdim = arith.constant 244 : i64
%kernelDim = arith.constant 13 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 i64
return
}

Expand Down
72 changes: 72 additions & 0 deletions examples/GemminiDialect/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,42 @@ tile-matmul-run:
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-os-run:
@${BUDDY_OPT} ./tile-matmul-os.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-igelu-run:
@${BUDDY_OPT} ./tile-matmul-ws-igelu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-relu-run:
@${BUDDY_OPT} ./tile-matmul-ws-relu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-matmul-ws-softmax-run:
@${BUDDY_OPT} ./tile-matmul-ws-softmax.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-run:
@${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
Expand All @@ -85,6 +121,42 @@ tile-conv-run:
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-igelu-run:
@${BUDDY_OPT} ./tile-conv-igelu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-softmax-run:
@${BUDDY_OPT} ./tile-conv-softmax.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-conv-relu-run:
@${BUDDY_OPT} ./tile-conv-relu.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

tile-rect-conv-run:
@${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \
${BUDDY_TRANSLATE} --buddy-to-llvmir | \
${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \
-mattr=+buddyext,+D -float-abi=hard \
-o log.o
@riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out
@spike --extension=gemmini pk a.out

gemmini-linalg-matmul-run:
@${BUDDY_OPT} ./matmul.mlir \
-convert-linalg-to-gemmini \
Expand Down
39 changes: 39 additions & 0 deletions examples/GemminiDialect/tile-conv-igelu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]]]]>

// outChannels = 2 kernelDim = 3 inChannels = 1
memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2]]>

// outChannels = 2
memref.global "private" @bias : memref<2xi32> = dense<[1,1]>

func.func @main() -> i64 {
%0 = arith.constant 0 : i64
%3 = arith.constant 3 : i64
%input = memref.get_global @input : memref<1x5x5x1xi8>
%weight = memref.get_global @weight : memref<9x2xi8>
%bias = memref.get_global @bias : memref<2xi32>
%output = memref.alloc() : memref<9x2xi8>
// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 3}:
Xinyu302 marked this conversation as resolved.
Show resolved Hide resolved
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
39 changes: 39 additions & 0 deletions examples/GemminiDialect/tile-conv-relu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]]]]>

// outChannels = 2 kernelDim = 3 inChannels = 1
memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2]]>

// outChannels = 2
memref.global "private" @bias : memref<2xi32> = dense<[1,1]>

func.func @main() -> i64 {
%0 = arith.constant 0 : i64
%3 = arith.constant 3 : i64
%input = memref.get_global @input : memref<1x5x5x1xi8>
%weight = memref.get_global @weight : memref<9x2xi8>
%bias = memref.get_global @bias : memref<2xi32>
%output = memref.alloc() : memref<9x2xi8>
// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
39 changes: 39 additions & 0 deletions examples/GemminiDialect/tile-conv-softmax.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: buddy-opt %s \
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]]]]>

// outChannels = 2 kernelDim = 3 inChannels = 1
memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2],
[-1, 2], [-1, 2], [-1, 2]]>

// outChannels = 2
memref.global "private" @bias : memref<2xi32> = dense<[1,1]>

func.func @main() -> i64 {
%0 = arith.constant 0 : i64
%3 = arith.constant 3 : i64
%input = memref.get_global @input : memref<1x5x5x1xi8>
%weight = memref.get_global @weight : memref<9x2xi8>
%bias = memref.get_global @bias : memref<2xi32>
%output = memref.alloc() : memref<9x2xi8>
// CHECK: "gemmini.intr.loop_conv_ws_config1"
// CHECK: "gemmini.intr.loop_conv_ws_config2"
// CHECK: "gemmini.intr.loop_conv_ws_config3"
// CHECK: "gemmini.intr.loop_conv_ws_config4"
// CHECK: "gemmini.intr.loop_conv_ws_config5"
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 4}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
6 changes: 3 additions & 3 deletions examples/GemminiDialect/tile-conv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: --lower-gemmini | \
// RUN: FileCheck %s

// batchSize = 1 inputDIm = 5 inChannels = 2
// batchSize = 1 inputDim = 5 inChannels = 1
memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
[[1], [0], [-1], [0], [1]],
Expand Down Expand Up @@ -32,8 +32,8 @@ func.func @main() -> i64 {
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
Loading