diff --git a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td index a474ca956b..c0cac18044 100644 --- a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td +++ b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -18,7 +18,13 @@ // //===----------------------------------------------------------------------===// let TargetPrefix = "riscv" in -def int_riscv_mvin : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; +def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; let TargetPrefix = "riscv" in def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; @@ -35,6 +41,9 @@ def int_riscv_config_st : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; let TargetPrefix = "riscv" in def int_riscv_config_ex : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; +let TargetPrefix = "riscv" in +def int_riscv_config_norm : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + let TargetPrefix = "riscv" in def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; diff --git a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td index a45a8ff808..adc172ab2f 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -35,6 +35,18 @@ def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs), let rd = 0; } +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> { + let rd = 0; +} + let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs), (ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{ @@ -65,6 +77,12 @@ def CONFIG_EX : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs), let rd = 0; } +let Predicates = [HasBuddyExt] in +def CONFIG_NORM : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs), + (ins GPR:$rs1, GPR:$rs2), "config_norm", "$rs1, $rs2"> { + let rd = 0; +} + let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in def PRELOAD : RVInstR<0b0000110, 0b011,OPC_CUSTOM_3,(outs), (ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{ @@ -164,6 +182,12 @@ def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs), let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>; + let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>; @@ -179,6 +203,9 @@ def : Pat<(int_riscv_config_st GPR:$rs1, GPR:$rs2), (CONFIG_ST GPR:$rs1, GPR:$rs let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_config_ex GPR:$rs1, GPR:$rs2), (CONFIG_EX GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_config_norm GPR:$rs1, GPR:$rs2), (CONFIG_NORM GPR:$rs1, GPR:$rs2)>; + let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>; diff --git a/examples/GemminiDialect/ciface.mlir b/examples/GemminiDialect/ciface.mlir index 004992f2a3..e45b6bed28 100644 --- a/examples/GemminiDialect/ciface.mlir +++ b/examples/GemminiDialect/ciface.mlir @@ -127,8 +127,8 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13 func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) { %outdim = arith.constant 254 : i64 %kernelDim = arith.constant 3 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 i64 return } @@ -136,8 +136,8 @@ func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8 func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) { %outdim = arith.constant 252 : i64 %kernelDim = arith.constant 5 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 i64 return } @@ -145,8 +145,8 @@ func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) { %outdim = arith.constant 250 : i64 %kernelDim = arith.constant 7 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 i64 return } @@ -154,8 +154,8 @@ func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) { %outdim = arith.constant 248 : i64 %kernelDim = arith.constant 9 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 i64 return } @@ -163,8 +163,8 @@ func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) { %outdim = arith.constant 246 : i64 %kernelDim = arith.constant 11 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 i64 return } @@ -172,8 +172,8 @@ func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1x func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) { %outdim = arith.constant 244 : i64 %kernelDim = arith.constant 13 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : + memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 i64 return } diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 7bc1c3dc22..cba84b780a 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -76,6 +76,51 @@ tile-matmul-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-matmul-os-run: + @${BUDDY_OPT} ./tile-matmul-os.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-matmul-ws-igelu-run: + @${BUDDY_OPT} ./tile-matmul-ws-igelu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-matmul-ws-relu-run: + @${BUDDY_OPT} ./tile-matmul-ws-relu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-matmul-ws-softmax-run: + @${BUDDY_OPT} ./tile-matmul-ws-softmax.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-matmul-ws-layernorm-run: + @${BUDDY_OPT} ./tile-matmul-ws-layernorm.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-run: @${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ @@ -85,6 +130,51 @@ tile-conv-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-conv-igelu-run: + @${BUDDY_OPT} ./tile-conv-igelu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-conv-softmax-run: + @${BUDDY_OPT} ./tile-conv-softmax.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-conv-relu-run: + @${BUDDY_OPT} ./tile-conv-relu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-conv-layernorm-run: + @${BUDDY_OPT} ./tile-conv-layernorm.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-rect-conv-run: + @${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + gemmini-linalg-matmul-run: @${BUDDY_OPT} ./matmul.mlir \ -convert-linalg-to-gemmini \ diff --git a/examples/GemminiDialect/tile-conv-igelu.mlir b/examples/GemminiDialect/tile-conv-igelu.mlir new file mode 100644 index 0000000000..9eb34f1f5f --- /dev/null +++ b/examples/GemminiDialect/tile-conv-igelu.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 3}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv-layernorm.mlir b/examples/GemminiDialect/tile-conv-layernorm.mlir new file mode 100644 index 0000000000..1d440c125d --- /dev/null +++ b/examples/GemminiDialect/tile-conv-layernorm.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 2}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv-relu.mlir b/examples/GemminiDialect/tile-conv-relu.mlir new file mode 100644 index 0000000000..99198f583a --- /dev/null +++ b/examples/GemminiDialect/tile-conv-relu.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv-softmax.mlir b/examples/GemminiDialect/tile-conv-softmax.mlir new file mode 100644 index 0000000000..67a63d5f4e --- /dev/null +++ b/examples/GemminiDialect/tile-conv-softmax.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 4}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv.mlir b/examples/GemminiDialect/tile-conv.mlir index 42f0085ce3..9ac91acffc 100644 --- a/examples/GemminiDialect/tile-conv.mlir +++ b/examples/GemminiDialect/tile-conv.mlir @@ -2,7 +2,7 @@ // RUN: --lower-gemmini | \ // RUN: FileCheck %s -// batchSize = 1 inputDIm = 5 inChannels = 2 +// batchSize = 1 inputDim = 5 inChannels = 1 memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1]], @@ -32,8 +32,8 @@ func.func @main() -> i64 { // CHECK: "gemmini.intr.loop_conv_ws_config6" // CHECK: "gemmini.intr.loop_conv_ws" // CHECK: "gemmini.intr.flush" - gemmini.tile_conv %input %weight %bias %output %3 %3 {stride = 1}: - memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> return %0 : i64 } diff --git a/examples/GemminiDialect/tile-matmul-os.mlir b/examples/GemminiDialect/tile-matmul-os.mlir new file mode 100644 index 0000000000..120a44654d --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-os.mlir @@ -0,0 +1,37 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.alloc() {alignment = 16} : memref<64x64xi8> + %bArray = memref.alloc() {alignment = 16}: memref<64x64xi8> + %cArray = memref.alloc() {alignment = 16}: memref<64x64xi8> + %dArray = memref.alloc() {alignment = 64} : memref<64x64xi32> + %dim = memref.dim %aArray, %c0 : memref<64x64xi8> + scf.for %i = %c0 to %dim step %c1 { + scf.for %j = %c0 to %dim step %c1 { + memref.store %i1I8, %aArray[%i, %j] : memref<64x64xi8> + memref.store %i1I8, %bArray[%i, %j] : memref<64x64xi8> + memref.store %i2I32, %dArray[%i, %j] : memref<64x64xi32> + } + } + + gemmini.print %aArray : memref<64x64xi8> + gemmini.print %bArray : memref<64x64xi8> + gemmini.print %dArray : memref<64x64xi32> + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.mvin" + // CHECK: "gemmini.intr.preload" + // CHECK: "gemmini.intr.compute_preloaded" + // CHECK: "gemmini.intr.compute_accumulated" + // CHECK: "gemmini.intr.mvout" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=0} : memref<64x64xi8> memref<64x64xi8> memref<64x64xi8> memref <64x64xi32> + gemmini.print %cArray : memref<64x64xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir new file mode 100644 index 0000000000..0edf6428bd --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=3, bertScale=0.8:f32}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir new file mode 100644 index 0000000000..cf3529c28b --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir @@ -0,0 +1,48 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 0, 0], [0, -1, 1, 0, 1], [-1, 0, -1, -1, 0], [-1, 0, 0, 1, 0], [0, 0, 0, 0, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[-1, 0, 1, 0, -1], [1, -1, 1, 0, -1], [-1, -1, -1, 1, 1], [-1, 1, 0, -1, 1], [-1, 0, 1, 1, 1]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=2}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-matmul-ws-relu.mlir b/examples/GemminiDialect/tile-matmul-ws-relu.mlir new file mode 100644 index 0000000000..f461950615 --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-relu.mlir @@ -0,0 +1,48 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-matmul-ws-softmax.mlir b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir new file mode 100644 index 0000000000..c81bccceac --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=4, bertScale=0.05:f32}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-rect-conv.mlir b/examples/GemminiDialect/tile-rect-conv.mlir new file mode 100644 index 0000000000..e982b3ad6f --- /dev/null +++ b/examples/GemminiDialect/tile-rect-conv.mlir @@ -0,0 +1,42 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputRowDim = 5 inputColDim = 10 inChannels = 1 +memref.global "private" @input : memref<1x5x10x1xi8> = dense<[[[[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[1, 2], [1, 2], [1, 2], + [1, 2], [1, 2], [1, 2], + [1, 2], [1, 2], [1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %8 = arith.constant 8 : i64 + + %input = memref.get_global @input : memref<1x5x10x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<24x2xi8> + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.mvin3" + // CHECK: "gemmini.intr.mvin" + // CHECK: "gemmini.intr.mvin2" + // CHECK: "gemmini.intr.preload" + // CHECK: "gemmini.intr.compute_preloaded" + // CHECK: "gemmini.intr.compute_accumulated" + gemmini.tile_conv %input %weight %bias %output %3 %8 %3 {stride = 1}: + memref<1x5x10x1xi8> memref<9x2xi8> memref<2xi32> memref<24x2xi8> i64 i64 i64 + gemmini.print %output : memref<24x2xi8> + return %0 : i64 +} diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index df7393d3e9..f2569a0c15 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -76,7 +76,9 @@ def ConfigLdOp : Gemmini_Op<"config_ld"> { let arguments = (ins I64:$stride, DefaultValuedAttr:$scale, DefaultValuedAttr:$shrunk, - DefaultValuedAttr:$id); + DefaultValuedAttr:$id, + DefaultValuedAttr:$block_mvin_stride, + DefaultValuedAttr:$pixel_repeats); let assemblyFormat = "$stride attr-dict `:` type($stride)"; } @@ -116,6 +118,28 @@ def ConfigExOp : Gemmini_Op<"config_ex"> { let assemblyFormat = "attr-dict"; } +def ConfigNormOp : Gemmini_Op<"config_norm"> { + let summary = "ConfigNormOp configures normalize pipeline"; + let description = [{ + ConfigNormOp configures normalize pipeline + -qConst: A constant value used for quantization during normalization. + -qConstType: Defines the type of the qConst. + -setStatsIdOnly: A flag to indicate if only the StatsId should be set. + -actMsg: A message related to the normalization activity. + -StatsId: An identifier associated with the statistics or metrics of the normalization process. + -igeluQb: A parameter related to the IGELU function for quantization. Specifies the 'b' value. + -igeluQc: Another parameter related to the IGELU function for quantization. Specifies the 'c' value. + }]; + let arguments = (ins DefaultValuedAttr:$qConst, + DefaultValuedAttr:$qConstType, + DefaultValuedAttr:$setStatsIdOnly, + DefaultValuedAttr:$actMsb, + DefaultValuedAttr:$StatsId, + DefaultValuedAttr:$igeluQb, + DefaultValuedAttr:$igeluQc); + let assemblyFormat = "attr-dict"; +} + def MvinOp : Gemmini_Op<"mvin"> { let summary = "Load operation"; let description = [{ @@ -128,6 +152,32 @@ def MvinOp : Gemmini_Op<"mvin"> { let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; } +def Mvin2Op : Gemmini_Op<"mvin2"> { + let summary = "Load operation"; + let description = [{ + Similar to Mvin + Move data from main memory to scratchpad + - MemRef to load in. + (including DRAM address, number of columns, number of rows) + - Local scratchpad or accumulator address. + }]; + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$input, I64:$addr); + let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; +} + +def Mvin3Op : Gemmini_Op<"mvin3"> { + let summary = "Load operation"; + let description = [{ + Similar to Mvin and Mvin2 + Move data from main memory to scratchpad + - MemRef to load in. + (including DRAM address, number of columns, number of rows) + - Local scratchpad or accumulator address. + }]; + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$input, I64:$addr); + let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; +} + def MvoutOp : Gemmini_Op<"mvout"> { let summary = "Store operation"; let description = [{ @@ -238,7 +288,8 @@ def TileMatMulOp : Gemmini_Op<"tile_matmul"> { DefaultValuedAttr:$bTranspose, DefaultValuedAttr:$fullC, DefaultValuedAttr:$lowD, - DefaultValuedAttr:$weightA); + DefaultValuedAttr:$weightA, + DefaultValuedAttr:$dataflow); let assemblyFormat = [{ $aArray $bArray $cArray $dArray attr-dict `:` type($aArray) type($bArray) type($cArray) type($dArray) @@ -254,7 +305,7 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { MemRefRankOf<[AnyType], [2]>:$weights, MemRefRankOf<[AnyType], [1]>:$bias, MemRefRankOf<[AnyType], [2]>:$output, - I64:$outDim, I64:$kernelDim, + I64:$outRowDim, I64:$outColDim, I64:$kernelDim, DefaultValuedAttr:$scale, DefaultValuedAttr:$stride, DefaultValuedAttr:$inputDilation, @@ -270,8 +321,8 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { DefaultValuedAttr:$poolStride, DefaultValuedAttr:$poolPadding); let assemblyFormat = [{ - $input $weights $bias $output $outDim $kernelDim attr-dict `:` type($input) - type($weights) type($bias) type($output) type($outDim) type($kernelDim) + $input $weights $bias $output $outRowDim $outColDim $kernelDim attr-dict `:` type($input) + type($weights) type($bias) type($output) type($outRowDim) type($outColDim) type($kernelDim) }]; } @@ -291,6 +342,12 @@ class Gemmini_IntrOpBase traits = []> : def Gemmini_Mvin_IntrOp : Gemmini_IntrOpBase<"mvin">, Arguments<(ins LLVM_Type, LLVM_Type)>; +def Gemmini_Mvin2_IntrOp : Gemmini_IntrOpBase<"mvin2">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +def Gemmini_Mvin3_IntrOp : Gemmini_IntrOpBase<"mvin3">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + def Gemmini_Mvout_IntrOp : Gemmini_IntrOpBase<"mvout">, Arguments<(ins LLVM_Type, LLVM_Type)>; @@ -303,7 +360,10 @@ def Gemmini_ConifgLd_IntrOp : Gemmini_IntrOpBase<"config_ld">, def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, +def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, Arguments<(ins LLVM_Type, LLVM_Type)>; def Gemmini_Preload_IntrOp : Gemmini_IntrOpBase<"preload">, diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index f0908b4887..86d27cbd9e 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -24,13 +24,18 @@ #define CONFIG_LD 1 #define CONFIG_ST 2 #define CONFIG_EX 0 +#define CONFIG_BERT 3 + +#define GARBAGE_ADDR ((uint32_t)(-1)) +#define OUTPUT_STATIONARY 0 +#define WEIGHT_STATIONARY 1 + +#define MVIN_SCALE_IDENTITY 1.0 #define ACC_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 #define ACC_ROWS 1024 #define MAX_BYTES 64 -#define MAX_BLOCK_LEN (MAX_BYTES/(DIM*1)) -#define MAX_BLOCK_LEN_ACC (MAX_BYTES/(DIM*4)) #define HAS_FIRST_LAYER_OPTIMIZATIONS typedef uint32_t acc_scale_t_bits; diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index 19cc8c0608..5930d33705 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -182,7 +182,7 @@ class Conv2DNchwFchwLowering loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(weightsShape[2])); rewriter.create( - loc, inputMat, weightsMat, bias, outputMat, outDim, kernelDim, + loc, inputMat, weightsMat, bias, outputMat, outDim, outDim, kernelDim, llvm::APFloat(float(1.0)), strides, dilations); rewriter.eraseOp(convOp); loopIvs0.clear(); @@ -309,7 +309,7 @@ class Conv2DNhwcHwcfLowering attr = rewriter.getI64IntegerAttr(kernelShape[1]); kernelDim = rewriter.create(loc, attr); rewriter.create( - loc, input, kernelMat, bias, outputMat, outDim, kernelDim, + loc, input, kernelMat, bias, outputMat, outDim, outDim, kernelDim, llvm::APFloat(float(1.0)), strides, dilations); // after the conv operation is completed, the data in outputmat needs to be // transferred into output. diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index de54cc8d44..b8d81bc4d5 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -62,6 +62,38 @@ scale_t_bits scale_t_to_scale_t_bits(scale_t x) { un.f = x; return un.b; } + +template +void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, int64_t addrLen, + ConversionPatternRewriter &rewriter) { + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + uint64_t spadAddrInt = (uint64_t)rows << (addrLen + 16) | + (uint64_t)cols << addrLen | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); +} + +void gemminiMvoutOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, int64_t addrLen, + ConversionPatternRewriter &rewriter) { + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + uint64_t spadAddrInt = (uint64_t)rows << (addrLen + 16) | + (uint64_t)cols << addrLen | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); +} + }; // namespace template @@ -116,11 +148,12 @@ struct GemminiConfigStLowering : public ConvertOpToLLVMPattern { int stride = getNumberFromValue(strideValue); float scale = configStOp.getScale().convertToFloat(); Location loc = configStOp.getLoc(); + uint64_t rs1 = ((uint64_t)configStOp.getActivation() << 2) | CONFIG_ST; uint64_t arg = (uint64_t)acc_scale_t_to_acc_scale_t_bits((acc_scale_t)scale) << 32 | (uint32_t)stride; Value value1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(CONFIG_ST)); + loc, rewriter.getI64IntegerAttr(rs1)); Value value2 = rewriter.create( loc, rewriter.getI64IntegerAttr(arg)); rewriter.replaceOpWithNewOp(configStOp, value1, value2); @@ -135,8 +168,10 @@ struct GemminiConfigLdLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value rs2Value = configLdOp.getStride(); float scale = configLdOp.getScale().convertToFloat(); + uint64_t blockMvinStride = configLdOp.getBlockMvinStride(); + uint64_t pixelRepeats = configLdOp.getPixelRepeats(); uint64_t rs1 = (uint64_t)scale_t_to_scale_t_bits(scale) << 32 | - ((uint64_t)16 << 16) | (uint64_t)1 << 8 | + (blockMvinStride << 16) | pixelRepeats << 8 | configLdOp.getId() << 3 | configLdOp.getShrunk() << 2 | CONFIG_LD; Location loc = configLdOp.getLoc(); @@ -173,6 +208,28 @@ struct GemminiConfigExLowering : public ConvertOpToLLVMPattern { } }; +struct GemminiConfigNormLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ConfigNormOp configNormOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = configNormOp.getLoc(); + uint64_t rs1 = (((uint64_t) ((uint32_t)configNormOp.getQConst())) << 32) | + (configNormOp.getQConstType() & 1) << 18 | + (configNormOp.getSetStatsIdOnly() & 1) << 17 | + (configNormOp.getActMsb() & 1) << 16 | + configNormOp.getStatsId() << 8 | CONFIG_BERT; + uint64_t rs2 = (((uint64_t) ((uint32_t)configNormOp.getIgeluQc())) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIgeluQb())); + Value rs1Value = rewriter.create( + loc, rewriter.getI64IntegerAttr(rs1)); + Value rs2Value = rewriter.create( + loc, rewriter.getI64IntegerAttr(rs2)); + rewriter.replaceOpWithNewOp(configNormOp, rs1Value, + rs2Value); + return success(); + } +}; + struct GemminiMvinLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiMvinLowering(LLVMTypeConverter &typeConverter, @@ -206,6 +263,72 @@ struct GemminiMvinLowering : public ConvertOpToLLVMPattern { int64_t addrLen; }; +struct GemminiMvin2Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiMvin2Lowering(LLVMTypeConverter &typeConverter, + int64_t addrLen) + : ConvertOpToLLVMPattern(typeConverter), addrLen(addrLen) {} + LogicalResult + matchAndRewrite(Mvin2Op mvin2Op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = mvin2Op.getInput(); + Location loc = input.getLoc(); + MemRefType memRefType = + mvin2Op.getOperandTypes().front().dyn_cast(); + llvm::ArrayRef memRefShape = memRefType.getShape(); + TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); + Value extractOp = rewriter.create( + loc, resultType, input); + IntegerType i64Type = rewriter.getI64Type(); + Value indexCastOp = + rewriter.create(loc, i64Type, extractOp); + Value spadAddrValue = mvin2Op.getAddr(); + uint64_t number = getNumberFromValue(spadAddrValue); + uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | + (uint64_t)memRefShape[1] << addrLen | number; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.replaceOpWithNewOp(mvin2Op, indexCastOp, spad); + return success(); + } + +private: + int64_t addrLen; +}; + +struct GemminiMvin3Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiMvin3Lowering(LLVMTypeConverter &typeConverter, + int64_t addrLen) + : ConvertOpToLLVMPattern(typeConverter), addrLen(addrLen) {} + LogicalResult + matchAndRewrite(Mvin3Op mvin3Op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = mvin3Op.getInput(); + Location loc = input.getLoc(); + MemRefType memRefType = + mvin3Op.getOperandTypes().front().dyn_cast(); + llvm::ArrayRef memRefShape = memRefType.getShape(); + TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); + Value extractOp = rewriter.create( + loc, resultType, input); + IntegerType i64Type = rewriter.getI64Type(); + Value indexCastOp = + rewriter.create(loc, i64Type, extractOp); + Value spadAddrValue = mvin3Op.getAddr(); + uint64_t number = getNumberFromValue(spadAddrValue); + uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | + (uint64_t)memRefShape[1] << addrLen | number; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.replaceOpWithNewOp(mvin3Op, indexCastOp, spad); + return success(); + } + +private: + int64_t addrLen; +}; + struct GemminiMvoutLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiMvoutLowering(LLVMTypeConverter &typeConverter, @@ -285,7 +408,7 @@ struct GemminiPreloadLowering : public ConvertOpToLLVMPattern { Value bdCols = preloadOp.getBdCols(); Value bdRows = preloadOp.getBdRows(); Value cCols = preloadOp.getCCols(); - Value cRows = preloadOp.getBdRows(); + Value cRows = preloadOp.getCRows(); Location loc = preloadOp.getLoc(); uint64_t bdAddrInt = getNumberFromValue(bdAddr); uint64_t cAddrInt = getNumberFromValue(cAddr); @@ -429,19 +552,171 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, - scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, - size_t k, size_t padI, size_t padJ, size_t padK, size_t strideA, - size_t strideB, size_t strideD, size_t strideC, bool aTranspose, - bool bTranspose, bool fullC, bool lowD, bool noBias, - bool repeatingBias, int act, TileMatMulOp &tileMatMulOp, - ConversionPatternRewriter &rewriter) const { - - gemminiLoopWs(i, j, k, padI, padJ, padK, a, b, pre, out, strideA, strideB, + void spTiledMatmulWs(Value &a, Value &b, Value &d, Value &c, + scale_t aScaleFactor, scale_t bScaleFactor, + scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, + size_t padI, size_t padJ, size_t padK, size_t strideA, + size_t strideB, size_t strideD, size_t strideC, + bool aTranspose, bool bTranspose, bool fullC, bool lowD, + bool noBias, bool repeatingBias, int act, + TileMatMulOp &tileMatMulOp, + ConversionPatternRewriter &rewriter) const { + + gemminiLoopWs(i, j, k, padI, padJ, padK, a, b, d, c, strideA, strideB, repeatingBias ? 0 : strideD, strideC, aTranspose, bTranspose, fullC, lowD, !noBias, act, tileMatMulOp, rewriter); } + // Tiling functions + void spTiledMatmulOs(Value &a, Value &b, Value &d, Value &c, + scale_t aScaleFactor, scale_t bScaleFactor, + scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, + size_t padI, size_t padJ, size_t padK, size_t strideA, + size_t strideB, size_t strideD, size_t strideC, + bool aTranspose, bool bTranspose, bool fullC, bool lowD, + bool noBias, bool repeatingBias, int act, + TileMatMulOp &tileMatMulOp, + ConversionPatternRewriter &rewriter) const { + const uint32_t aSpAddrStart = 0; + const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - k * j * dim; + const uint32_t dSpAddrStart = 1 << (addrLen - 1); + const uint32_t cSpAddrStart = + (3 << (addrLen - 2)) | (fullC << (addrLen - 3)); + + const size_t maxBlockLen = MAX_BYTES / (dim * 1); + const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); + + const int aBlocks = k <= maxBlockLen ? k : maxBlockLen; + const int bBlocks = j <= maxBlockLen ? j : maxBlockLen; + const int dBlocks = j <= maxBlockLenAcc ? j : maxBlockLenAcc; + + Location loc = a.getLoc(); + bool dAddrNull = llvm::dyn_cast(d.getDefiningOp()) && getNumberFromValue(d) == 0; + bool cAddrNull = llvm::dyn_cast(c.getDefiningOp()) && getNumberFromValue(c) == 0; + + // Move-in D + if (!dAddrNull && !noBias) { + const size_t dStride = repeatingBias ? 0 : strideD * sizeOfAccT; + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dStride)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)dScaleFactor)); + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0 += dBlocks) { + const size_t biasRow = repeatingBias ? 0 : i0; + const size_t offset = (biasRow * strideD + j0) * dim * sizeOfAccT; + const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * dim; + const size_t blocks = j0 + dBlocks <= j ? dBlocks : j - j0; + const size_t cols = blocks * dim - (j0 + blocks >= j ? padJ : 0); + const size_t rows = dim - (i0 == i - 1 ? padI : 0); + gemminiMvinOffset(d, offset, dSpAddrAcc, cols, rows, addrLen, rewriter); + } + } + } + + // Move-in B + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideB)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)bScaleFactor)); + for (size_t j0 = 0; j0 < j; j0 += bBlocks) { + for (size_t k0 = 0; k0 < k; k0++) { + const size_t offset = (k0 * strideB + j0) * dim * sizeOfElemT; + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * dim; + const size_t blocks = j0 + bBlocks <= j ? bBlocks : j - j0; + const size_t cols = blocks * dim - (j0 + blocks >= j ? padJ : 0); + const size_t rows = dim - (k0 == k - 1 ? padK : 0); + gemminiMvinOffset(b, offset, bSpAddr, cols, rows, addrLen, rewriter); + } + } + + // Move-in A + strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideA)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)aScaleFactor)); + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t k0 = 0; k0 < k; k0 += aBlocks) { + const size_t offset = (i0 * strideA + k0) * dim * sizeOfElemT; + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * dim; + const size_t blocks = k0 + aBlocks <= k ? aBlocks : k - k0; + const size_t cols = blocks * dim - (k0 + blocks >= k ? padK : 0); + const size_t rows = dim - (i0 == i - 1 ? padI : 0); + gemminiMvinOffset(a, offset, aSpAddr, cols, rows, addrLen, rewriter); + } + } + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0++) { + const uint32_t cSpAddr = cSpAddrStart + (i0 * j + j0) * dim; + for (size_t k0 = 0; k0 < k; k0++) { + + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * dim; + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * dim; + + uint32_t outSpAddr = k0 == k - 1 ? cSpAddr : GARBAGE_ADDR; + + // If we're not using a bias, then we want to overwrite what's in the + // accumulator, rather than writing over it + + int noBiasNewMatrix = noBias && !dAddrNull && k0 == k - 1; + if (noBiasNewMatrix) { + outSpAddr &= ~(1 << (addrLen - 2)); + } + + const size_t aCols = dim - (k0 == k - 1 ? padK : 0); + const size_t aRows = dim - (i0 == i - 1 ? padI : 0); + const size_t bCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t bRows = dim - (k0 == k - 1 ? padK : 0); + const size_t cCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t cRows = dim - (i0 == i - 1 ? padI : 0); + + Value aColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aCols)); + Value aRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aRows)); + Value bColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bCols)); + Value bRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bRows)); + Value cColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cCols)); + Value cRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cRows)); + + Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); + Value bSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bSpAddr)); + Value outSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(outSpAddr)); + + Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); + Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + + rewriter.create(loc, garbageAddrOp, outSpAddrOp, dimOp, + dimOp, cRowsOp, cColsOp); + + if (k0 == 0) { // First iteration + rewriter.create(loc, aSpAddrOp, bSpAddrOp, aRowsOp, aColsOp, bRowsOp, bColsOp); + + } else { // All other iterations + rewriter.create(loc, aSpAddrOp, bSpAddrOp, aRowsOp, aColsOp, bRowsOp, bColsOp); + } + } + } + } + // Move-out C + if (!cAddrNull) { + const size_t sizeof_C = fullC ? sizeOfAccT : sizeOfElemT; + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0++) { + const size_t offset = (i0 *strideC + j0)*dim*sizeof_C; + const uint32_t cSpAddr = cSpAddrStart + (i0 *j + j0)*dim; + + const size_t cCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t cRows = dim - (i0 == j - 1 ? padI : 0); + + gemminiMvoutOffset(c, offset, cSpAddr, cCols, cRows, addrLen, rewriter); + } + } + } + } + void tiledMatmulOuter(size_t dimI, size_t dimJ, size_t dimK, Value &A, Value &B, Value &D, Value &C, size_t strideA, size_t strideB, size_t strideD, size_t strideC, @@ -450,7 +725,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { size_t tileK, int act, acc_scale_t scale, acc_scale_t bertScale, bool repeatingBias, bool aTranspose, bool bTranspose, bool fullC, bool lowD, - uint8_t weightA, TileMatMulOp &tileMatMulOp, + uint8_t weightA, int dataflow, TileMatMulOp &tileMatMulOp, ConversionPatternRewriter &rewriter) const { const size_t dimIPadded = (dimI / dim + (dimI % dim != 0)) * dim; const size_t dimJPadded = (dimJ / dim + (dimJ % dim != 0)) * dim; @@ -475,7 +750,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { const size_t sizeofC = fullC ? sizeOfAccT : sizeOfElemT; Location loc = tileMatMulOp.getLoc(); llvm::APFloat accScaleIdentity((float)ACC_SCALE_IDENTITY); - rewriter.create(loc, /*dataflow = */ 1, /*sysAct = */ act & 3, + rewriter.create(loc, /*dataflow = */ dataflow, /*sysAct = */ act & 3, /* sysShift = */ 0, accScaleIdentity); Value strideValue = rewriter.create( loc, rewriter.getI64IntegerAttr(strideC * sizeofC)); @@ -493,6 +768,33 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(strideD * sizeofD)); rewriter.create(loc, strideValue, llvm::APFloat((float)dScaleFactor), lowD, 2); + + /* + Add config norm op + */ + if (act == IGELU) { + const float sqrt_2 = 1.41421356237; + const float S = bertScale; + const float S_erf = (-0.2888 * ((S*S) / 2)); + + const uint32_t qb = -1.769 / (S / sqrt_2); + const uint32_t qc = 1.0 / S_erf; + rewriter.create(loc, 0, 0, 0, 0,0, qb, qc); + } + + if (act == SOFTMAX) { + const float a = 0.3585; + const float b = 1.353; + const float c = 0.344; + + const uint32_t qln2 = (int) (0.693147 / bertScale); + const uint32_t qln2_inv = 65536 / qln2; + const uint32_t qb = b / bertScale; + const uint32_t qc = c / (a*bertScale*bertScale); + rewriter.create(loc, qln2, 0, 0, 1, 0, qb, qc); + rewriter.create(loc, qln2_inv, 1, 0, 1, 0, qb, qc); + } + for (size_t i0 = 0; i0 < I0; i0++) for (size_t j0 = 0; j0 < J0; j0++) for (size_t k0 = 0; k0 < K0; k0++) { @@ -569,10 +871,17 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { b = rewriter.create(loc, rewriter.getI64Type(), B, offsetValue); } - inner(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, - k, padI, padJ, padK, strideA, strideB, strideD, strideC, - aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, - tileMatMulOp, rewriter); + if (dataflow == OUTPUT_STATIONARY) { + spTiledMatmulOs(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, + k, padI, padJ, padK, strideA, strideB, strideD, strideC, + aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, + tileMatMulOp, rewriter); + } else { // WS + spTiledMatmulWs(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, + k, padI, padJ, padK, strideA, strideB, strideD, strideC, + aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, + tileMatMulOp, rewriter); + } } IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); Value flushValue = rewriter.create( @@ -593,9 +902,9 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiTileMatMulLowering(LLVMTypeConverter &typeConverter, - int64_t dim, size_t sizeOfElemT, + int64_t dim, int64_t addrLen, size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), sizeOfElemT(sizeOfElemT), sizeOfAccT(sizeOfAccT) {} LogicalResult @@ -700,7 +1009,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { size_t tileI, tileJ, tileK; if (act == LAYERNORM || act == SOFTMAX) { tileI = 1; - tileJ = dimJPaded | dim; + tileJ = dimJPaded / dim; tileK = 1; } else { tileI = dimIPaded / dim < dbMaxTileIJ ? dimIPaded / dim : dbMaxTileIJ; @@ -744,18 +1053,21 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { #undef dbMatsInAcc #undef dbMaxTileIJ #undef dbMaxTileK + int dataflow = tileMatMulOp.getDataflow(); + tiledMatmulOuter(dimI, dimJ, dimK, aArrayindexCastOp, bArrayindexCastOp, dArrayindexCastOp, cArrayindexCastOp, strideA, strideB, strideD, strideC, aScaleFactor, bScaleFactor, dScaleFactor, tileI, tileJ, tileK, act, scale, bertScale, repeatingBias, - aTranspose, bTranspose, fullC, lowD, weightA, tileMatMulOp, + aTranspose, bTranspose, fullC, lowD, weightA, dataflow, tileMatMulOp, rewriter); return success(); }; private: int64_t dim; + int64_t addrLen; size_t sizeOfElemT; size_t sizeOfAccT; }; @@ -830,10 +1142,13 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void spTiledConv(int batchSize, int inDim, int inChannels, int outChannels, - int outDim, int poolOutDim, int stride, int padding, - int kernelDim, int kernelDilation, int poolSize, - int poolStride, int poolPadding, int batches, int porows, + void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, + int poolOutRowDim, int poolOutColDim, + int stride, int padding, int kernelDim, int kernelDilation, + int inStride, int weightStride, int outStride, + int poolSize, int poolStride, int poolPadding, + int batches, int porows, int pocols, int pochs, int krows, int kcols, int kchs, int lpad, int rpad, int upad, int dpad, int plpad, int prpad, int pupad, int pdpad, Value &input, Value &weights, @@ -842,7 +1157,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { bool transWeight1203, bool transWeight0132, bool noBias, bool noPool, bool downsample, bool inputDilated, bool dw, TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter + ) const { + + Location loc = tileConvOp.getLoc(); if (dw) { kchs = 1; pochs = 1; @@ -850,8 +1168,30 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int orows = porows * poolStride + poolSize - 1 - pupad - pdpad; const int ocols = pocols * poolStride + poolSize - 1 - plpad - prpad; + const int ochs = pochs; + + // Calculate image dimensions + // Note: "irows" and "icols" includes padding + const int dilatedKrows = krows + (kernelDilation - 1)*(krows - 1); + const int dilatedKcols = kcols + (kernelDilation - 1)*(kcols - 1); + int irows = orows * stride + dilatedKrows - 1; + int icols = ocols * stride + dilatedKcols - 1; + int irowsUnpadded = irows - upad - dpad; + int icolsUnpadded = icols - lpad - rpad; + const int ichs = kchs; +#define UNDILATED(x) ((inputDilated) ? (((x)+1)/2) : (x)) + + if (inputDilated) { + irowsUnpadded = (irowsUnpadded+1)/2; + icolsUnpadded = (icolsUnpadded+1)/2; + + irows = irowsUnpadded + UNDILATED(upad) + UNDILATED(dpad); + icols = icolsUnpadded + UNDILATED(lpad) + UNDILATED(rpad); + } + + #ifdef HAS_FIRST_LAYER_OPTIMIZATIONS const bool transposed = transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; @@ -864,19 +1204,326 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { #else const int maxPixelsPerRow = 1; #endif - gemminiLoopConvWs( - batchSize, inDim, inChannels, outChannels, outDim, poolOutDim, stride, - padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding, - batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, - dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, - input, noBias, noPool, downsample, wrot180, inputDilated, act, - transOutput1203, transWeight1203, transWeight0132, transInput3120, - maxPixelsPerRow, dw, tileConvOp, rewriter); + // Calculate spad address offsets + const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); + const int inChannelsPerBank = kchs / dim + (kchs % dim != 0); + const int bRows = transWeight0132 ? inChannelsPerBank * kcols * krows * ochs : + outChannelsPerBank * kcols * krows * kchs; + + static uint32_t dSpAddrRow = 0; + static uint32_t cSpAddrRow = 0; + + const uint32_t aSpAddrStart = 0; + const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - bRows; + const uint32_t dSpAddrStart = (1 << (addrLen - 1)) + dSpAddrRow; + const uint32_t cSpAddrStart = (3 << (addrLen - 2)) + cSpAddrRow; + + if (bias != 0) { + dSpAddrRow = (dSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; + } + + if (output != 0) { + cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; + } + if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { + gemminiLoopConvWs( + batchSize, inRowDim, inChannels, outChannels, outRowDim, + poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, + poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, + kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, + ocols, weights, output, bias, input, noBias, noPool, downsample, + wrot180, inputDilated, act, transOutput1203, transWeight1203, + transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, + rewriter); + return; + } + if (!noPool) { + llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n"; + return; + } + // Only rectangular convolutions will use the following C code + // mvin bias + const size_t maxBlockLen = MAX_BYTES / (dim * 1); + const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); + if (bias != NULL) { + // TODO we probably don't need quite this many nested loops for this part + const int maxOchsPerMvin = ochs < (int)(maxBlockLenAcc * dim) ? ochs : + maxBlockLenAcc * dim; + Value zeroValue = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + rewriter.create(loc, zeroValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); + for (int b = 0; b < batches; b++) + for (int orow = 0; orow < orows; orow++) + for (int ocol = 0; ocol < ocols; ocol += dim) { + const int I = ocols - ocol > dim ? dim : ocols - ocol; + for (int och = 0; och < ochs; och += maxOchsPerMvin) { + const int J = ochs - och > maxOchsPerMvin ? maxOchsPerMvin : ochs - och; + const uint32_t dSpAddr = dSpAddrStart + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; + if (noBias) { + gemminiMvinOffset(zeroValue, 0 * sizeOfAccT, dSpAddr, J, I, addrLen, rewriter); + } else { + gemminiMvinOffset(bias, och * sizeOfAccT, dSpAddr, J, I, addrLen, rewriter); + } + } + } + } + // mvin input + if (input != NULL){ + int maxChsPerMvin = ichs < (int)(maxBlockLen * dim) ? ichs : + maxBlockLen * dim; + if (transInput3120) { + maxChsPerMvin = batches < (int)(maxBlockLen * dim) ? batches : + maxBlockLen * dim; + } + const int dramStride = transInput3120 ? + batchSize * sizeOfElemT : + inChannels * sizeOfElemT; + const int spadStride = transInput3120 ? + ichs * (irows >> downsample) * (icols >> downsample) : + batches * (irows >> downsample) * (icols >> downsample); + Value strideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride << downsample)); + rewriter.create(loc, strideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 0, spadStride, maxPixelsPerRow); + const int b_it = transInput3120 ? maxChsPerMvin : 1; + const int ich_it = transInput3120 ? 1 : maxChsPerMvin; + for (int b = 0; b < batches; b += b_it) + for (int irow = -UNDILATED(upad); irow < irowsUnpadded + UNDILATED(dpad); irow += 1 + downsample) { + const int irowPadded = irow + UNDILATED(upad); + for (int icol = -UNDILATED(lpad); icol < icolsUnpadded + UNDILATED(rpad);) { + // TODO There might be some unnecessary mvins here at the edge of the image + int I = icolsUnpadded - icol > (dim << downsample) ? + (dim << downsample) : icolsUnpadded - icol; + if (icol < 0) { + I = -icol > dim ? dim : -icol; + } else if (icol >= icolsUnpadded) { + I = icolsUnpadded + UNDILATED(rpad) - icol > dim ? dim : icolsUnpadded + UNDILATED(rpad) - icol; + } + const int icolPadded = icol + UNDILATED(lpad); + for (int ich = 0; ich < ichs; ich += ich_it) { + int K = ichs - ich > maxChsPerMvin ? maxChsPerMvin : ichs - ich; + if (transInput3120) { + K = batches - b > maxChsPerMvin ? maxChsPerMvin : batches - b; + } +#define DS(x) ((x) >> (downsample)) + uint32_t aSpAddr = aSpAddrStart + (ich / dim) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + if (transInput3120) { + aSpAddr = aSpAddrStart + (b / dim) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + } + const bool is_zeros = irow < 0 || irow >= irowsUnpadded || icol < 0 || icol >= icolsUnpadded; + size_t offset = (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; + Value memAddr = input; + if (is_zeros) { + memAddr = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + offset = 0; + } else if (transInput3120) { + offset = (ich*inRowDim*inColDim + irow*inColDim + icol) * batchSize + b; + } + gemminiMvinOffset(memAddr, offset * sizeOfElemT, aSpAddr, K, I >> downsample, addrLen, rewriter); + } + icol += I; + } + } + } + // mvin weights + if (weights != NULL) { + int max_chs_per_mvin = ochs < (int)(maxBlockLen * dim) ? ochs : + maxBlockLen * dim; + if (transWeight0132) { + max_chs_per_mvin = kchs < (int)(maxBlockLen * dim) ? kchs : + maxBlockLen * dim; + } + size_t dramStride = weightStride * sizeOfElemT; + if (dw) { + dramStride = sizeOfElemT; + } else if (transWeight1203) { + dramStride = kernelDim * kernelDim * outChannels * sizeOfElemT; + } else if (transWeight0132) { + dramStride = inChannels * sizeOfElemT; + } + const size_t spadBlockStride = transWeight0132 ? + krows * kcols * ochs : krows * kcols * kchs; + Value dramStrideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride)); + rewriter.create(loc, dramStrideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 1, spadBlockStride); + + const size_t och_it = transWeight0132 ? dim : max_chs_per_mvin; + const size_t kch_it = transWeight0132 ? max_chs_per_mvin : dim; + for (int och = 0; och < ochs; och += och_it) { + for (int krow = 0; krow < krows; krow++) + for (int kcol = 0; kcol < kcols; kcol++) + for (int kch = 0; kch < kchs; kch += kch_it) { + int K = kchs - kch > dim ? dim : kchs - kch; + int J = ochs - och > max_chs_per_mvin ? max_chs_per_mvin : ochs - och; + if (transWeight0132) { + K = ochs - och > dim ? dim : ochs - och; + J = kchs - kch > max_chs_per_mvin ? max_chs_per_mvin : kchs - kch; + } + uint32_t bSpAddr = bSpAddrStart + (och / dim) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch; + if (transWeight0132) { + bSpAddr = bSpAddrStart + (kch / dim) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och; + } + size_t offset = (krow*kernelDim*inChannels + kcol*inChannels + kch) * weightStride + och; + if (dw) { + offset = krow * kernelDim + kcol; + } else if (transWeight1203) { + offset = (kch * kernelDim * kernelDim + krow * kernelDim + kcol) * outChannels + och; + } else if (transWeight0132) { + offset = (krow * kernelDim * outChannels + kcol * outChannels + och) * inChannels + kch; + } + gemminiMvinOffset(weights, offset * sizeOfElemT, bSpAddr, J, K, addrLen, rewriter); + } + } + } + // Compute + { + const int b_it = transInput3120 ? dim : 1; + const int ocol_it = transInput3120 ? 1 : (dim << inputDilated); + if (transInput3120) { + rewriter.create( + loc, /*dataflow = */ OUTPUT_STATIONARY, /*act = */ 0, /*shift = */ 0, + /*scale = */ llvm::APFloat((float)0), /*cStride = */ orows * ocols, + /*aStride = */ irows * icols, + /*aTranspose = */ 0, /*bTranspose*/ 0, + /*setOnlyStrides = */ true); + } + for (int och = 0; och < ochs; och += dim) { + for (int krow = 0; krow < krows; krow++) { + for (int kcol = 0; kcol < kcols; kcol += maxPixelsPerRow) { + for (int kch = 0; kch < kchs; kch += dim) { + bool newWeights = true; + for (int b = 0; b < batches; b += b_it) { + for (int orow = 0; orow < orows; orow++) { + // Skip some kernel rows due to input-dilation + if (inputDilated && + ((krow * kernelDilation + orow * stride - upad) % 2 != + 0)) { + continue; + } + for (int ocol = 0; ocol < ocols;) { + // Skip some cols dimensions due to input-dilation + if (inputDilated && + ((kcol + ocol * stride - lpad) % 2 != 0)) { + ocol++; + continue; + } + int irow = orow * stride + krow * kernelDilation; + int icol = ocol * stride + kcol * kernelDilation; + if (inputDilated) { + irow = (irow + 1) / 2; + icol = (icol + 1) / 2; + } + const int pixels = kcols - kcol > maxPixelsPerRow + ? maxPixelsPerRow + : kcols - kcol; + const uint32_t cSpAddr = + cSpAddrStart + + (och / dim) * batches * orows * ocols + + b * orows * ocols + orow * ocols + ocol; + // Over here, construct a new matrix + // + // Let us assume that we only ever operate on + // one pixel in one row. + // Thus, krows == kcols == 1 + // + // Then, for every set of I, J, and K values + // - I = ocols + // - J = ochs + // - K = kchs + int I = UNDILATED(ocols - ocol > (dim << inputDilated) + ? (dim << inputDilated) + : ocols - ocol); + const int J = ochs - och > dim ? dim : ochs - och; + const int K = + pixels * (kchs - kch > dim ? dim : kchs - kch); + if (transInput3120) { + I = batches - b > dim ? dim : batches - b; + } + uint32_t aSpAddr = + aSpAddrStart + + (kch / dim) * batches * DS(irows) * DS(icols) + + b * DS(irows) * DS(icols) + DS(irow) * DS(icols) + + DS(icol); + if (transInput3120) { + aSpAddr = aSpAddrStart + + (b / dim) * kchs * DS(irows) * DS(icols) + + kch * DS(irows) * DS(icols) + + DS(irow) * DS(icols) + DS(icol); + } + const int krow_ = wrot180 ? krows - krow - 1 : krow; + const int kcol_ = wrot180 ? kcols - kcol - 1 : kcol; + uint32_t bSpAddr = + bSpAddrStart + (och / dim) * krows * kcols * kchs + + krow_ * kcols * kchs + kcol_ * kchs + kch; + if (transWeight0132) { + bSpAddr = bSpAddrStart + + (kch / dim) * krows * kcols * ochs + + krow_ * kcols * ochs + kcol_ * ochs + och; + } + const uint32_t perSpAddr = + newWeights ? bSpAddr : GARBAGE_ADDR; + + Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); + Value iOp = rewriter.create(loc, rewriter.getI64IntegerAttr(I)); + Value jOp = rewriter.create(loc, rewriter.getI64IntegerAttr(J)); + Value kOp = rewriter.create(loc, rewriter.getI64IntegerAttr(K)); + Value perSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(perSpAddr)); + Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); + Value cSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cSpAddr)); + + rewriter.create(loc, perSpAddrOp, cSpAddrOp, kOp, jOp, iOp, jOp); + if (newWeights) { + rewriter.create(loc, aSpAddrOp, garbageAddrOp, iOp, kOp, iOp, jOp); + } else { + rewriter.create(loc, aSpAddrOp, garbageAddrOp, iOp, kOp, iOp, jOp); + } + ocol += ocol_it; + newWeights = false; + } + } + } + } + } + } + } + } +#undef DS +#undef UNDILATED + // mvout output + if (output != NULL) { + if (noPool) { + for (int b = 0; b < batches; b++) + for (int orow = 0; orow < orows; orow++) + for (int ocol = 0; ocol < ocols; ocol += dim) { + const int I = ocols - ocol > dim ? dim : ocols - ocol; + for (int och = 0; och < ochs; och += dim) { + const int J = ochs - och > dim ? dim : ochs - och; + const uint32_t cSpAddr = + cSpAddrStart + (och / dim) * batches * orows * ocols + + b * orows * ocols + orow * ocols + ocol; + size_t outOffset = + (b * outRowDim * outColDim + + orow * outColDim + ocol) * + outStride + + och; + if (transOutput1203) { + outOffset = + (orow * outColDim * batchSize + ocol * batchSize + + b) * + outChannels + + och; + } + gemminiMvoutOffset(output, outOffset * sizeOfElemT, cSpAddr, J, I, addrLen, rewriter); + } + } + } else { + printf("Pooling with rectangular convolutions is currently not supported.\n"); + exit(1); + } + } } - void tiledConv(int batchSize, int inDim, int inChannels, int outChannels, - int outDim, int stride, int inputDilation, int kernelDilation, - int padding, int kernelDim, bool wrot180, bool transOutput1203, + void tiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, int outChannels, + int outRowDim, int outColDim, int stride, int inputDilation, int kernelDilation, + int padding, int kernelDim, + int inStride, int weightStride, int outStride, + bool wrot180, bool transOutput1203, bool transInput3120, bool transWeight1203, bool transWeight0132, int batches, int porows, int pocols, int pochs, int krows, int kcols, int kchs, const Value &input, @@ -891,7 +1538,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { poolStride = 1; poolPadding = 0; } - const bool downsample = stride == 2 && kernelDim == 1 && inDim % 2 == 0 && + const bool downsample = stride == 2 && kernelDim == 1 && inRowDim % 2 == 0 && inColDim % 2 == 0 && padding == 0 && noPool && inputDilation == 1 && !transInput3120; const int inputDilated = inputDilation == 2; @@ -903,18 +1550,24 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(stDramStride)); rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); rewriter.create( - loc, /*dataflow = */ 1, /*act = */ 0, /*shift = */ 0, + loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, /*aStride = */ stride >> downsample, /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, /*setOnlyStrides = */ false); - const int poolOutDim = - (outDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int dilatedInDim = inDim + (inputDilation - 1) * (inDim - 1); + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int dilatedInRowDim = inRowDim + (inputDilation - 1) * (inRowDim - 1); + const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); + + int porowEnd = poolOutRowDim; + for (int b = 0; b < batchSize; b += batches) { - for (int porow = 0; porow < poolOutDim; porow += porows) { + for (int porow = 0; porow < porowEnd; porow += porows) { const int orow = porow * poolStride - poolPadding; - for (int pocol = 0; pocol < poolOutDim; pocol += pocols) { + for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { const int ocol = pocol * poolStride - poolPadding; for (int poch = 0; poch < outChannels; poch += pochs) { for (int krow = 0; krow < kernelDim; krow += krows) { @@ -929,8 +1582,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { for (int kch = 0; kch < inChannels; kch += kchs) { TypedAttr offsetAttr = - rewriter.getI64IntegerAttr(((b * poolOutDim * poolOutDim + - porow * poolOutDim + pocol) * + rewriter.getI64IntegerAttr(((b * poolOutRowDim * poolOutColDim + + porow * poolOutColDim + pocol) * outChannels + poch) * sizeOfElemT); @@ -941,7 +1594,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); if (transOutput1203) { offsetAttr = rewriter.getI64IntegerAttr( - ((porow * poolOutDim * batchSize + pocol * batchSize + + ((porow * poolOutColDim * batchSize + pocol * batchSize + b) * outChannels + poch) * @@ -972,9 +1625,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int batches_ = batchSize - b > batches ? batches : batchSize - b; const int porows_ = - poolOutDim - porow > porows ? porows : poolOutDim - porow; + poolOutRowDim - porow > porows ? porows : poolOutRowDim - porow; const int pocols_ = - poolOutDim - pocol > pocols ? pocols : poolOutDim - pocol; + poolOutColDim - pocol > pocols ? pocols : poolOutColDim - pocol; const int pochs_ = outChannels - poch > pochs ? pochs : outChannels - poch; const int krows_ = @@ -989,10 +1642,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int plpad = ocol < 0 ? -ocol : 0; const int prpad = - ocol + ocols_ > outDim ? ocol + ocols_ - outDim : 0; + ocol + ocols_ > outColDim ? ocol + ocols_ - outColDim : 0; const int pupad = orow < 0 ? -orow : 0; const int pdpad = - orow + orows_ > outDim ? orow + orows_ - outDim : 0; + orow + orows_ > outRowDim ? orow + orows_ - outRowDim : 0; const int dilatedKrows_ = krows_ + (kernelDilation - 1) * (krows_ - 1); @@ -1005,12 +1658,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; int lpad = icol < 0 ? -icol : 0; - int rpad = icol + icols_ > dilatedInDim - ? icol + icols_ - dilatedInDim + int rpad = icol + icols_ > dilatedInColDim + ? icol + icols_ - dilatedInColDim : 0; int upad = irow < 0 ? -irow : 0; - int dpad = irow + irows_ > dilatedInDim - ? irow + irows_ - dilatedInDim + int dpad = irow + irows_ > dilatedInRowDim + ? irow + irows_ - dilatedInRowDim : 0; if (inputDilated) { @@ -1063,8 +1716,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); } offsetAttr = rewriter.getI64IntegerAttr( - ((b * inDim * inDim + - ((irow + upad) >> inputDilated) * inDim + + ((b * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + ((icol + lpad) >> inputDilated)) * inChannels + kch) * @@ -1076,8 +1729,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); if (transInput3120) { offsetAttr = rewriter.getI64IntegerAttr( - ((kch * inDim * inDim + - ((irow + upad) >> inputDilated) * inDim + + ((kch * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + ((icol + lpad) >> inputDilated)) * batchSize + b) * @@ -1087,9 +1740,11 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { input, offsetValue); } - spTiledConv(batchSize, inDim, inChannels, outChannels, outDim, - poolOutDim, stride, padding, kernelDim, - kernelDilation, poolSize, poolStride, poolPadding, + spTiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, outColDim, + poolOutRowDim, poolOutColDim, stride, padding, kernelDim, + kernelDilation, + inStride, weightStride, outStride, + poolSize, poolStride, poolPadding, batches_, porows_, pocols_, pochs_, krows_, kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, in, weightsSlice, out, bias_, @@ -1153,9 +1808,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiTileConvLowering(LLVMTypeConverter &typeConverter, - int64_t dim, size_t sizeOfElemT, + int64_t dim, int64_t addrLen, size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), sizeOfElemT(sizeOfElemT), sizeOfAccT(sizeOfAccT) {} LogicalResult matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, @@ -1165,29 +1820,19 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { Value weights = tileConvOp.getWeights(); Value bias = tileConvOp.getBias(); MemRefType inputType = input.getType().dyn_cast(); - MemRefType outputType = output.getType().dyn_cast(); - MemRefType weightsType = weights.getType().dyn_cast(); MemRefType biasType = bias.getType().dyn_cast(); ArrayRef inputShape = inputType.getShape(); - ArrayRef outputShape = outputType.getShape(); - ArrayRef weightsShape = weightsType.getShape(); ArrayRef biasShape = biasType.getShape(); - // inDim - if (inputShape[1] != inputShape[2]) { - llvm::outs() << "inDim error.\n"; - return failure(); - } - // outChannels - if (biasShape[0] != outputShape[1] || biasShape[0] != weightsShape[1]) { - llvm::outs() << "outChannels error.\n"; - return failure(); - } - Value outDimValue = tileConvOp.getOutDim(); - int outDim = getNumberFromValue(outDimValue); + + Value outRowDimValue = tileConvOp.getOutRowDim(); + int outRowDim = getNumberFromValue(outRowDimValue); + Value outColDimValue = tileConvOp.getOutColDim(); + int outColDim = getNumberFromValue(outColDimValue); Value kernelDimValue = tileConvOp.getKernelDim(); int kernelDim = getNumberFromValue(kernelDimValue); int batchSize = inputShape[0]; - int inDim = inputShape[1]; + int inRowDim = inputShape[1]; + int inColDim = inputShape[2]; int inChannels = inputShape[3]; int outChannels = biasShape[0]; int stride = tileConvOp.getStride(); @@ -1228,13 +1873,15 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { poolStride = 1; poolPadding = 0; } - const int poolOutDim = - (outDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && - noPool && inDim % 2 == 0; - int args[] = {batchSize, poolOutDim, poolOutDim, outChannels, + noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; + int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, kernelDim, kernelDim, inChannels}; - const int maxArgs[] = {batchSize, poolOutDim, poolOutDim, outChannels, + const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, kernelDim, kernelDim, inChannels}; const int orowsIdx = 1; const int ocolsIdx = 2; @@ -1342,9 +1989,15 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int krows = args[4]; const int kcols = args[5]; const int kchs = args[6]; - tiledConv(batchSize, inDim, inChannels, outChannels, outDim, stride, - inputDilation, kernelDilation, padding, kernelDim, wrot180, - transOutput1203, transInput3120, transWeight1203, transWeight0132, + + const int inStride = inChannels; + const int outStride = outChannels; + const int weightStride = outChannels; + tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, outColDim, + stride, + inputDilation, kernelDilation, padding, kernelDim, + inStride, weightStride, outStride, + wrot180, transOutput1203, transInput3120, transWeight1203, transWeight0132, batches, orows, ocols, ochs, krows, kcols, kchs, inputIndexCastOp, weightsIndexCastOp, biasIndexCastOp, outputIndexCastOp, act, scale, poolSize, noPool ? 0 : poolStride, poolPadding, tileConvOp, @@ -1354,6 +2007,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { private: int64_t dim; + int64_t addrLen; + size_t sizeOfElemT; size_t sizeOfAccT; }; @@ -1368,15 +2023,18 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter); patterns.add(converter); patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter); + patterns.add(converter); patterns.add(converter, dim, addrLen); patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter, addrLen); - patterns.add(converter, dim, sizeOfElemT, + patterns.add(converter, dim, addrLen, sizeOfElemT, sizeOfAccT); - patterns.add(converter, dim, sizeOfElemT, + patterns.add(converter, dim, addrLen, sizeOfElemT, sizeOfAccT); } @@ -1384,15 +2042,15 @@ void mlir::configureGemminiegalizeForExportTarget( LLVMConversionTarget &target) { target.addLegalOp< Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, - Mvin_IntrOp, Mvout_IntrOp, Preload_IntrOp, ComputePreloaded_IntrOp, + Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, LoopConvWsConfig4_IntrOp, - LoopConvWsConfig5_IntrOp, LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp>(); - target.addIllegalOp(); + target.addIllegalOp(); + TileConvOp, ConfigNormOp>(); }