From f23cb7d04117c339d7e4ac95fa3e4f570d0e3fcd Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 30 Jun 2023 22:12:22 +0800 Subject: [PATCH 01/51] add support for lowering config_norm handle act in func tiledMatmulOuter --- .../llvm/IR/IntrinsicsRISCVBuddyExt.td | 3 + .../Target/RISCV/RISCVInstrInfoBuddyExt.td | 9 +++ midend/include/Dialect/Gemmini/Gemmini.td | 21 +++++- midend/include/Dialect/Gemmini/Transform.h | 3 + .../Transforms/LegalizeForLLVMExport.cpp | 74 ++++++++++++++++++- 5 files changed, 107 insertions(+), 3 deletions(-) diff --git a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td index a474ca956b..9e96a395f7 100644 --- a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td +++ b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -35,6 +35,9 @@ def int_riscv_config_st : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; let TargetPrefix = "riscv" in def int_riscv_config_ex : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; +let TargetPrefix = "riscv" in +def int_riscv_config_norm : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; + let TargetPrefix = "riscv" in def int_riscv_preload : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; diff --git a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td index a45a8ff808..12cdd74f5a 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -65,6 +65,12 @@ def CONFIG_EX : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs), let rd = 0; } +let Predicates = [HasBuddyExt] in +def CONFIG_NORM : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs), + (ins GPR:$rs1, GPR:$rs2), "config_norm", "$rs1, $rs2"> { + let rd = 0; +} + let hasSideEffects = 1, mayLoad = 1, mayStore =1, Predicates = [HasBuddyExt] in def PRELOAD : RVInstR<0b0000110, 0b011,OPC_CUSTOM_3,(outs), (ins GPR:$rs1, GPR:$rs2), "preload", "$rs1, $rs2">{ @@ -179,6 +185,9 @@ def : Pat<(int_riscv_config_st GPR:$rs1, GPR:$rs2), (CONFIG_ST GPR:$rs1, GPR:$rs let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_config_ex GPR:$rs1, GPR:$rs2), (CONFIG_EX GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_config_norm GPR:$rs1, GPR:$rs2), (CONFIG_NORM GPR:$rs1, GPR:$rs2)>; + let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_preload GPR:$rs1, GPR:$rs2), (PRELOAD GPR:$rs1, GPR:$rs2)>; diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index df7393d3e9..ec7c2685b0 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -116,6 +116,22 @@ def ConfigExOp : Gemmini_Op<"config_ex"> { let assemblyFormat = "attr-dict"; } +def ConfigNormOp : Gemmini_Op<"config_norm"> { + let summary = "ConfigNormOp configures TODO pipeline"; + let description = [{ + ConfigNormOp configures TODO pipeline + }]; + let arguments = (ins DefaultValuedAttr:$qConst, + DefaultValuedAttr:$qConstType, + DefaultValuedAttr:$setStatsIdOnly, + DefaultValuedAttr:$actMsb, + DefaultValuedAttr:$StatsId, + DefaultValuedAttr:$iguluQb, + DefaultValuedAttr:$iguluQc); + let assemblyFormat = "attr-dict"; + +} + def MvinOp : Gemmini_Op<"mvin"> { let summary = "Load operation"; let description = [{ @@ -303,7 +319,10 @@ def Gemmini_ConifgLd_IntrOp : Gemmini_IntrOpBase<"config_ld">, def Gemmini_ConfigSt_IntrOp : Gemmini_IntrOpBase<"config_st">, Arguments<(ins LLVM_Type, LLVM_Type)>; -def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, +def Gemmini_ConfigEX_IntrOp : Gemmini_IntrOpBase<"config_ex">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +def Gemmini_ConfigNorm_IntrOp : Gemmini_IntrOpBase<"config_norm">, Arguments<(ins LLVM_Type, LLVM_Type)>; def Gemmini_Preload_IntrOp : Gemmini_IntrOpBase<"preload">, diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index f0908b4887..a186409aca 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -24,6 +24,9 @@ #define CONFIG_LD 1 #define CONFIG_ST 2 #define CONFIG_EX 0 +#define CONFIG_BERT 3 +#define DIM 16 +#define ADDR_LEN 32 #define ACC_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index de54cc8d44..d6f9e24f83 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -173,6 +173,46 @@ struct GemminiConfigExLowering : public ConvertOpToLLVMPattern { } }; +struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ConfigNormOp configNormOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = configNormOp.getLoc(); + // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, + // (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG) + uint64_t rs1 = (uint64_t )((uint32_t )configNormOp.getQConst() << 32) | + (configNormOp.getQConstType() & 1) << 18 | + (configNormOp.getSetStatsIdOnly() & 1) << 17 | + (configNormOp.getActMsb() & 1) << 16 | + configNormOp.getStatsId() << 8 | CONFIG_BERT; + uint64_t rs2 = ((uint64_t)((uint32_t)(configNormOp.getIguluQc())) << 32) | ((uint64_t)((uint32_t)(configNormOp.getIguluQb()))); + Value rs1Value = rewriter.create( + loc, rewriter.getI64IntegerAttr(rs1)); + Value rs2Value = rewriter.create( + loc, rewriter.getI64IntegerAttr(rs2)); + rewriter.replaceOpWithNewOp(configNormOp, rs1Value, + rs2Value); + // float scale = configNormOp.getSysAccScale().convertToFloat(); + // uint64_t rs1 = + // configNormOp. + // uint64_t rs1 = + // (uint64_t)acc_scale_t_to_acc_scale_t_bits(scale) << 32 | + // configNormOp.getQ() << 16 | configNormOp.getBTranspose() << 9 | + // configNormOp.getATranspose() << 8 | configNormOp.getSetOnlyStrides() << 7 | + // configNormOp.getSysAct() << 3 | configNormOp.getDataflow() << 2 | CONFIG_EX; + + // uint64_t rs2 = configNormOp.getCStride() << 48 | configNormOp.getSysShift(); + // Value rs1Value = rewriter.create( + // loc, rewriter.getI64IntegerAttr(rs1)); + // Value rs2Value = rewriter.create( + // loc, rewriter.getI64IntegerAttr(rs2)); + // rewriter.replaceOpWithNewOp(configNormOp, rs1Value, + // rs2Value); + return success(); + } +}; + struct GemminiMvinLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiMvinLowering(LLVMTypeConverter &typeConverter, @@ -493,6 +533,35 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(strideD * sizeofD)); rewriter.create(loc, strideValue, llvm::APFloat((float)dScaleFactor), lowD, 2); + + /* + Add config norm op + */ + + // acc_scale_t => acc_scale_t + if (act == IGELU) { + const acc_scale_t sqrt_2 = 1.41421356237; + const acc_scale_t S = bertScale; + const acc_scale_t S_erf = (-0.2888 * ((S*S) / 2)); + + const acc_t qb = -1.769 / (S / sqrt_2); + const acc_t qc = 1.0 / S_erf; + rewriter.create(loc, 0, 0, 0, 0,0, qb, qc); + } + + if (act == SOFTMAX) { + const scale_t a = 0.3585; + const scale_t b = 1.353; + const scale_t c = 0.344; + + const acc_t qln2 = (int) (0.693147 / bertScale); + const acc_t qln2_inv = 65536 / qln2; + const acc_t qb = b / bertScale; + const acc_t qc = c / (a*bertScale*bertScale); + rewriter.create(loc, qln2, 0, 0, 1, 0, qb, qc); + rewriter.create(loc, qln2_inv, 1, 0, 1, 0, qb, qc); + } + for (size_t i0 = 0; i0 < I0; i0++) for (size_t j0 = 0; j0 < J0; j0++) for (size_t k0 = 0; k0 < K0; k0++) { @@ -1370,6 +1439,7 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter); + patterns.add(converter); patterns.add(converter, dim, addrLen); patterns.add(converter, addrLen); patterns.add(converter, addrLen); @@ -1390,9 +1460,9 @@ void mlir::configureGemminiegalizeForExportTarget( LoopWsConfigStridesAB_IntrOp, LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, LoopConvWsConfig4_IntrOp, - LoopConvWsConfig5_IntrOp, LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp>(); + LoopConvWsConfig5_IntrOp, LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); target.addIllegalOp(); + TileConvOp, ConfigNormOp>(); } From c81f0fd3e59885f11ca42c5e98259c1d409280c4 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 3 Jul 2023 21:59:22 +0800 Subject: [PATCH 02/51] finish matmul os but has bug --- midend/include/Dialect/Gemmini/Gemmini.td | 3 +- midend/include/Dialect/Gemmini/Transform.h | 5 + .../Transforms/LegalizeForLLVMExport.cpp | 218 ++++++++++++++++-- 3 files changed, 201 insertions(+), 25 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index ec7c2685b0..c0f4a7583f 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -254,7 +254,8 @@ def TileMatMulOp : Gemmini_Op<"tile_matmul"> { DefaultValuedAttr:$bTranspose, DefaultValuedAttr:$fullC, DefaultValuedAttr:$lowD, - DefaultValuedAttr:$weightA); + DefaultValuedAttr:$weightA, + DefaultValuedAttr:$dataflow); let assemblyFormat = [{ $aArray $bArray $cArray $dArray attr-dict `:` type($aArray) type($bArray) type($cArray) type($dArray) diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index a186409aca..fba87712ce 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -25,6 +25,11 @@ #define CONFIG_ST 2 #define CONFIG_EX 0 #define CONFIG_BERT 3 + +#define GARBAGE_ADDR ((uint32_t)(-1)) +#define OUTPUT_STATIONARY 0 +#define WEIGHT_STATIONARY 1 + #define DIM 16 #define ADDR_LEN 32 #define ACC_SCALE_IDENTITY 1.0 diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index d6f9e24f83..540abbd1db 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -193,22 +193,6 @@ struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern loc, rewriter.getI64IntegerAttr(rs2)); rewriter.replaceOpWithNewOp(configNormOp, rs1Value, rs2Value); - // float scale = configNormOp.getSysAccScale().convertToFloat(); - // uint64_t rs1 = - // configNormOp. - // uint64_t rs1 = - // (uint64_t)acc_scale_t_to_acc_scale_t_bits(scale) << 32 | - // configNormOp.getQ() << 16 | configNormOp.getBTranspose() << 9 | - // configNormOp.getATranspose() << 8 | configNormOp.getSetOnlyStrides() << 7 | - // configNormOp.getSysAct() << 3 | configNormOp.getDataflow() << 2 | CONFIG_EX; - - // uint64_t rs2 = configNormOp.getCStride() << 48 | configNormOp.getSysShift(); - // Value rs1Value = rewriter.create( - // loc, rewriter.getI64IntegerAttr(rs1)); - // Value rs2Value = rewriter.create( - // loc, rewriter.getI64IntegerAttr(rs2)); - // rewriter.replaceOpWithNewOp(configNormOp, rs1Value, - // rs2Value); return success(); } }; @@ -469,8 +453,185 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, - scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, + void spTiledMatmulWs(Value &a, Value &b, Value &d, Value &c, + scale_t aScaleFactor, scale_t bScaleFactor, + scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, + size_t padI, size_t padJ, size_t padK, size_t strideA, + size_t strideB, size_t strideD, size_t strideC, + bool aTranspose, bool bTranspose, bool fullC, bool lowD, + bool noBias, bool repeatingBias, int act, + TileMatMulOp &tileMatMulOp, + ConversionPatternRewriter &rewriter) const { + + gemminiLoopWs(i, j, k, padI, padJ, padK, a, b, d, c, strideA, strideB, + repeatingBias ? 0 : strideD, strideC, aTranspose, bTranspose, + fullC, lowD, !noBias, act, tileMatMulOp, rewriter); + } + + // Tiling functions + void spTiledMatmulOs(Value &a, Value &b, Value &d, Value &c, + scale_t aScaleFactor, scale_t bScaleFactor, + scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, + size_t padI, size_t padJ, size_t padK, size_t strideA, + size_t strideB, size_t strideD, size_t strideC, + bool aTranspose, bool bTranspose, bool fullC, bool lowD, + bool noBias, bool repeatingBias, int act, + TileMatMulOp &tileMatMulOp, + ConversionPatternRewriter &rewriter) const { + const uint32_t aSpAddrStart = 0; + const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - k * j * DIM; + const uint32_t dSpAddrStart = 1 << (ADDR_LEN - 1); + const uint32_t cSpAddrStart = + (3 << (ADDR_LEN - 2)) | (fullC << (ADDR_LEN - 3)); + + const int aBlocks = k <= MAX_BLOCK_LEN ? k : MAX_BLOCK_LEN; + const int bBlocks = j <= MAX_BLOCK_LEN ? j : MAX_BLOCK_LEN; + const int dBlocks = j <= MAX_BLOCK_LEN_ACC ? j : MAX_BLOCK_LEN_ACC; + + Location loc = a.getLoc(); + uint64_t dAddrInt = getNumberFromValue(d); + + // Move-in D + if (dAddrInt != 0 && !noBias) { + const size_t dStride = repeatingBias ? 0 : strideD * sizeof(acc_t); + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dStride)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)dScaleFactor)); + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0 += dBlocks) { + const size_t biasRow = repeatingBias ? 0 : i0; +// const acc_t *const dDramAddr = +// (acc_t *)d + (biasRow * strideD + j0) * DIM; + const size_t offset = (biasRow * strideD + j0) * DIM * sizeof (acc_t); + const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * DIM; + + const size_t blocks = j0 + dBlocks <= j ? dBlocks : j - j0; + + const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); + const size_t rows = DIM - (i0 == i - 1 ? padI : 0); + + gemminiMvinOffset(d, offset, dSpAddrAcc, cols, rows, rewriter); + } + } + } + + // Move-in B + Value strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideB)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)bScaleFactor)); + // gemmini_extended_config_ld(strideB * sizeof(elem_t), bScaleFactor); + for (size_t j0 = 0; j0 < j; j0 += bBlocks) { + for (size_t k0 = 0; k0 < k; k0++) { +// const elem_t *const B_dram_addr = B + (k0 * strideB + j0) * DIM; + const size_t offset = (k0 * strideB + j0) * DIM * sizeof (elem_t); + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; + const size_t blocks = j0 + bBlocks <= j ? bBlocks : j - j0; + const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); + const size_t rows = DIM - (k0 == k - 1 ? padK : 0); + gemminiMvinOffset(b, offset, bSpAddr, cols, rows, rewriter); +// gemmini_extended_mvin(B_dram_addr, bSpAddr, cols, rows); + } + } + + // Move-in A + strideValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideA)); + rewriter.create(loc, strideValue, + llvm::APFloat((float)aScaleFactor)); + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t k0 = 0; k0 < k; k0 += aBlocks) { +// const elem_t *const A_dram_addr = A + (i0 * strideA + k0) * DIM; + const size_t offset = (i0 * strideA + k0) * DIM * sizeof (elem_t); + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; + const size_t blocks = k0 + aBlocks <= k ? aBlocks : k - k0; + const size_t cols = blocks * DIM - (k0 + blocks >= k ? padK : 0); + const size_t rows = DIM - (i0 == i - 1 ? padI : 0); + gemminiMvinOffset(a, offset, aSpAddr, cols, rows, rewriter); +// gemmini_extended_mvin(A_dram_addr, aSpAddr, cols, rows); + } + } + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0++) { + const uint32_t cSpAddr = cSpAddrStart + (i0 * j + j0) * DIM; + + for (size_t k0 = 0; k0 < k; k0++) { + + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; + + uint32_t outSpAddr = k0 == k - 1 ? cSpAddr : GARBAGE_ADDR; + + // If we're not using a bias, then we want to overwrite what's in the + // accumulator, rather than writing over it + + int noBiasNewMatrix = noBias && dAddrInt != 0 && k0 == k - 1; + if (noBiasNewMatrix) { + outSpAddr &= ~(1 << (ADDR_LEN - 2)); + } + + const size_t aCols = DIM - (k0 == k - 1 ? padK : 0); + const size_t aRows = DIM - (i0 == i - 1 ? padI : 0); + const size_t bCols = DIM - (j0 == j - 1 ? padJ : 0); + const size_t bRows = DIM - (k0 == k - 1 ? padK : 0); + const size_t cCols = DIM - (j0 == j - 1 ? padJ : 0); + const size_t cRows = DIM - (i0 == i - 1 ? padI : 0); + + Value aColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aCols)); + Value aRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aRows)); + Value bColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bCols)); + Value bRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bRows)); + Value cColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cCols)); + Value cRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cRows)); + + Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); + Value bSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bSpAddr)); + Value outSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(outSpAddr)); + + + Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); + + Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(DIM)); + +// gemmini_extended_preload(GARBAGE_ADDR, outSpAddr, DIM, DIM, cCols, +// cRows); + rewriter.create(loc, garbageAddrOp, outSpAddrOp, dimOp, + dimOp, cRowsOp, cColsOp); + + if (k0 == 0) { // First iteration + rewriter.create(loc, aSpAddrOp, bSpAddrOp, aRowsOp, aColsOp, bRowsOp, bColsOp); + + } else { // All other iterations + rewriter.create(loc, aSpAddrOp, bSpAddrOp, aRowsOp, aColsOp, bRowsOp, bColsOp); + } + } + } + } + } + + void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, + ConversionPatternRewriter &rewriter) const{ + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + Value spadAddrValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(SpAddr)); + uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | + (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); + + } + + void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, size_t padI, size_t padJ, size_t padK, size_t strideA, size_t strideB, size_t strideD, size_t strideC, bool aTranspose, bool bTranspose, bool fullC, bool lowD, bool noBias, @@ -490,7 +651,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { size_t tileK, int act, acc_scale_t scale, acc_scale_t bertScale, bool repeatingBias, bool aTranspose, bool bTranspose, bool fullC, bool lowD, - uint8_t weightA, TileMatMulOp &tileMatMulOp, + uint8_t weightA, int dataflow, TileMatMulOp &tileMatMulOp, ConversionPatternRewriter &rewriter) const { const size_t dimIPadded = (dimI / dim + (dimI % dim != 0)) * dim; const size_t dimJPadded = (dimJ / dim + (dimJ % dim != 0)) * dim; @@ -638,10 +799,17 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { b = rewriter.create(loc, rewriter.getI64Type(), B, offsetValue); } - inner(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, - k, padI, padJ, padK, strideA, strideB, strideD, strideC, - aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, - tileMatMulOp, rewriter); + if (dataflow == OUTPUT_STATIONARY) { + spTiledMatmulOs(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, + k, padI, padJ, padK, strideA, strideB, strideD, strideC, + aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, + tileMatMulOp, rewriter); + } else { // WS + spTiledMatmulWs(a, b, pre, out, aScaleFactor, bScaleFactor, dScaleFactor, i, j, + k, padI, padJ, padK, strideA, strideB, strideD, strideC, + aTranspose, bTranspose, fullC, lowD, noBias, repeatingBias, act, + tileMatMulOp, rewriter); + } } IntegerAttr flushAttr = rewriter.getI64IntegerAttr(0); Value flushValue = rewriter.create( @@ -813,12 +981,14 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { #undef dbMatsInAcc #undef dbMaxTileIJ #undef dbMaxTileK + int dataflow = tileMatMulOp.getDataflow(); + tiledMatmulOuter(dimI, dimJ, dimK, aArrayindexCastOp, bArrayindexCastOp, dArrayindexCastOp, cArrayindexCastOp, strideA, strideB, strideD, strideC, aScaleFactor, bScaleFactor, dScaleFactor, tileI, tileJ, tileK, act, scale, bertScale, repeatingBias, - aTranspose, bTranspose, fullC, lowD, weightA, tileMatMulOp, + aTranspose, bTranspose, fullC, lowD, weightA, dataflow, tileMatMulOp, rewriter); return success(); }; From 4db570f532eb6fefae6d6444d252bb6d889dbd62 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 4 Jul 2023 00:09:23 +0800 Subject: [PATCH 03/51] Fix compile-time errors --- .../Transforms/LegalizeForLLVMExport.cpp | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 540abbd1db..33affefa98 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -489,10 +489,12 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { const int dBlocks = j <= MAX_BLOCK_LEN_ACC ? j : MAX_BLOCK_LEN_ACC; Location loc = a.getLoc(); - uint64_t dAddrInt = getNumberFromValue(d); + bool dAddrNull = llvm::dyn_cast(d.getDefiningOp()) && getNumberFromValue(d) == 0; + bool cAddrNull = llvm::dyn_cast(c.getDefiningOp()) && getNumberFromValue(c) == 0; +// uint64_t dAddrInt = getNumberFromValue(d); // Move-in D - if (dAddrInt != 0 && !noBias) { + if (!dAddrNull && !noBias) { const size_t dStride = repeatingBias ? 0 : strideD * sizeof(acc_t); Value strideValue = rewriter.create( loc, rewriter.getI64IntegerAttr(dStride)); @@ -569,7 +571,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { // If we're not using a bias, then we want to overwrite what's in the // accumulator, rather than writing over it - int noBiasNewMatrix = noBias && dAddrInt != 0 && k0 == k - 1; + int noBiasNewMatrix = noBias && !dAddrNull && k0 == k - 1; if (noBiasNewMatrix) { outSpAddr &= ~(1 << (ADDR_LEN - 2)); } @@ -611,6 +613,22 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { } } } + // Move-out C + if (!cAddrNull) { + const size_t sizeof_C = fullC ? sizeof(acc_t) : sizeof(elem_t); + + for (size_t i0 = 0; i0 < i; i0++) { + for (size_t j0 = 0; j0 < j; j0++) { + const size_t offset = (i0 *strideC + j0)*DIM*sizeof_C; + const uint32_t cSpAddr = cSpAddrStart + (i0 *j + j0)*DIM; + + const size_t cCols = DIM - (j0 == j - 1 ? padJ : 0); + const size_t cRows = DIM - (i0 == j - 1 ? padI : 0); + + gemminiMvoutOffset(c, offset, cSpAddr, cCols, cRows, rewriter); + } + } + } } void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, @@ -628,7 +646,23 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { Value spad = rewriter.create( loc, rewriter.getI64IntegerAttr(spadAddrInt)); rewriter.create(loc, configPtr, spad); + } + void gemminiMvoutOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, + ConversionPatternRewriter &rewriter) const{ + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + Value spadAddrValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(SpAddr)); + uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | + (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); } void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, From b5651c0c4b7588d8f0ee162c730beda1cfda48b1 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 4 Jul 2023 16:56:54 +0800 Subject: [PATCH 04/51] fix bug & add test for tiled-matmul-os --- examples/GemminiDialect/makefile | 9 +++++++ examples/GemminiDialect/tile-matmul-os.mlir | 26 +++++++++++++++++++ .../Transforms/LegalizeForLLVMExport.cpp | 2 +- 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 examples/GemminiDialect/tile-matmul-os.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 7bc1c3dc22..3b96e8487d 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -76,6 +76,15 @@ tile-matmul-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-matmul-os-run: + @${BUDDY_OPT} ./tile-matmul-os.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-run: @${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-matmul-os.mlir b/examples/GemminiDialect/tile-matmul-os.mlir new file mode 100644 index 0000000000..576c8f9443 --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-os.mlir @@ -0,0 +1,26 @@ +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.alloc() {alignment = 16} : memref<64x64xi8> + %bArray = memref.alloc() {alignment = 16}: memref<64x64xi8> + %cArray = memref.alloc() {alignment = 16}: memref<64x64xi8> + %dArray = memref.alloc() {alignment = 64} : memref<64x64xi32> + %dim = memref.dim %aArray, %c0 : memref<64x64xi8> + scf.for %i = %c0 to %dim step %c1 { + scf.for %j = %c0 to %dim step %c1 { + memref.store %i1I8, %aArray[%i, %j] : memref<64x64xi8> + memref.store %i1I8, %bArray[%i, %j] : memref<64x64xi8> + memref.store %i2I32, %dArray[%i, %j] : memref<64x64xi32> + } + } + gemmini.print %aArray : memref<64x64xi8> + gemmini.print %bArray : memref<64x64xi8> + gemmini.print %dArray : memref<64x64xi32> + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=0} : memref<64x64xi8> memref<64x64xi8> memref<64x64xi8> memref <64x64xi32> + gemmini.print %cArray : memref<64x64xi8> + return %i0 : i8 +} diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 33affefa98..bf5d7fcbdb 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -710,7 +710,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { const size_t sizeofC = fullC ? sizeOfAccT : sizeOfElemT; Location loc = tileMatMulOp.getLoc(); llvm::APFloat accScaleIdentity((float)ACC_SCALE_IDENTITY); - rewriter.create(loc, /*dataflow = */ 1, /*sysAct = */ act & 3, + rewriter.create(loc, /*dataflow = */ dataflow, /*sysAct = */ act & 3, /* sysShift = */ 0, accScaleIdentity); Value strideValue = rewriter.create( loc, rewriter.getI64IntegerAttr(strideC * sizeofC)); From 28e5d562470426f2dd5cc16c26eeacc7c106ce92 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 4 Jul 2023 18:36:14 +0800 Subject: [PATCH 05/51] Clear comments --- .../Transforms/LegalizeForLLVMExport.cpp | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index bf5d7fcbdb..def3d2ac94 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -179,8 +179,6 @@ struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern matchAndRewrite(ConfigNormOp configNormOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = configNormOp.getLoc(); - // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, - // (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG) uint64_t rs1 = (uint64_t )((uint32_t )configNormOp.getQConst() << 32) | (configNormOp.getQConstType() & 1) << 18 | (configNormOp.getSetStatsIdOnly() & 1) << 17 | @@ -491,7 +489,6 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { Location loc = a.getLoc(); bool dAddrNull = llvm::dyn_cast(d.getDefiningOp()) && getNumberFromValue(d) == 0; bool cAddrNull = llvm::dyn_cast(c.getDefiningOp()) && getNumberFromValue(c) == 0; -// uint64_t dAddrInt = getNumberFromValue(d); // Move-in D if (!dAddrNull && !noBias) { @@ -504,16 +501,11 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0 += dBlocks) { const size_t biasRow = repeatingBias ? 0 : i0; -// const acc_t *const dDramAddr = -// (acc_t *)d + (biasRow * strideD + j0) * DIM; const size_t offset = (biasRow * strideD + j0) * DIM * sizeof (acc_t); const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * DIM; - const size_t blocks = j0 + dBlocks <= j ? dBlocks : j - j0; - const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); const size_t rows = DIM - (i0 == i - 1 ? padI : 0); - gemminiMvinOffset(d, offset, dSpAddrAcc, cols, rows, rewriter); } } @@ -524,17 +516,14 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(strideB)); rewriter.create(loc, strideValue, llvm::APFloat((float)bScaleFactor)); - // gemmini_extended_config_ld(strideB * sizeof(elem_t), bScaleFactor); for (size_t j0 = 0; j0 < j; j0 += bBlocks) { for (size_t k0 = 0; k0 < k; k0++) { -// const elem_t *const B_dram_addr = B + (k0 * strideB + j0) * DIM; const size_t offset = (k0 * strideB + j0) * DIM * sizeof (elem_t); const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; const size_t blocks = j0 + bBlocks <= j ? bBlocks : j - j0; const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); const size_t rows = DIM - (k0 == k - 1 ? padK : 0); gemminiMvinOffset(b, offset, bSpAddr, cols, rows, rewriter); -// gemmini_extended_mvin(B_dram_addr, bSpAddr, cols, rows); } } @@ -546,21 +535,18 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t k0 = 0; k0 < k; k0 += aBlocks) { -// const elem_t *const A_dram_addr = A + (i0 * strideA + k0) * DIM; const size_t offset = (i0 * strideA + k0) * DIM * sizeof (elem_t); const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; const size_t blocks = k0 + aBlocks <= k ? aBlocks : k - k0; const size_t cols = blocks * DIM - (k0 + blocks >= k ? padK : 0); const size_t rows = DIM - (i0 == i - 1 ? padI : 0); gemminiMvinOffset(a, offset, aSpAddr, cols, rows, rewriter); -// gemmini_extended_mvin(A_dram_addr, aSpAddr, cols, rows); } } for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0++) { const uint32_t cSpAddr = cSpAddrStart + (i0 * j + j0) * DIM; - for (size_t k0 = 0; k0 < k; k0++) { const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; @@ -594,13 +580,9 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { Value bSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(bSpAddr)); Value outSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(outSpAddr)); - Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); - Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(DIM)); -// gemmini_extended_preload(GARBAGE_ADDR, outSpAddr, DIM, DIM, cCols, -// cRows); rewriter.create(loc, garbageAddrOp, outSpAddrOp, dimOp, dimOp, cRowsOp, cColsOp); @@ -732,8 +714,6 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { /* Add config norm op */ - - // acc_scale_t => acc_scale_t if (act == IGELU) { const acc_scale_t sqrt_2 = 1.41421356237; const acc_scale_t S = bertScale; From b56ad9901e4980162565a141bd3eb479baa70e9a Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 10 Jul 2023 17:07:40 +0800 Subject: [PATCH 06/51] fix bug in GemminiConfigNormOpLowering --- .../Gemmini/Transforms/LegalizeForLLVMExport.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index def3d2ac94..71510f6425 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -179,12 +179,14 @@ struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern matchAndRewrite(ConfigNormOp configNormOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = configNormOp.getLoc(); - uint64_t rs1 = (uint64_t )((uint32_t )configNormOp.getQConst() << 32) | + // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, + // (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG) + uint64_t rs1 = (((uint64_t) (configNormOp.getQConst())) << 32) | (configNormOp.getQConstType() & 1) << 18 | (configNormOp.getSetStatsIdOnly() & 1) << 17 | (configNormOp.getActMsb() & 1) << 16 | configNormOp.getStatsId() << 8 | CONFIG_BERT; - uint64_t rs2 = ((uint64_t)((uint32_t)(configNormOp.getIguluQc())) << 32) | ((uint64_t)((uint32_t)(configNormOp.getIguluQb()))); + uint64_t rs2 = (((uint64_t) configNormOp.getIguluQc()) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIguluQb())); Value rs1Value = rewriter.create( loc, rewriter.getI64IntegerAttr(rs1)); Value rs2Value = rewriter.create( @@ -621,8 +623,6 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(offset)); IntegerType i64Type = rewriter.getI64Type(); Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - Value spadAddrValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(SpAddr)); uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; Value spad = rewriter.create( @@ -638,8 +638,6 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(offset)); IntegerType i64Type = rewriter.getI64Type(); Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - Value spadAddrValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(SpAddr)); uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; Value spad = rewriter.create( From ec9b42689d83d20e877628c4a2ee1b098d97484c Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 10 Jul 2023 20:33:32 +0800 Subject: [PATCH 07/51] before merge patch of gemmini dialect --- .../lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 71510f6425..34ad5cdf5d 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -181,12 +181,12 @@ struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern Location loc = configNormOp.getLoc(); // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, // (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG) - uint64_t rs1 = (((uint64_t) (configNormOp.getQConst())) << 32) | + uint64_t rs1 = (((uint64_t) ((uint32_t)configNormOp.getQConst())) << 32) | (configNormOp.getQConstType() & 1) << 18 | (configNormOp.getSetStatsIdOnly() & 1) << 17 | (configNormOp.getActMsb() & 1) << 16 | configNormOp.getStatsId() << 8 | CONFIG_BERT; - uint64_t rs2 = (((uint64_t) configNormOp.getIguluQc()) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIguluQb())); + uint64_t rs2 = (((uint64_t) ((uint32_t)configNormOp.getIguluQc())) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIguluQb())); Value rs1Value = rewriter.create( loc, rewriter.getI64IntegerAttr(rs1)); Value rs2Value = rewriter.create( From 54ffd2cca09b31ade9075c1298ddede6184c4519 Mon Sep 17 00:00:00 2001 From: yxy Date: Mon, 10 Jul 2023 23:45:26 +0800 Subject: [PATCH 08/51] add act in ConfigSt lowering, solve wrong computing --- midend/include/Dialect/Gemmini/Transform.h | 2 ++ .../lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index fba87712ce..915ddc875d 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -46,6 +46,8 @@ typedef float acc_scale_t; typedef uint32_t scale_t_bits; typedef float scale_t; typedef int32_t scale_acc_t; +typedef int32_t acc_t; +typedef int8_t elem_t; namespace mlir { diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 34ad5cdf5d..3308b316fb 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -116,11 +116,12 @@ struct GemminiConfigStLowering : public ConvertOpToLLVMPattern { int stride = getNumberFromValue(strideValue); float scale = configStOp.getScale().convertToFloat(); Location loc = configStOp.getLoc(); + uint64_t rs1 = ((uint64_t)configStOp.getActivation() << 2) | CONFIG_ST; uint64_t arg = (uint64_t)acc_scale_t_to_acc_scale_t_bits((acc_scale_t)scale) << 32 | (uint32_t)stride; Value value1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(CONFIG_ST)); + loc, rewriter.getI64IntegerAttr(rs1)); Value value2 = rewriter.create( loc, rewriter.getI64IntegerAttr(arg)); rewriter.replaceOpWithNewOp(configStOp, value1, value2); From b7aa81171268e293c377a29931bae3888d74ca7a Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 11 Jul 2023 00:14:31 +0800 Subject: [PATCH 09/51] Add test for tile-matmul-ws-igelu --- examples/GemminiDialect/makefile | 9 ++++++ .../GemminiDialect/tile-matmul-ws-igelu.mlir | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 examples/GemminiDialect/tile-matmul-ws-igelu.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 3b96e8487d..3c5925d886 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -85,6 +85,15 @@ tile-matmul-os-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-matmul-ws-igelu-run: + @${BUDDY_OPT} ./tile-matmul-ws-igelu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-run: @${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir new file mode 100644 index 0000000000..92f8671606 --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -0,0 +1,32 @@ +memref.global "private" @g1 : memref<3x3xi8> = dense<[[1, 0, 0], [1, -1, 1], [-1, 0, 1]]> +memref.global "private" @g2 : memref<3x3xi8> = dense<[[1, -1, 0], [1, 0, -1], [-1, -1, 0]]> + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<3x3xi8> + %bArray = memref.get_global @g2 : memref<3x3xi8> + %cArray = memref.alloc() : memref<3x3xi8> + %dArray = memref.alloc() : memref<3x3xi32> + %dim_I = memref.dim %aArray, %c0 : memref<3x3xi8> + %dim_J = memref.dim %bArray, %c1 : memref<3x3xi8> + %dim_K = memref.dim %aArray, %c1 : memref<3x3xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<3x3xi32> + } + } + gemmini.print %aArray : memref<3x3xi8> + gemmini.print %bArray : memref<3x3xi8> + // gemmini.print %dArray : memref<3x3xi32> + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=3, bertScale=0.8:f32}: memref<3x3xi8> memref<3x3xi8> memref<3x3xi8> memref<3x3xi32> + gemmini.print %cArray : memref<3x3xi8> + return %i0 : i8 +} From d242b9ca0823fe4cb7964da8fa3a2e6c4175e0fe Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 11 Jul 2023 21:05:49 +0800 Subject: [PATCH 10/51] change signature, use class variable to replace "sizeof" operator --- .../Gemmini/Transforms/LegalizeForLLVMExport.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 3308b316fb..dc3c2aec5e 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -174,14 +174,12 @@ struct GemminiConfigExLowering : public ConvertOpToLLVMPattern { } }; -struct GemminiConfigNormOpLowering : public ConvertOpToLLVMPattern { +struct GemminiConfigNormLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ConfigNormOp configNormOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = configNormOp.getLoc(); - // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, - // (((uint64_t) ((uint32_t) q_const)) << 32) | ((q_const_type & 1) << 18) | ((set_stats_id_only & 1) << 17) | ((act_msb & 1) << 16) | ((uint64_t)stat_id << 8) | CONFIG_BERT, ((uint64_t)((uint32_t)(igelu_qc)) << 32) | ((uint64_t)((uint32_t)(igelu_qb))), k_CONFIG) uint64_t rs1 = (((uint64_t) ((uint32_t)configNormOp.getQConst())) << 32) | (configNormOp.getQConstType() & 1) << 18 | (configNormOp.getSetStatsIdOnly() & 1) << 17 | @@ -495,7 +493,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { // Move-in D if (!dAddrNull && !noBias) { - const size_t dStride = repeatingBias ? 0 : strideD * sizeof(acc_t); + const size_t dStride = repeatingBias ? 0 : strideD * sizeOfAccT; Value strideValue = rewriter.create( loc, rewriter.getI64IntegerAttr(dStride)); rewriter.create(loc, strideValue, @@ -600,7 +598,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { } // Move-out C if (!cAddrNull) { - const size_t sizeof_C = fullC ? sizeof(acc_t) : sizeof(elem_t); + const size_t sizeof_C = fullC ? sizeOfAccT : sizeOfElemT; for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0++) { @@ -1622,7 +1620,7 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter); - patterns.add(converter); + patterns.add(converter); patterns.add(converter, dim, addrLen); patterns.add(converter, addrLen); patterns.add(converter, addrLen); From 498eb87fc4e323b8e01b28b5dad45179e2f1faa4 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 11 Jul 2023 21:58:13 +0800 Subject: [PATCH 11/51] add mvin2, mvin3 lowering --- .../llvm/IR/IntrinsicsRISCVBuddyExt.td | 6 ++ .../Target/RISCV/RISCVInstrInfoBuddyExt.td | 12 ++++ midend/include/Dialect/Gemmini/Gemmini.td | 32 +++++++++ .../Transforms/LegalizeForLLVMExport.cpp | 72 ++++++++++++++++++- 4 files changed, 120 insertions(+), 2 deletions(-) diff --git a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td index 9e96a395f7..51fffb6878 100644 --- a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td +++ b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -20,6 +20,12 @@ let TargetPrefix = "riscv" in def int_riscv_mvin : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; +let TargetPrefix = "riscv" in +def int_riscv_mvin2 : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; + +let TargetPrefix = "riscv" in +def int_riscv_mvin3 : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; + let TargetPrefix = "riscv" in def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; diff --git a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td index 12cdd74f5a..b5491d5045 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -35,6 +35,18 @@ def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs), let rd = 0; } +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + let rd = 0; +} + +let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in +def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs), + (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + let rd = 0; +} + let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in def MVOUT : RVInstR<0b0000011, 0b011, OPC_CUSTOM_3, (outs), (ins GPR:$rs1, GPR:$rs2), "mvout","$rs1, $rs2">{ diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index c0f4a7583f..2fc9718e78 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -144,6 +144,32 @@ def MvinOp : Gemmini_Op<"mvin"> { let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; } +def Mvin2Op : Gemmini_Op<"mvin2"> { + let summary = "Load operation"; + let description = [{ + Similar to Mvin + Move data from main memory to scratchpad + - MemRef to load in. + (including DRAM address, number of columns, number of rows) + - Local scratchpad or accumulator address. + }]; + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$input, I64:$addr); + let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; +} + +def Mvin3Op : Gemmini_Op<"mvin3"> { + let summary = "Load operation"; + let description = [{ + Similar to Mvin and Mvin2 + Move data from main memory to scratchpad + - MemRef to load in. + (including DRAM address, number of columns, number of rows) + - Local scratchpad or accumulator address. + }]; + let arguments = (ins MemRefRankOf<[AnyType], [2]>:$input, I64:$addr); + let assemblyFormat = "$input $addr attr-dict `:` type($input) type($addr)"; +} + def MvoutOp : Gemmini_Op<"mvout"> { let summary = "Store operation"; let description = [{ @@ -308,6 +334,12 @@ class Gemmini_IntrOpBase traits = []> : def Gemmini_Mvin_IntrOp : Gemmini_IntrOpBase<"mvin">, Arguments<(ins LLVM_Type, LLVM_Type)>; +def Gemmini_Mvin2_IntrOp : Gemmini_IntrOpBase<"mvin2">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +def Gemmini_Mvin3_IntrOp : Gemmini_IntrOpBase<"mvin3">, + Arguments<(ins LLVM_Type, LLVM_Type)>; + def Gemmini_Mvout_IntrOp : Gemmini_IntrOpBase<"mvout">, Arguments<(ins LLVM_Type, LLVM_Type)>; diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index dc3c2aec5e..647800d7be 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -229,6 +229,72 @@ struct GemminiMvinLowering : public ConvertOpToLLVMPattern { int64_t addrLen; }; +struct GemminiMvin2Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiMvin2Lowering(LLVMTypeConverter &typeConverter, + int64_t addrLen) + : ConvertOpToLLVMPattern(typeConverter), addrLen(addrLen) {} + LogicalResult + matchAndRewrite(Mvin2Op mvin2Op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = mvin2Op.getInput(); + Location loc = input.getLoc(); + MemRefType memRefType = + mvin2Op.getOperandTypes().front().dyn_cast(); + llvm::ArrayRef memRefShape = memRefType.getShape(); + TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); + Value extractOp = rewriter.create( + loc, resultType, input); + IntegerType i64Type = rewriter.getI64Type(); + Value indexCastOp = + rewriter.create(loc, i64Type, extractOp); + Value spadAddrValue = mvin2Op.getAddr(); + uint64_t number = getNumberFromValue(spadAddrValue); + uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | + (uint64_t)memRefShape[1] << addrLen | number; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.replaceOpWithNewOp(mvin2Op, indexCastOp, spad); + return success(); + } + +private: + int64_t addrLen; +}; + +struct GemminiMvin3Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + explicit GemminiMvin3Lowering(LLVMTypeConverter &typeConverter, + int64_t addrLen) + : ConvertOpToLLVMPattern(typeConverter), addrLen(addrLen) {} + LogicalResult + matchAndRewrite(Mvin3Op mvin3Op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = mvin3Op.getInput(); + Location loc = input.getLoc(); + MemRefType memRefType = + mvin3Op.getOperandTypes().front().dyn_cast(); + llvm::ArrayRef memRefShape = memRefType.getShape(); + TypeRange resultType = mlir::TypeRange(rewriter.getIndexType()); + Value extractOp = rewriter.create( + loc, resultType, input); + IntegerType i64Type = rewriter.getI64Type(); + Value indexCastOp = + rewriter.create(loc, i64Type, extractOp); + Value spadAddrValue = mvin3Op.getAddr(); + uint64_t number = getNumberFromValue(spadAddrValue); + uint64_t spadAddrInt = (uint64_t)memRefShape[0] << (addrLen + 16) | + (uint64_t)memRefShape[1] << addrLen | number; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.replaceOpWithNewOp(mvin3Op, indexCastOp, spad); + return success(); + } + +private: + int64_t addrLen; +}; + struct GemminiMvoutLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiMvoutLowering(LLVMTypeConverter &typeConverter, @@ -1618,6 +1684,8 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter); patterns.add(converter); patterns.add(converter, addrLen); + patterns.add(converter, addrLen); + patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter); patterns.add(converter); @@ -1635,14 +1703,14 @@ void mlir::configureGemminiegalizeForExportTarget( LLVMConversionTarget &target) { target.addLegalOp< Flush_IntrOp, ConfigSt_IntrOp, ConifgLd_IntrOp, ConfigEX_IntrOp, - Mvin_IntrOp, Mvout_IntrOp, Preload_IntrOp, ComputePreloaded_IntrOp, + Mvin_IntrOp, Mvin2_IntrOp, Mvin3_IntrOp, Mvout_IntrOp, Preload_IntrOp, ComputePreloaded_IntrOp, ComputeAccumulated_IntrOp, LoopWsConfigBounds_IntrOp, LoopWsConfigAddrsAB_IntrOp, LoopWsConfigAddrsDC_IntrOp, LoopWsConfigStridesAB_IntrOp, LoopWsConfigStridesDC_IntrOp, LoopWs_IntrOp, LoopConvWsConfig1_IntrOp, LoopConvWsConfig2_IntrOp, LoopConvWsConfig3_IntrOp, LoopConvWsConfig4_IntrOp, LoopConvWsConfig5_IntrOp, LoopConvWsConfig6_IntrOp, LoopConvWs_IntrOp, ConfigNorm_IntrOp>(); - target.addIllegalOp(); From 85c656bc09466646e2f4b07d4f4dfb7a65dfa53c Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 12 Jul 2023 20:02:16 +0800 Subject: [PATCH 12/51] make gemminiMvinOffset a tool function with template to adapt mvin2 and mvin3 add MVIN_SCALE_IDENTITY --- midend/include/Dialect/Gemmini/Transform.h | 1 + .../Transforms/LegalizeForLLVMExport.cpp | 62 ++++++++++--------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index 915ddc875d..3c4127e029 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -33,6 +33,7 @@ #define DIM 16 #define ADDR_LEN 32 #define ACC_SCALE_IDENTITY 1.0 +#define MVIN_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 #define ACC_ROWS 1024 diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 647800d7be..3f75ac0351 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -62,6 +62,38 @@ scale_t_bits scale_t_to_scale_t_bits(scale_t x) { un.f = x; return un.b; } + +template +void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, + ConversionPatternRewriter &rewriter) { + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | + (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); +} + +void gemminiMvoutOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, + const size_t cols, const size_t rows, + ConversionPatternRewriter &rewriter) { + Location loc = mem.getLoc(); + Value offsetOp = rewriter.create( + loc, rewriter.getI64IntegerAttr(offset)); + IntegerType i64Type = rewriter.getI64Type(); + Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); + uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | + (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + Value spad = rewriter.create( + loc, rewriter.getI64IntegerAttr(spadAddrInt)); + rewriter.create(loc, configPtr, spad); +} + }; // namespace template @@ -680,35 +712,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { } } - void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, - const size_t cols, const size_t rows, - ConversionPatternRewriter &rewriter) const{ - Location loc = mem.getLoc(); - Value offsetOp = rewriter.create( - loc, rewriter.getI64IntegerAttr(offset)); - IntegerType i64Type = rewriter.getI64Type(); - Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | - (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; - Value spad = rewriter.create( - loc, rewriter.getI64IntegerAttr(spadAddrInt)); - rewriter.create(loc, configPtr, spad); - } - void gemminiMvoutOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, - const size_t cols, const size_t rows, - ConversionPatternRewriter &rewriter) const{ - Location loc = mem.getLoc(); - Value offsetOp = rewriter.create( - loc, rewriter.getI64IntegerAttr(offset)); - IntegerType i64Type = rewriter.getI64Type(); - Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | - (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; - Value spad = rewriter.create( - loc, rewriter.getI64IntegerAttr(spadAddrInt)); - rewriter.create(loc, configPtr, spad); - } void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, size_t k, size_t padI, size_t padJ, size_t padK, size_t strideA, @@ -1219,7 +1223,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64IntegerAttr(stDramStride)); rewriter.create(loc, strideValue, act, llvm::APFloat(scale)); rewriter.create( - loc, /*dataflow = */ 1, /*act = */ 0, /*shift = */ 0, + loc, /*dataflow = */ WEIGHT_STATIONARY, /*act = */ 0, /*shift = */ 0, /*scale = */ llvm::APFloat((float)0), /*cStride = */ inputDilation, /*aStride = */ stride >> downsample, /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, From 1f444d6e041914eb912308ea30909f60502c33fa Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 12 Jul 2023 20:22:59 +0800 Subject: [PATCH 13/51] extend config_ld op --- midend/include/Dialect/Gemmini/Gemmini.td | 4 +++- .../lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index 2fc9718e78..7551864292 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -76,7 +76,9 @@ def ConfigLdOp : Gemmini_Op<"config_ld"> { let arguments = (ins I64:$stride, DefaultValuedAttr:$scale, DefaultValuedAttr:$shrunk, - DefaultValuedAttr:$id); + DefaultValuedAttr:$id, + DefaultValuedAttr:$block_mvin_stride, + DefaultValuedAttr:$pixel_repeats); let assemblyFormat = "$stride attr-dict `:` type($stride)"; } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 3f75ac0351..17404c1656 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -168,8 +168,10 @@ struct GemminiConfigLdLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Value rs2Value = configLdOp.getStride(); float scale = configLdOp.getScale().convertToFloat(); + uint64_t blockMvinStride = configLdOp.getBlockMvinStride(); + uint64_t pixelRepeats = configLdOp.getPixelRepeats(); uint64_t rs1 = (uint64_t)scale_t_to_scale_t_bits(scale) << 32 | - ((uint64_t)16 << 16) | (uint64_t)1 << 8 | + (blockMvinStride << 16) | pixelRepeats << 8 | configLdOp.getId() << 3 | configLdOp.getShrunk() << 2 | CONFIG_LD; Location loc = configLdOp.getLoc(); From 5e69ee1b23c84537d4468553388d2e5cbd526c52 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 3 Aug 2023 00:28:14 +0800 Subject: [PATCH 14/51] Fixed a spell error, but still don't know how to describe the parameters of ConfigNormOp --- midend/include/Dialect/Gemmini/Gemmini.td | 15 +++++++++++---- .../Gemmini/Transforms/LegalizeForLLVMExport.cpp | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index 7551864292..9f720926e8 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -119,17 +119,24 @@ def ConfigExOp : Gemmini_Op<"config_ex"> { } def ConfigNormOp : Gemmini_Op<"config_norm"> { - let summary = "ConfigNormOp configures TODO pipeline"; + let summary = "ConfigNormOp configures normalize pipeline"; let description = [{ - ConfigNormOp configures TODO pipeline + ConfigNormOp configures normalize pipeline + -qConst: + -qConstType: + -setStatsIdOnly: + -actMsg: + -StatsId: + -igeluQb: + -igeluQc: }]; let arguments = (ins DefaultValuedAttr:$qConst, DefaultValuedAttr:$qConstType, DefaultValuedAttr:$setStatsIdOnly, DefaultValuedAttr:$actMsb, DefaultValuedAttr:$StatsId, - DefaultValuedAttr:$iguluQb, - DefaultValuedAttr:$iguluQc); + DefaultValuedAttr:$igeluQb, + DefaultValuedAttr:$igeluQc); let assemblyFormat = "attr-dict"; } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 17404c1656..0d739d4acb 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -219,7 +219,7 @@ struct GemminiConfigNormLowering : public ConvertOpToLLVMPattern { (configNormOp.getSetStatsIdOnly() & 1) << 17 | (configNormOp.getActMsb() & 1) << 16 | configNormOp.getStatsId() << 8 | CONFIG_BERT; - uint64_t rs2 = (((uint64_t) ((uint32_t)configNormOp.getIguluQc())) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIguluQb())); + uint64_t rs2 = (((uint64_t) ((uint32_t)configNormOp.getIgeluQc())) << 32) | ((uint64_t) ((uint32_t)configNormOp.getIgeluQb())); Value rs1Value = rewriter.create( loc, rewriter.getI64IntegerAttr(rs1)); Value rs2Value = rewriter.create( From 8143b89c0c366d01461eb28de6d4ee96520c4dbe Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 3 Aug 2023 00:29:26 +0800 Subject: [PATCH 15/51] delete empty line --- midend/include/Dialect/Gemmini/Gemmini.td | 1 - 1 file changed, 1 deletion(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index 9f720926e8..e3045e43fb 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -138,7 +138,6 @@ def ConfigNormOp : Gemmini_Op<"config_norm"> { DefaultValuedAttr:$igeluQb, DefaultValuedAttr:$igeluQc); let assemblyFormat = "attr-dict"; - } def MvinOp : Gemmini_Op<"mvin"> { From aa733729d7614e4b406bfae5162dff075d2b4ffa Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 9 Aug 2023 00:23:30 +0800 Subject: [PATCH 16/51] add filecheck for tile-matmul-os.mlir --- examples/GemminiDialect/tile-matmul-os.mlir | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/GemminiDialect/tile-matmul-os.mlir b/examples/GemminiDialect/tile-matmul-os.mlir index 576c8f9443..dafcabef72 100644 --- a/examples/GemminiDialect/tile-matmul-os.mlir +++ b/examples/GemminiDialect/tile-matmul-os.mlir @@ -1,3 +1,7 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + func.func @main() -> i8 { %i0 = arith.constant 0 : i8 %i1I8 = arith.constant 1 : i8 @@ -20,6 +24,12 @@ func.func @main() -> i8 { gemmini.print %aArray : memref<64x64xi8> gemmini.print %bArray : memref<64x64xi8> gemmini.print %dArray : memref<64x64xi32> + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.mvin" + // CHECK: "gemmini.intr.preload" + // CHECK: "gemmini.intr.compute_preloaded" + // CHECK: "gemmini.intr.compute_accumulated" + // CHECK: "gemmini.intr.mvout" gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=0} : memref<64x64xi8> memref<64x64xi8> memref<64x64xi8> memref <64x64xi32> gemmini.print %cArray : memref<64x64xi8> return %i0 : i8 From 91ce9c8e910616b49c7210d10e603592fb93e870 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 9 Aug 2023 00:32:44 +0800 Subject: [PATCH 17/51] add FILECHECK for tile-matmul-ws-igelu.mlir --- .../GemminiDialect/tile-matmul-ws-igelu.mlir | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir index 92f8671606..78db052800 100644 --- a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -1,3 +1,7 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + memref.global "private" @g1 : memref<3x3xi8> = dense<[[1, 0, 0], [1, -1, 1], [-1, 0, 1]]> memref.global "private" @g2 : memref<3x3xi8> = dense<[[1, -1, 0], [1, 0, -1], [-1, -1, 0]]> @@ -25,7 +29,17 @@ func.func @main() -> i8 { } gemmini.print %aArray : memref<3x3xi8> gemmini.print %bArray : memref<3x3xi8> - // gemmini.print %dArray : memref<3x3xi32> + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=3, bertScale=0.8:f32}: memref<3x3xi8> memref<3x3xi8> memref<3x3xi8> memref<3x3xi32> gemmini.print %cArray : memref<3x3xi8> return %i0 : i8 From 1898f4c740623998b05c17251db25a9cd811f814 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 9 Aug 2023 00:40:54 +0800 Subject: [PATCH 18/51] add detailed descrition generated by chatGPT. --- midend/include/Dialect/Gemmini/Gemmini.td | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index e3045e43fb..e998a5d548 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -122,13 +122,13 @@ def ConfigNormOp : Gemmini_Op<"config_norm"> { let summary = "ConfigNormOp configures normalize pipeline"; let description = [{ ConfigNormOp configures normalize pipeline - -qConst: - -qConstType: - -setStatsIdOnly: - -actMsg: - -StatsId: - -igeluQb: - -igeluQc: + -qConst: A constant value used for quantization during normalization. + -qConstType: Defines the type or format of the qConst. + -setStatsIdOnly: A flag to indicate if only the StatsId should be set, without applying any other normalization parameters. + -actMsg: A message or alert related to the normalization activity. + -StatsId: An identifier associated with the statistics or metrics of the normalization process. + -igeluQb: A parameter related to the Inverse Gaussian Error Linear Unit (IGELU) function for quantization. Specifies the 'b' value. + -igeluQc: Another parameter related to the IGELU function for quantization. Specifies the 'c' value. }]; let arguments = (ins DefaultValuedAttr:$qConst, DefaultValuedAttr:$qConstType, From 4a2ecd7196d7bbd6d093f3bc0af88b7f28ab2e2c Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 9 Aug 2023 15:28:46 +0800 Subject: [PATCH 19/51] delete sizeof, using class attribute delete useless define --- midend/include/Dialect/Gemmini/Transform.h | 2 -- .../Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index 3c4127e029..738a62d13c 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -33,7 +33,6 @@ #define DIM 16 #define ADDR_LEN 32 #define ACC_SCALE_IDENTITY 1.0 -#define MVIN_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 #define ACC_ROWS 1024 @@ -48,7 +47,6 @@ typedef uint32_t scale_t_bits; typedef float scale_t; typedef int32_t scale_acc_t; typedef int32_t acc_t; -typedef int8_t elem_t; namespace mlir { diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 0d739d4acb..69cf071d5c 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -602,7 +602,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0 += dBlocks) { const size_t biasRow = repeatingBias ? 0 : i0; - const size_t offset = (biasRow * strideD + j0) * DIM * sizeof (acc_t); + const size_t offset = (biasRow * strideD + j0) * DIM * sizeOfAccT; const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * DIM; const size_t blocks = j0 + dBlocks <= j ? dBlocks : j - j0; const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); @@ -619,7 +619,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { llvm::APFloat((float)bScaleFactor)); for (size_t j0 = 0; j0 < j; j0 += bBlocks) { for (size_t k0 = 0; k0 < k; k0++) { - const size_t offset = (k0 * strideB + j0) * DIM * sizeof (elem_t); + const size_t offset = (k0 * strideB + j0) * DIM * sizeOfElemT; const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; const size_t blocks = j0 + bBlocks <= j ? bBlocks : j - j0; const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); @@ -636,7 +636,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t k0 = 0; k0 < k; k0 += aBlocks) { - const size_t offset = (i0 * strideA + k0) * DIM * sizeof (elem_t); + const size_t offset = (i0 * strideA + k0) * DIM * sizeOfElemT; const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; const size_t blocks = k0 + aBlocks <= k ? aBlocks : k - k0; const size_t cols = blocks * DIM - (k0 + blocks >= k ? padK : 0); From bd17696b58cc01b9f03235b3ce47ca1bb59ed471 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 10 Aug 2023 15:32:43 +0800 Subject: [PATCH 20/51] add space in int_riscv_mvin --- backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td index 51fffb6878..c0cac18044 100644 --- a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td +++ b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -18,13 +18,13 @@ // //===----------------------------------------------------------------------===// let TargetPrefix = "riscv" in -def int_riscv_mvin : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; +def int_riscv_mvin : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; let TargetPrefix = "riscv" in -def int_riscv_mvin2 : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; +def int_riscv_mvin2 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; let TargetPrefix = "riscv" in -def int_riscv_mvin3 : Intrinsic<[],[llvm_i64_ty, llvm_i64_ty],[]>; +def int_riscv_mvin3 : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty],[]>; let TargetPrefix = "riscv" in def int_riscv_mvout : Intrinsic<[], [llvm_i64_ty, llvm_i64_ty], []>; From 5cfaab45c5d9e27ac0909c707f02f557d518ae80 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 10 Aug 2023 15:35:42 +0800 Subject: [PATCH 21/51] add description for config_norm op --- midend/include/Dialect/Gemmini/Gemmini.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index e998a5d548..e73889fe61 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -123,11 +123,11 @@ def ConfigNormOp : Gemmini_Op<"config_norm"> { let description = [{ ConfigNormOp configures normalize pipeline -qConst: A constant value used for quantization during normalization. - -qConstType: Defines the type or format of the qConst. - -setStatsIdOnly: A flag to indicate if only the StatsId should be set, without applying any other normalization parameters. - -actMsg: A message or alert related to the normalization activity. + -qConstType: Defines the type of the qConst. + -setStatsIdOnly: A flag to indicate if only the StatsId should be set. + -actMsg: A message related to the normalization activity. -StatsId: An identifier associated with the statistics or metrics of the normalization process. - -igeluQb: A parameter related to the Inverse Gaussian Error Linear Unit (IGELU) function for quantization. Specifies the 'b' value. + -igeluQb: A parameter related to the IGELU function for quantization. Specifies the 'b' value. -igeluQc: Another parameter related to the IGELU function for quantization. Specifies the 'c' value. }]; let arguments = (ins DefaultValuedAttr:$qConst, From 2e30b2b79cd2852afe654a0de0428f5f29b4cf26 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 10 Aug 2023 15:47:50 +0800 Subject: [PATCH 22/51] delete scale_t and acc_t --- .../Transforms/LegalizeForLLVMExport.cpp | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 69cf071d5c..a3d3ff45b1 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -784,24 +784,24 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { Add config norm op */ if (act == IGELU) { - const acc_scale_t sqrt_2 = 1.41421356237; - const acc_scale_t S = bertScale; - const acc_scale_t S_erf = (-0.2888 * ((S*S) / 2)); + const float sqrt_2 = 1.41421356237; + const float S = bertScale; + const float S_erf = (-0.2888 * ((S*S) / 2)); - const acc_t qb = -1.769 / (S / sqrt_2); - const acc_t qc = 1.0 / S_erf; + const uint32_t qb = -1.769 / (S / sqrt_2); + const uint32_t qc = 1.0 / S_erf; rewriter.create(loc, 0, 0, 0, 0,0, qb, qc); } if (act == SOFTMAX) { - const scale_t a = 0.3585; - const scale_t b = 1.353; - const scale_t c = 0.344; - - const acc_t qln2 = (int) (0.693147 / bertScale); - const acc_t qln2_inv = 65536 / qln2; - const acc_t qb = b / bertScale; - const acc_t qc = c / (a*bertScale*bertScale); + const float a = 0.3585; + const float b = 1.353; + const float c = 0.344; + + const uint32_t qln2 = (int) (0.693147 / bertScale); + const uint32_t qln2_inv = 65536 / qln2; + const uint32_t qb = b / bertScale; + const uint32_t qc = c / (a*bertScale*bertScale); rewriter.create(loc, qln2, 0, 0, 1, 0, qb, qc); rewriter.create(loc, qln2_inv, 1, 0, 1, 0, qb, qc); } From 31cedc4204dd004f273bea30a2fbd21f653de978 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 10 Aug 2023 15:48:29 +0800 Subject: [PATCH 23/51] delete function inner --- .../Gemmini/Transforms/LegalizeForLLVMExport.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index a3d3ff45b1..f62a12f3ac 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -713,21 +713,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { } } } - - - - void inner(Value &a, Value &b, Value &pre, Value &out, scale_t aScaleFactor, scale_t bScaleFactor, scale_acc_t dScaleFactor, size_t i, size_t j, - size_t k, size_t padI, size_t padJ, size_t padK, size_t strideA, - size_t strideB, size_t strideD, size_t strideC, bool aTranspose, - bool bTranspose, bool fullC, bool lowD, bool noBias, - bool repeatingBias, int act, TileMatMulOp &tileMatMulOp, - ConversionPatternRewriter &rewriter) const { - - gemminiLoopWs(i, j, k, padI, padJ, padK, a, b, pre, out, strideA, strideB, - repeatingBias ? 0 : strideD, strideC, aTranspose, bTranspose, - fullC, lowD, !noBias, act, tileMatMulOp, rewriter); - } - + void tiledMatmulOuter(size_t dimI, size_t dimJ, size_t dimK, Value &A, Value &B, Value &D, Value &C, size_t strideA, size_t strideB, size_t strideD, size_t strideC, From a5780fc246851d631ea882753b6e06717fbc77e4 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 5 Sep 2023 17:38:59 +0800 Subject: [PATCH 24/51] delete define acc_t, DIM, ADDR_LEN --- midend/include/Dialect/Gemmini/Transform.h | 5 - .../Transforms/LegalizeForLLVMExport.cpp | 97 ++++++++++--------- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index 738a62d13c..a6b3ebff98 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -30,15 +30,11 @@ #define OUTPUT_STATIONARY 0 #define WEIGHT_STATIONARY 1 -#define DIM 16 -#define ADDR_LEN 32 #define ACC_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 #define ACC_ROWS 1024 #define MAX_BYTES 64 -#define MAX_BLOCK_LEN (MAX_BYTES/(DIM*1)) -#define MAX_BLOCK_LEN_ACC (MAX_BYTES/(DIM*4)) #define HAS_FIRST_LAYER_OPTIMIZATIONS typedef uint32_t acc_scale_t_bits; @@ -46,7 +42,6 @@ typedef float acc_scale_t; typedef uint32_t scale_t_bits; typedef float scale_t; typedef int32_t scale_acc_t; -typedef int32_t acc_t; namespace mlir { diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index f62a12f3ac..02db531032 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -65,30 +65,30 @@ scale_t_bits scale_t_to_scale_t_bits(scale_t x) { template void gemminiMvinOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, - const size_t cols, const size_t rows, + const size_t cols, const size_t rows, int64_t addrLen, ConversionPatternRewriter &rewriter) { Location loc = mem.getLoc(); Value offsetOp = rewriter.create( loc, rewriter.getI64IntegerAttr(offset)); IntegerType i64Type = rewriter.getI64Type(); Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | - (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + uint64_t spadAddrInt = (uint64_t)rows << (addrLen + 16) | + (uint64_t)cols << addrLen | (uint64_t) SpAddr; Value spad = rewriter.create( loc, rewriter.getI64IntegerAttr(spadAddrInt)); rewriter.create(loc, configPtr, spad); } void gemminiMvoutOffset(const Value &mem, const size_t offset, const uint32_t SpAddr, - const size_t cols, const size_t rows, + const size_t cols, const size_t rows, int64_t addrLen, ConversionPatternRewriter &rewriter) { Location loc = mem.getLoc(); Value offsetOp = rewriter.create( loc, rewriter.getI64IntegerAttr(offset)); IntegerType i64Type = rewriter.getI64Type(); Value configPtr = rewriter.create(loc, i64Type, mem, offsetOp); - uint64_t spadAddrInt = (uint64_t)rows << (ADDR_LEN + 16) | - (uint64_t)cols << ADDR_LEN | (uint64_t) SpAddr; + uint64_t spadAddrInt = (uint64_t)rows << (addrLen + 16) | + (uint64_t)cols << addrLen | (uint64_t) SpAddr; Value spad = rewriter.create( loc, rewriter.getI64IntegerAttr(spadAddrInt)); rewriter.create(loc, configPtr, spad); @@ -578,14 +578,17 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { TileMatMulOp &tileMatMulOp, ConversionPatternRewriter &rewriter) const { const uint32_t aSpAddrStart = 0; - const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - k * j * DIM; - const uint32_t dSpAddrStart = 1 << (ADDR_LEN - 1); + const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - k * j * dim; + const uint32_t dSpAddrStart = 1 << (addrLen - 1); const uint32_t cSpAddrStart = - (3 << (ADDR_LEN - 2)) | (fullC << (ADDR_LEN - 3)); + (3 << (addrLen - 2)) | (fullC << (addrLen - 3)); + + const size_t maxBlockLen = MAX_BYTES / (dim * 1); + const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); - const int aBlocks = k <= MAX_BLOCK_LEN ? k : MAX_BLOCK_LEN; - const int bBlocks = j <= MAX_BLOCK_LEN ? j : MAX_BLOCK_LEN; - const int dBlocks = j <= MAX_BLOCK_LEN_ACC ? j : MAX_BLOCK_LEN_ACC; + const int aBlocks = k <= maxBlockLen ? k : maxBlockLen; + const int bBlocks = j <= maxBlockLen ? j : maxBlockLen; + const int dBlocks = j <= maxBlockLenAcc ? j : maxBlockLenAcc; Location loc = a.getLoc(); bool dAddrNull = llvm::dyn_cast(d.getDefiningOp()) && getNumberFromValue(d) == 0; @@ -602,12 +605,12 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0 += dBlocks) { const size_t biasRow = repeatingBias ? 0 : i0; - const size_t offset = (biasRow * strideD + j0) * DIM * sizeOfAccT; - const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * DIM; + const size_t offset = (biasRow * strideD + j0) * dim * sizeOfAccT; + const uint32_t dSpAddrAcc = dSpAddrStart + (i0 * j + j0) * dim; const size_t blocks = j0 + dBlocks <= j ? dBlocks : j - j0; - const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); - const size_t rows = DIM - (i0 == i - 1 ? padI : 0); - gemminiMvinOffset(d, offset, dSpAddrAcc, cols, rows, rewriter); + const size_t cols = blocks * dim - (j0 + blocks >= j ? padJ : 0); + const size_t rows = dim - (i0 == i - 1 ? padI : 0); + gemminiMvinOffset(d, offset, dSpAddrAcc, cols, rows, addrLen, rewriter); } } } @@ -619,12 +622,12 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { llvm::APFloat((float)bScaleFactor)); for (size_t j0 = 0; j0 < j; j0 += bBlocks) { for (size_t k0 = 0; k0 < k; k0++) { - const size_t offset = (k0 * strideB + j0) * DIM * sizeOfElemT; - const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; + const size_t offset = (k0 * strideB + j0) * dim * sizeOfElemT; + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * dim; const size_t blocks = j0 + bBlocks <= j ? bBlocks : j - j0; - const size_t cols = blocks * DIM - (j0 + blocks >= j ? padJ : 0); - const size_t rows = DIM - (k0 == k - 1 ? padK : 0); - gemminiMvinOffset(b, offset, bSpAddr, cols, rows, rewriter); + const size_t cols = blocks * dim - (j0 + blocks >= j ? padJ : 0); + const size_t rows = dim - (k0 == k - 1 ? padK : 0); + gemminiMvinOffset(b, offset, bSpAddr, cols, rows, addrLen, rewriter); } } @@ -636,22 +639,22 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t k0 = 0; k0 < k; k0 += aBlocks) { - const size_t offset = (i0 * strideA + k0) * DIM * sizeOfElemT; - const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; + const size_t offset = (i0 * strideA + k0) * dim * sizeOfElemT; + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * dim; const size_t blocks = k0 + aBlocks <= k ? aBlocks : k - k0; - const size_t cols = blocks * DIM - (k0 + blocks >= k ? padK : 0); - const size_t rows = DIM - (i0 == i - 1 ? padI : 0); - gemminiMvinOffset(a, offset, aSpAddr, cols, rows, rewriter); + const size_t cols = blocks * dim - (k0 + blocks >= k ? padK : 0); + const size_t rows = dim - (i0 == i - 1 ? padI : 0); + gemminiMvinOffset(a, offset, aSpAddr, cols, rows, addrLen, rewriter); } } for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0++) { - const uint32_t cSpAddr = cSpAddrStart + (i0 * j + j0) * DIM; + const uint32_t cSpAddr = cSpAddrStart + (i0 * j + j0) * dim; for (size_t k0 = 0; k0 < k; k0++) { - const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * DIM; - const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * DIM; + const uint32_t aSpAddr = aSpAddrStart + (i0 * k + k0) * dim; + const uint32_t bSpAddr = bSpAddrStart + (k0 * j + j0) * dim; uint32_t outSpAddr = k0 == k - 1 ? cSpAddr : GARBAGE_ADDR; @@ -660,15 +663,15 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { int noBiasNewMatrix = noBias && !dAddrNull && k0 == k - 1; if (noBiasNewMatrix) { - outSpAddr &= ~(1 << (ADDR_LEN - 2)); + outSpAddr &= ~(1 << (addrLen - 2)); } - const size_t aCols = DIM - (k0 == k - 1 ? padK : 0); - const size_t aRows = DIM - (i0 == i - 1 ? padI : 0); - const size_t bCols = DIM - (j0 == j - 1 ? padJ : 0); - const size_t bRows = DIM - (k0 == k - 1 ? padK : 0); - const size_t cCols = DIM - (j0 == j - 1 ? padJ : 0); - const size_t cRows = DIM - (i0 == i - 1 ? padI : 0); + const size_t aCols = dim - (k0 == k - 1 ? padK : 0); + const size_t aRows = dim - (i0 == i - 1 ? padI : 0); + const size_t bCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t bRows = dim - (k0 == k - 1 ? padK : 0); + const size_t cCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t cRows = dim - (i0 == i - 1 ? padI : 0); Value aColsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aCols)); Value aRowsOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aRows)); @@ -682,7 +685,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { Value outSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(outSpAddr)); Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); - Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(DIM)); + Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); rewriter.create(loc, garbageAddrOp, outSpAddrOp, dimOp, dimOp, cRowsOp, cColsOp); @@ -702,13 +705,13 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { for (size_t i0 = 0; i0 < i; i0++) { for (size_t j0 = 0; j0 < j; j0++) { - const size_t offset = (i0 *strideC + j0)*DIM*sizeof_C; - const uint32_t cSpAddr = cSpAddrStart + (i0 *j + j0)*DIM; + const size_t offset = (i0 *strideC + j0)*dim*sizeof_C; + const uint32_t cSpAddr = cSpAddrStart + (i0 *j + j0)*dim; - const size_t cCols = DIM - (j0 == j - 1 ? padJ : 0); - const size_t cRows = DIM - (i0 == j - 1 ? padI : 0); + const size_t cCols = dim - (j0 == j - 1 ? padJ : 0); + const size_t cRows = dim - (i0 == j - 1 ? padI : 0); - gemminiMvoutOffset(c, offset, cSpAddr, cCols, cRows, rewriter); + gemminiMvoutOffset(c, offset, cSpAddr, cCols, cRows, addrLen, rewriter); } } } @@ -899,9 +902,9 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiTileMatMulLowering(LLVMTypeConverter &typeConverter, - int64_t dim, size_t sizeOfElemT, + int64_t dim, int64_t addrLen, size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), sizeOfElemT(sizeOfElemT), sizeOfAccT(sizeOfAccT) {} LogicalResult @@ -1064,6 +1067,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { private: int64_t dim; + int64_t addrLen; size_t sizeOfElemT; size_t sizeOfAccT; }; @@ -1662,6 +1666,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { private: int64_t dim; + size_t sizeOfElemT; size_t sizeOfAccT; }; @@ -1685,7 +1690,7 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter, addrLen); patterns.add(converter, addrLen); patterns.add(converter, addrLen); - patterns.add(converter, dim, sizeOfElemT, + patterns.add(converter, dim, addrLen, sizeOfElemT, sizeOfAccT); patterns.add(converter, dim, sizeOfElemT, sizeOfAccT); From 2e11d1cfbffc5352935953b5a10451d6f8a1bd17 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 13 Jul 2023 00:51:49 +0800 Subject: [PATCH 25/51] update to gemmini upstream and complete rectangle conv --- .../Transforms/LegalizeForLLVMExport.cpp | 377 +++++++++++++++++- 1 file changed, 369 insertions(+), 8 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 02db531032..5ac7bd4236 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1142,10 +1142,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, rs1Value, rs2Value); } - void spTiledConv(int batchSize, int inDim, int inChannels, int outChannels, - int outDim, int poolOutDim, int stride, int padding, - int kernelDim, int kernelDilation, int poolSize, - int poolStride, int poolPadding, int batches, int porows, + void spTiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, + int outChannels, int outRowDim, int outColDim, + int poolOutRowDim, int poolOutColDim, + int stride, int padding, int kernelDim, int kernelDilation, + int poolSize, int poolStride, int poolPadding, + int batches, int porows, int pocols, int pochs, int krows, int kcols, int kchs, int lpad, int rpad, int upad, int dpad, int plpad, int prpad, int pupad, int pdpad, Value &input, Value &weights, @@ -1154,7 +1156,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { bool transWeight1203, bool transWeight0132, bool noBias, bool noPool, bool downsample, bool inputDilated, bool dw, TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + int inStride = 1, int weightStride = 1, int outStride = 1 + ) const { + Location loc = tileConvOp.getLoc(); if (dw) { kchs = 1; pochs = 1; @@ -1162,8 +1167,30 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int orows = porows * poolStride + poolSize - 1 - pupad - pdpad; const int ocols = pocols * poolStride + poolSize - 1 - plpad - prpad; + const int ochs = pochs; + + // Calculate image dimensions + // Note: "irows" and "icols" includes padding + const int dilatedKrows = krows + (kernelDilation - 1)*(krows - 1); + const int dilatedKcols = kcols + (kernelDilation - 1)*(kcols - 1); + int irows = orows * stride + dilatedKrows - 1; + int icols = ocols * stride + dilatedKcols - 1; + int irowsUnpadded = irows - upad - dpad; + int icolsUnpadded = icols - lpad - rpad; + const int ichs = kchs; +#define UNDILATED(x) ((inputDilated) ? (((x)+1)/2) : (x)) + + if (inputDilated) { + irowsUnpadded = (irowsUnpadded+1)/2; + icolsUnpadded = (icolsUnpadded+1)/2; + + irows = irowsUnpadded + UNDILATED(upad) + UNDILATED(dpad); + icols = icolsUnpadded + UNDILATED(lpad) + UNDILATED(rpad); + } + + #ifdef HAS_FIRST_LAYER_OPTIMIZATIONS const bool transposed = transOutput1203 || transInput3120 || transWeight1203 || transWeight0132; @@ -1176,14 +1203,348 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { #else const int maxPixelsPerRow = 1; #endif + // Calculate spad address offsets + const int outChannelsPerBank = ochs / DIM + (ochs % DIM != 0); + const int inChannelsPerBank = kchs / DIM + (kchs % DIM != 0); + const int bRows = transWeight0132 ? inChannelsPerBank * kcols * krows * ochs : + outChannelsPerBank * kcols * krows * kchs; + + static uint32_t dSpAddrRow = 0; + static uint32_t cSpAddrRow = 0; + + const uint32_t aSpAddrStart = 0; + const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - bRows; + const uint32_t dSpAddrStart = (1 << (ADDR_LEN - 1)) + dSpAddrRow; + const uint32_t cSpAddrStart = (3 << (ADDR_LEN - 2)) + cSpAddrRow; + + if (bias != 0) { + dSpAddrRow = (dSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; + } + + if (output != 0) { + cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; + } gemminiLoopConvWs( - batchSize, inDim, inChannels, outChannels, outDim, poolOutDim, stride, + batchSize, inRowDim, inChannels, outChannels, outRowDim, poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, input, noBias, noPool, downsample, wrot180, inputDilated, act, transOutput1203, transWeight1203, transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, rewriter); + if (!noPool) { + // TODO: Exit, but now I don't known how to do + // printf("Pooling with rectangular convolutions is currently not supported.\n"); + // exit(1); + } + // Only rectangular convolutions will use the following C code + // mvin bias + if (bias != NULL) { + // TODO we probably don't need quite this many nested loops for this part + const int maxOchsPerMvin = ochs < MAX_BLOCK_LEN_ACC * DIM ? ochs : + MAX_BLOCK_LEN_ACC * DIM; + Value zeroValue = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + // rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, ) + // TODO: configLd op 这里不够用,需要加block_mvinStride和pixel_repeats,应该不难 +// gemmini_extended4_config_ld(); + rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); + for (int b = 0; b < batches; b++) + for (int orow = 0; orow < orows; orow++) + for (int ocol = 0; ocol < ocols; ocol += DIM) { + const int I = ocols - ocol > DIM ? DIM : ocols - ocol; + for (int och = 0; och < ochs; och += maxOchsPerMvin) { + const int J = ochs - och > maxOchsPerMvin ? maxOchsPerMvin : ochs - och; + const uint32_t dSpAddr = dSpAddrStart + (och / DIM) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; +// const acc_t * bias_dram_addr = noBias ? NULL : bias + och; +// gemmini_extended_mvin3(bias_dram_addr, +// dSpAddr, +// J, I); + if (noBias) { +// Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + gemminiMvinOffset(zeroValue, 0, dSpAddr, J, I, rewriter); + } else { + gemminiMvinOffset(bias, och, dSpAddr, J, I, rewriter); + } + } + } + } + // mvin input + if (input != NULL){ + int maxChsPerMvin = ichs < MAX_BLOCK_LEN * DIM ? ichs : + MAX_BLOCK_LEN * DIM; + if (transInput3120) { + maxChsPerMvin = batches < MAX_BLOCK_LEN * DIM ? batches : + MAX_BLOCK_LEN * DIM; + } + const int dramStride = transInput3120 ? + batchSize * sizeof(elem_t) : + inChannels * sizeof(elem_t); + const int spadStride = transInput3120 ? + ichs * (irows >> downsample) * (icols >> downsample) : + batches * (irows >> downsample) * (icols >> downsample); + Value strideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride << downsample)); + rewriter.create(loc, strideValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 0, spadStride, maxPixelsPerRow); +// gemmini_extended5_config_ld(dramStride << downsample, MVIN_SCALE_IDENTITY, false, spadStride, maxPixelsPerRow, 0); + const int b_it = transInput3120 ? maxChsPerMvin : 1; + const int ich_it = transInput3120 ? 1 : maxChsPerMvin; + for (int b = 0; b < batches; b += b_it) + for (int irow = -UNDILATED(upad); irow < irowsUnpadded + UNDILATED(dpad); irow += 1 + downsample) { + const int irowPadded = irow + UNDILATED(upad); + for (int icol = -UNDILATED(lpad); icol < icolsUnpadded + UNDILATED(rpad);) { + // TODO There might be some unnecessary mvins here at the edge of the image + int I = icolsUnpadded - icol > (DIM << downsample) ? + (DIM << downsample) : icolsUnpadded - icol; + if (icol < 0) { + I = -icol > DIM ? DIM : -icol; + } else if (icol >= icolsUnpadded) { + I = icolsUnpadded + UNDILATED(rpad) - icol > DIM ? DIM : icolsUnpadded + UNDILATED(rpad) - icol; + } + const int icolPadded = icol + UNDILATED(lpad); + for (int ich = 0; ich < ichs; ich += ich_it) { + int K = ichs - ich > maxChsPerMvin ? maxChsPerMvin : ichs - ich; + if (transInput3120) { + K = batches - b > maxChsPerMvin ? maxChsPerMvin : batches - b; + } +#define DS(x) ((x) >> (downsample)) + uint32_t aSpAddr = aSpAddrStart + (ich / DIM) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + if (transInput3120) { + aSpAddr = aSpAddrStart + (b / DIM) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + } + const bool is_zeros = irow < 0 || irow >= irowsUnpadded || icol < 0 || icol >= icolsUnpadded; +// const elem_t * in = input + (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; + size_t offset = (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; + Value memAddr = input; + if (is_zeros) { + memAddr = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + offset = 0; +// in = NULL; + + } else if (transInput3120) { + offset = (ich*inRowDim*inColDim + irow*inColDim + icol) * batchSize + b; + } +// gemmini_extended_mvin(in, +// aSpAddr, +// K, I >> downsample); + gemminiMvinOffset(memAddr, offset, aSpAddr, K, I >> downsample, rewriter); + } + icol += I; + } + } + } + // mvin weights + if (weights != NULL) { + int max_chs_per_mvin = ochs < MAX_BLOCK_LEN * DIM ? ochs : + MAX_BLOCK_LEN * DIM; + if (transWeight0132) { + max_chs_per_mvin = kchs < MAX_BLOCK_LEN * DIM ? kchs : + MAX_BLOCK_LEN * DIM; + } + size_t dramStride = weightStride * sizeof(elem_t); + if (dw) { + dramStride = sizeof(elem_t); + } else if (transWeight1203) { + dramStride = kernelDim * kernelDim * outChannels * sizeof(elem_t); + } else if (transWeight0132) { + dramStride = inChannels * sizeof(elem_t); + } + const size_t spadBlockStride = transWeight0132 ? + krows * kcols * ochs : krows * kcols * kchs; +// gemmini_extended4_config_ld(dramStride, MVIN_SCALE_IDENTITY, false, +// spadBlockStride, 1); + Value dramStrideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride)); + rewriter.create(loc, dramStrideValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, spadBlockStride); + + const size_t och_it = transWeight0132 ? DIM : max_chs_per_mvin; + const size_t kch_it = transWeight0132 ? max_chs_per_mvin : DIM; + for (int och = 0; och < ochs; och += och_it) { + for (int krow = 0; krow < krows; krow++) + for (int kcol = 0; kcol < kcols; kcol++) + for (int kch = 0; kch < kchs; kch += kch_it) { + int K = kchs - kch > DIM ? DIM : kchs - kch; + int J = ochs - och > max_chs_per_mvin ? max_chs_per_mvin : ochs - och; + if (transWeight0132) { + K = ochs - och > DIM ? DIM : ochs - och; + J = kchs - kch > max_chs_per_mvin ? max_chs_per_mvin : kchs - kch; + } + uint32_t bSpAddr = bSpAddrStart + (och / DIM) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch; + if (transWeight0132) { + bSpAddr = bSpAddrStart + (kch / DIM) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och; + } + size_t offset = (krow*kernelDim*inChannels + kcol*inChannels + kch) * weightStride + och; + if (dw) { + offset = krow * kernelDim + kcol; + } else if (transWeight1203) { + offset = (kch * kernelDim * kernelDim + krow * kernelDim + kcol) * outChannels + och; + } else if (transWeight0132) { + offset = (krow * kernelDim * outChannels + kcol * outChannels + och) * inChannels + kch; + } + gemminiMvinOffset(weights, offset, bSpAddr, J, K, rewriter); +// gemmini_extended_mvin2(w, bSpAddr, J, K); + } + } + } + // Compute + { + const int b_it = transInput3120 ? DIM : 1; + const int ocol_it = transInput3120 ? 1 : (DIM << inputDilated); + if (transInput3120) { + rewriter.create( + loc, /*dataflow = */ OUTPUT_STATIONARY, /*act = */ 0, /*shift = */ 0, + /*scale = */ llvm::APFloat((float)0), /*cStride = */ orows * ocols, + /*aStride = */ irows * icols, + /*aTranspose = */ 0, /*bTranspose*/ 0, + /*setOnlyStrides = */ true); +// gemmini_extended3_config_ex(0, 0, 0, 0, orows * ocols, irows * icols, 0, +// 0, true); + } + for (int och = 0; och < ochs; och += DIM) { + for (int krow = 0; krow < krows; krow++) { + for (int kcol = 0; kcol < kcols; kcol += maxPixelsPerRow) { + for (int kch = 0; kch < kchs; kch += DIM) { + bool newWeights = true; + for (int b = 0; b < batches; b += b_it) { + for (int orow = 0; orow < orows; orow++) { + // Skip some kernel rows due to input-dilation + if (inputDilated && + ((krow * kernelDilation + orow * stride - upad) % 2 != + 0)) { + continue; + } + for (int ocol = 0; ocol < ocols;) { + // Skip some cols dimensions due to input-dilation + if (inputDilated && + ((kcol + ocol * stride - lpad) % 2 != 0)) { + ocol++; + continue; + } + int irow = orow * stride + krow * kernelDilation; + int icol = ocol * stride + kcol * kernelDilation; + if (inputDilated) { + irow = (irow + 1) / 2; + icol = (icol + 1) / 2; + } + const int pixels = kcols - kcol > maxPixelsPerRow + ? maxPixelsPerRow + : kcols - kcol; + const uint32_t cSpAddr = + cSpAddrStart + + (och / DIM) * batches * orows * ocols + + b * orows * ocols + orow * ocols + ocol; + // Over here, construct a new matrix + // + // Let us assume that we only ever operate on + // one pixel in one row. + // Thus, krows == kcols == 1 + // + // Then, for every set of I, J, and K values + // - I = ocols + // - J = ochs + // - K = kchs + int I = UNDILATED(ocols - ocol > (DIM << inputDilated) + ? (DIM << inputDilated) + : ocols - ocol); + const int J = ochs - och > DIM ? DIM : ochs - och; + const int K = + pixels * (kchs - kch > DIM ? DIM : kchs - kch); + if (transInput3120) { + I = batches - b > DIM ? DIM : batches - b; + } + uint32_t aSpAddr = + aSpAddrStart + + (kch / DIM) * batches * DS(irows) * DS(icols) + + b * DS(irows) * DS(icols) + DS(irow) * DS(icols) + + DS(icol); + if (transInput3120) { + aSpAddr = aSpAddrStart + + (b / DIM) * kchs * DS(irows) * DS(icols) + + kch * DS(irows) * DS(icols) + + DS(irow) * DS(icols) + DS(icol); + } + const int krow_ = wrot180 ? krows - krow - 1 : krow; + const int kcol_ = wrot180 ? kcols - kcol - 1 : kcol; + uint32_t bSpAddr = + bSpAddrStart + (och / DIM) * krows * kcols * kchs + + krow_ * kcols * kchs + kcol_ * kchs + kch; + if (transWeight0132) { + bSpAddr = bSpAddrStart + + (kch / DIM) * krows * kcols * ochs + + krow_ * kcols * ochs + kcol_ * ochs + och; + } + const uint32_t perSpAddr = + newWeights ? bSpAddr : GARBAGE_ADDR; + + Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); + Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(DIM)); + Value iOp = rewriter.create(loc, rewriter.getI64IntegerAttr(I)); + Value jOp = rewriter.create(loc, rewriter.getI64IntegerAttr(J)); + Value kOp = rewriter.create(loc, rewriter.getI64IntegerAttr(K)); + Value perSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(perSpAddr)); + Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); + Value cSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cSpAddr)); + + + + + + // perform matmul +// gemmini_extended_preload(perSpAddr, cSpAddr, J, K, J, +// I); + rewriter.create(loc, perSpAddrOp, cSpAddrOp, jOp, + kOp, jOp, iOp); + if (newWeights) { +// gemmini_extended_compute_preloaded( +// aSpAddr, GARBAGE_ADDR, K, I, J, I); + rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); + } else { +// gemmini_extended_compute_accumulated( +// aSpAddr, GARBAGE_ADDR, K, I, J, I); + rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); + } + ocol += ocol_it; + newWeights = false; + } + } + } + } + } + } + } + } +#undef DS +#undef UNDILATED + // mvout output + if (output != NULL) { + if (noPool) { + for (int b = 0; b < batches; b++) + for (int orow = 0; orow < orows; orow++) + for (int ocol = 0; ocol < ocols; ocol += DIM) { + const int I = ocols - ocol > DIM ? DIM : ocols - ocol; + for (int och = 0; och < ochs; och += DIM) { + const int J = ochs - och > DIM ? DIM : ochs - och; + const uint32_t cSpAddr = + cSpAddrStart + (och / DIM) * batches * orows * ocols + + b * orows * ocols + orow * ocols + ocol; + size_t outOffset = + (b * outRowDim * outColDim + + orow * outColDim + ocol) * + outStride + + och; + if (transOutput1203) { + outOffset = + (orow * outColDim * batchSize + ocol * batchSize + + b) * + outChannels + + och; + } + gemminiMvoutOffset(output, outOffset, cSpAddr, J, I, rewriter); +// gemmini_extended_mvout(out, cSpAddr, J, I); + } + } + } else { + printf("Pooling with rectangular convolutions is currently not supported.\n"); + exit(1); + } + } } void tiledConv(int batchSize, int inDim, int inChannels, int outChannels, @@ -1399,8 +1760,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { input, offsetValue); } - spTiledConv(batchSize, inDim, inChannels, outChannels, outDim, - poolOutDim, stride, padding, kernelDim, + spTiledConv(batchSize, inDim, inDim, inChannels, outChannels, outDim, outDim, + poolOutDim, poolOutDim, stride, padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding, batches_, porows_, pocols_, pochs_, krows_, kcols_, kchs_, lpad, rpad, upad, dpad, plpad, From 70e68eeb544b8a6833a69cf7093325eab702837b Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 01:46:22 +0800 Subject: [PATCH 26/51] modify by gemmini upstream --- midend/include/Dialect/Gemmini/Gemmini.td | 6 +- midend/include/Dialect/Gemmini/Transform.h | 1 + .../Transforms/LegalizeForLLVMExport.cpp | 270 +++++++++++------- 3 files changed, 163 insertions(+), 114 deletions(-) diff --git a/midend/include/Dialect/Gemmini/Gemmini.td b/midend/include/Dialect/Gemmini/Gemmini.td index e73889fe61..f2569a0c15 100644 --- a/midend/include/Dialect/Gemmini/Gemmini.td +++ b/midend/include/Dialect/Gemmini/Gemmini.td @@ -305,7 +305,7 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { MemRefRankOf<[AnyType], [2]>:$weights, MemRefRankOf<[AnyType], [1]>:$bias, MemRefRankOf<[AnyType], [2]>:$output, - I64:$outDim, I64:$kernelDim, + I64:$outRowDim, I64:$outColDim, I64:$kernelDim, DefaultValuedAttr:$scale, DefaultValuedAttr:$stride, DefaultValuedAttr:$inputDilation, @@ -321,8 +321,8 @@ def TileConvOp : Gemmini_Op<"tile_conv"> { DefaultValuedAttr:$poolStride, DefaultValuedAttr:$poolPadding); let assemblyFormat = [{ - $input $weights $bias $output $outDim $kernelDim attr-dict `:` type($input) - type($weights) type($bias) type($output) type($outDim) type($kernelDim) + $input $weights $bias $output $outRowDim $outColDim $kernelDim attr-dict `:` type($input) + type($weights) type($bias) type($output) type($outRowDim) type($outColDim) type($kernelDim) }]; } diff --git a/midend/include/Dialect/Gemmini/Transform.h b/midend/include/Dialect/Gemmini/Transform.h index a6b3ebff98..86d27cbd9e 100644 --- a/midend/include/Dialect/Gemmini/Transform.h +++ b/midend/include/Dialect/Gemmini/Transform.h @@ -30,6 +30,7 @@ #define OUTPUT_STATIONARY 0 #define WEIGHT_STATIONARY 1 +#define MVIN_SCALE_IDENTITY 1.0 #define ACC_SCALE_IDENTITY 1.0 #define BANK_NUM 4 #define BANK_ROWS 4096 diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 5ac7bd4236..1597356723 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -43,6 +43,12 @@ int64_t getNumberFromValue(Value &value) { .getInt(); } +int ceil_divide_int(int a, int b){ + int c = (a % b == 0) ? ((int)(a/b)) :(((int)(a/b)) + 1); + if(a < b) c = 1; + return c; +} + acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { union { acc_scale_t_bits b; @@ -1146,6 +1152,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { int outChannels, int outRowDim, int outColDim, int poolOutRowDim, int poolOutColDim, int stride, int padding, int kernelDim, int kernelDilation, + int inStride, int weightStride, int outStride, int poolSize, int poolStride, int poolPadding, int batches, int porows, int pocols, int pochs, int krows, int kcols, int kchs, @@ -1156,8 +1163,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { bool transWeight1203, bool transWeight0132, bool noBias, bool noPool, bool downsample, bool inputDilated, bool dw, TileConvOp &tileConvOp, - ConversionPatternRewriter &rewriter, - int inStride = 1, int weightStride = 1, int outStride = 1 + ConversionPatternRewriter &rewriter ) const { Location loc = tileConvOp.getLoc(); if (dw) { @@ -1204,8 +1210,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int maxPixelsPerRow = 1; #endif // Calculate spad address offsets - const int outChannelsPerBank = ochs / DIM + (ochs % DIM != 0); - const int inChannelsPerBank = kchs / DIM + (kchs % DIM != 0); + const int outChannelsPerBank = ochs / dim + (ochs % dim != 0); + const int inChannelsPerBank = kchs / dim + (kchs % dim != 0); const int bRows = transWeight0132 ? inChannelsPerBank * kcols * krows * ochs : outChannelsPerBank * kcols * krows * kchs; @@ -1214,8 +1220,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const uint32_t aSpAddrStart = 0; const uint32_t bSpAddrStart = BANK_NUM * BANK_ROWS - bRows; - const uint32_t dSpAddrStart = (1 << (ADDR_LEN - 1)) + dSpAddrRow; - const uint32_t cSpAddrStart = (3 << (ADDR_LEN - 2)) + cSpAddrRow; + const uint32_t dSpAddrStart = (1 << (addrLen - 1)) + dSpAddrRow; + const uint32_t cSpAddrStart = (3 << (addrLen - 2)) + cSpAddrRow; if (bias != 0) { dSpAddrRow = (dSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; @@ -1239,10 +1245,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } // Only rectangular convolutions will use the following C code // mvin bias + const size_t maxBlockLen = MAX_BYTES / (dim * 1); + const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); if (bias != NULL) { // TODO we probably don't need quite this many nested loops for this part - const int maxOchsPerMvin = ochs < MAX_BLOCK_LEN_ACC * DIM ? ochs : - MAX_BLOCK_LEN_ACC * DIM; + const int maxOchsPerMvin = ochs < maxBlockLenAcc * dim ? ochs : + maxBlockLenAcc * dim; Value zeroValue = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); // rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, ) // TODO: configLd op 这里不够用,需要加block_mvinStride和pixel_repeats,应该不难 @@ -1250,35 +1258,35 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); for (int b = 0; b < batches; b++) for (int orow = 0; orow < orows; orow++) - for (int ocol = 0; ocol < ocols; ocol += DIM) { - const int I = ocols - ocol > DIM ? DIM : ocols - ocol; + for (int ocol = 0; ocol < ocols; ocol += dim) { + const int I = ocols - ocol > dim ? dim : ocols - ocol; for (int och = 0; och < ochs; och += maxOchsPerMvin) { const int J = ochs - och > maxOchsPerMvin ? maxOchsPerMvin : ochs - och; - const uint32_t dSpAddr = dSpAddrStart + (och / DIM) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; + const uint32_t dSpAddr = dSpAddrStart + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; // const acc_t * bias_dram_addr = noBias ? NULL : bias + och; // gemmini_extended_mvin3(bias_dram_addr, // dSpAddr, // J, I); if (noBias) { // Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - gemminiMvinOffset(zeroValue, 0, dSpAddr, J, I, rewriter); + gemminiMvinOffset(zeroValue, 0, dSpAddr, J, I, addrLen, rewriter); } else { - gemminiMvinOffset(bias, och, dSpAddr, J, I, rewriter); + gemminiMvinOffset(bias, och, dSpAddr, J, I, addrLen, rewriter); } } } } // mvin input if (input != NULL){ - int maxChsPerMvin = ichs < MAX_BLOCK_LEN * DIM ? ichs : - MAX_BLOCK_LEN * DIM; + int maxChsPerMvin = ichs < maxBlockLen * dim ? ichs : + maxBlockLen * dim; if (transInput3120) { - maxChsPerMvin = batches < MAX_BLOCK_LEN * DIM ? batches : - MAX_BLOCK_LEN * DIM; + maxChsPerMvin = batches < maxBlockLen * dim ? batches : + maxBlockLen * dim; } const int dramStride = transInput3120 ? - batchSize * sizeof(elem_t) : - inChannels * sizeof(elem_t); + batchSize * sizeOfElemT : + inChannels * sizeOfElemT; const int spadStride = transInput3120 ? ichs * (irows >> downsample) * (icols >> downsample) : batches * (irows >> downsample) * (icols >> downsample); @@ -1292,12 +1300,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int irowPadded = irow + UNDILATED(upad); for (int icol = -UNDILATED(lpad); icol < icolsUnpadded + UNDILATED(rpad);) { // TODO There might be some unnecessary mvins here at the edge of the image - int I = icolsUnpadded - icol > (DIM << downsample) ? - (DIM << downsample) : icolsUnpadded - icol; + int I = icolsUnpadded - icol > (dim << downsample) ? + (dim << downsample) : icolsUnpadded - icol; if (icol < 0) { - I = -icol > DIM ? DIM : -icol; + I = -icol > dim ? dim : -icol; } else if (icol >= icolsUnpadded) { - I = icolsUnpadded + UNDILATED(rpad) - icol > DIM ? DIM : icolsUnpadded + UNDILATED(rpad) - icol; + I = icolsUnpadded + UNDILATED(rpad) - icol > dim ? dim : icolsUnpadded + UNDILATED(rpad) - icol; } const int icolPadded = icol + UNDILATED(lpad); for (int ich = 0; ich < ichs; ich += ich_it) { @@ -1306,9 +1314,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { K = batches - b > maxChsPerMvin ? maxChsPerMvin : batches - b; } #define DS(x) ((x) >> (downsample)) - uint32_t aSpAddr = aSpAddrStart + (ich / DIM) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + uint32_t aSpAddr = aSpAddrStart + (ich / dim) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); if (transInput3120) { - aSpAddr = aSpAddrStart + (b / DIM) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); + aSpAddr = aSpAddrStart + (b / dim) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); } const bool is_zeros = irow < 0 || irow >= irowsUnpadded || icol < 0 || icol >= icolsUnpadded; // const elem_t * in = input + (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; @@ -1325,7 +1333,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { // gemmini_extended_mvin(in, // aSpAddr, // K, I >> downsample); - gemminiMvinOffset(memAddr, offset, aSpAddr, K, I >> downsample, rewriter); + gemminiMvinOffset(memAddr, offset, aSpAddr, K, I >> downsample, addrLen, rewriter); } icol += I; } @@ -1333,19 +1341,19 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } // mvin weights if (weights != NULL) { - int max_chs_per_mvin = ochs < MAX_BLOCK_LEN * DIM ? ochs : - MAX_BLOCK_LEN * DIM; + int max_chs_per_mvin = ochs < maxBlockLen * dim ? ochs : + maxBlockLen * dim; if (transWeight0132) { - max_chs_per_mvin = kchs < MAX_BLOCK_LEN * DIM ? kchs : - MAX_BLOCK_LEN * DIM; + max_chs_per_mvin = kchs < maxBlockLen * dim ? kchs : + maxBlockLen * dim; } - size_t dramStride = weightStride * sizeof(elem_t); + size_t dramStride = weightStride * sizeOfElemT; if (dw) { - dramStride = sizeof(elem_t); + dramStride = sizeOfElemT; } else if (transWeight1203) { - dramStride = kernelDim * kernelDim * outChannels * sizeof(elem_t); + dramStride = kernelDim * kernelDim * outChannels * sizeOfElemT; } else if (transWeight0132) { - dramStride = inChannels * sizeof(elem_t); + dramStride = inChannels * sizeOfElemT; } const size_t spadBlockStride = transWeight0132 ? krows * kcols * ochs : krows * kcols * kchs; @@ -1354,21 +1362,21 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { Value dramStrideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride)); rewriter.create(loc, dramStrideValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, spadBlockStride); - const size_t och_it = transWeight0132 ? DIM : max_chs_per_mvin; - const size_t kch_it = transWeight0132 ? max_chs_per_mvin : DIM; + const size_t och_it = transWeight0132 ? dim : max_chs_per_mvin; + const size_t kch_it = transWeight0132 ? max_chs_per_mvin : dim; for (int och = 0; och < ochs; och += och_it) { for (int krow = 0; krow < krows; krow++) for (int kcol = 0; kcol < kcols; kcol++) for (int kch = 0; kch < kchs; kch += kch_it) { - int K = kchs - kch > DIM ? DIM : kchs - kch; + int K = kchs - kch > dim ? dim : kchs - kch; int J = ochs - och > max_chs_per_mvin ? max_chs_per_mvin : ochs - och; if (transWeight0132) { - K = ochs - och > DIM ? DIM : ochs - och; + K = ochs - och > dim ? dim : ochs - och; J = kchs - kch > max_chs_per_mvin ? max_chs_per_mvin : kchs - kch; } - uint32_t bSpAddr = bSpAddrStart + (och / DIM) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch; + uint32_t bSpAddr = bSpAddrStart + (och / dim) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch; if (transWeight0132) { - bSpAddr = bSpAddrStart + (kch / DIM) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och; + bSpAddr = bSpAddrStart + (kch / dim) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och; } size_t offset = (krow*kernelDim*inChannels + kcol*inChannels + kch) * weightStride + och; if (dw) { @@ -1378,15 +1386,15 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } else if (transWeight0132) { offset = (krow * kernelDim * outChannels + kcol * outChannels + och) * inChannels + kch; } - gemminiMvinOffset(weights, offset, bSpAddr, J, K, rewriter); + gemminiMvinOffset(weights, offset, bSpAddr, J, K, addrLen, rewriter); // gemmini_extended_mvin2(w, bSpAddr, J, K); } } } // Compute { - const int b_it = transInput3120 ? DIM : 1; - const int ocol_it = transInput3120 ? 1 : (DIM << inputDilated); + const int b_it = transInput3120 ? dim : 1; + const int ocol_it = transInput3120 ? 1 : (dim << inputDilated); if (transInput3120) { rewriter.create( loc, /*dataflow = */ OUTPUT_STATIONARY, /*act = */ 0, /*shift = */ 0, @@ -1397,10 +1405,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { // gemmini_extended3_config_ex(0, 0, 0, 0, orows * ocols, irows * icols, 0, // 0, true); } - for (int och = 0; och < ochs; och += DIM) { + for (int och = 0; och < ochs; och += dim) { for (int krow = 0; krow < krows; krow++) { for (int kcol = 0; kcol < kcols; kcol += maxPixelsPerRow) { - for (int kch = 0; kch < kchs; kch += DIM) { + for (int kch = 0; kch < kchs; kch += dim) { bool newWeights = true; for (int b = 0; b < batches; b += b_it) { for (int orow = 0; orow < orows; orow++) { @@ -1428,7 +1436,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { : kcols - kcol; const uint32_t cSpAddr = cSpAddrStart + - (och / DIM) * batches * orows * ocols + + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; // Over here, construct a new matrix // @@ -1440,41 +1448,41 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { // - I = ocols // - J = ochs // - K = kchs - int I = UNDILATED(ocols - ocol > (DIM << inputDilated) - ? (DIM << inputDilated) + int I = UNDILATED(ocols - ocol > (dim << inputDilated) + ? (dim << inputDilated) : ocols - ocol); - const int J = ochs - och > DIM ? DIM : ochs - och; + const int J = ochs - och > dim ? dim : ochs - och; const int K = - pixels * (kchs - kch > DIM ? DIM : kchs - kch); + pixels * (kchs - kch > dim ? dim : kchs - kch); if (transInput3120) { - I = batches - b > DIM ? DIM : batches - b; + I = batches - b > dim ? dim : batches - b; } uint32_t aSpAddr = aSpAddrStart + - (kch / DIM) * batches * DS(irows) * DS(icols) + + (kch / dim) * batches * DS(irows) * DS(icols) + b * DS(irows) * DS(icols) + DS(irow) * DS(icols) + DS(icol); if (transInput3120) { aSpAddr = aSpAddrStart + - (b / DIM) * kchs * DS(irows) * DS(icols) + + (b / dim) * kchs * DS(irows) * DS(icols) + kch * DS(irows) * DS(icols) + DS(irow) * DS(icols) + DS(icol); } const int krow_ = wrot180 ? krows - krow - 1 : krow; const int kcol_ = wrot180 ? kcols - kcol - 1 : kcol; uint32_t bSpAddr = - bSpAddrStart + (och / DIM) * krows * kcols * kchs + + bSpAddrStart + (och / dim) * krows * kcols * kchs + krow_ * kcols * kchs + kcol_ * kchs + kch; if (transWeight0132) { bSpAddr = bSpAddrStart + - (kch / DIM) * krows * kcols * ochs + + (kch / dim) * krows * kcols * ochs + krow_ * kcols * ochs + kcol_ * ochs + och; } const uint32_t perSpAddr = newWeights ? bSpAddr : GARBAGE_ADDR; Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); - Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(DIM)); + Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); Value iOp = rewriter.create(loc, rewriter.getI64IntegerAttr(I)); Value jOp = rewriter.create(loc, rewriter.getI64IntegerAttr(J)); Value kOp = rewriter.create(loc, rewriter.getI64IntegerAttr(K)); @@ -1517,12 +1525,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (noPool) { for (int b = 0; b < batches; b++) for (int orow = 0; orow < orows; orow++) - for (int ocol = 0; ocol < ocols; ocol += DIM) { - const int I = ocols - ocol > DIM ? DIM : ocols - ocol; - for (int och = 0; och < ochs; och += DIM) { - const int J = ochs - och > DIM ? DIM : ochs - och; + for (int ocol = 0; ocol < ocols; ocol += dim) { + const int I = ocols - ocol > dim ? dim : ocols - ocol; + for (int och = 0; och < ochs; och += dim) { + const int J = ochs - och > dim ? dim : ochs - och; const uint32_t cSpAddr = - cSpAddrStart + (och / DIM) * batches * orows * ocols + + cSpAddrStart + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; size_t outOffset = (b * outRowDim * outColDim + @@ -1536,7 +1544,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { outChannels + och; } - gemminiMvoutOffset(output, outOffset, cSpAddr, J, I, rewriter); + gemminiMvoutOffset(output, outOffset, cSpAddr, J, I, addrLen, rewriter); // gemmini_extended_mvout(out, cSpAddr, J, I); } } @@ -1547,9 +1555,11 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } } - void tiledConv(int batchSize, int inDim, int inChannels, int outChannels, - int outDim, int stride, int inputDilation, int kernelDilation, - int padding, int kernelDim, bool wrot180, bool transOutput1203, + void tiledConv(int batchSize, int inRowDim, int inColDim, int inChannels, int outChannels, + int outRowDim, int outColDim, int stride, int inputDilation, int kernelDilation, + int padding, int kernelDim, + int inStride, int weightStride, int outStride, + bool wrot180, bool transOutput1203, bool transInput3120, bool transWeight1203, bool transWeight0132, int batches, int porows, int pocols, int pochs, int krows, int kcols, int kchs, const Value &input, @@ -1564,7 +1574,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { poolStride = 1; poolPadding = 0; } - const bool downsample = stride == 2 && kernelDim == 1 && inDim % 2 == 0 && + const bool downsample = stride == 2 && kernelDim == 1 && inRowDim % 2 == 0 && inColDim % 2 == 0 && padding == 0 && noPool && inputDilation == 1 && !transInput3120; const int inputDilated = inputDilation == 2; @@ -1581,13 +1591,37 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { /*aStride = */ stride >> downsample, /*aTranspose = */ transInput3120, /*bTranspose*/ transWeight0132, /*setOnlyStrides = */ false); - const int poolOutDim = - (outDim + 2 * poolPadding - poolSize) / poolStride + 1; - const int dilatedInDim = inDim + (inputDilation - 1) * (inDim - 1); + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int dilatedInRowDim = inRowDim + (inputDilation - 1) * (inRowDim - 1); + const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); + + size_t aSpadId = 0; + size_t bSpadId = 0; + + int porowEnd = poolOutRowDim; + int porowStart = 0; + bool a_reuse = false; + bool b_reuse = false; + size_t num_kch = ceil_divide_int(inChannels, kchs); + size_t num_poch = ceil_divide_int(outChannels, pochs); + size_t num_b = ceil_divide_int(batchSize, batches); + size_t num_porow = ceil_divide_int((porowEnd - porowStart), porows); + size_t num_pocol = ceil_divide_int(poolOutColDim, pocols); + size_t num_krow = ceil_divide_int(kernelDim, krows); + size_t num_kcol = ceil_divide_int(kernelDim, kcols); + + if(num_kch * num_poch * num_krow * num_kcol <= 2) + b_reuse = true; + if(num_kch * num_krow * num_kcol * num_b * num_porow * num_pocol <= 2) + a_reuse = true; + for (int b = 0; b < batchSize; b += batches) { - for (int porow = 0; porow < poolOutDim; porow += porows) { + for (int porow = 0; porow < porowEnd; porow += porows) { const int orow = porow * poolStride - poolPadding; - for (int pocol = 0; pocol < poolOutDim; pocol += pocols) { + for (int pocol = 0; pocol < poolOutColDim; pocol += pocols) { const int ocol = pocol * poolStride - poolPadding; for (int poch = 0; poch < outChannels; poch += pochs) { for (int krow = 0; krow < kernelDim; krow += krows) { @@ -1602,8 +1636,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { for (int kch = 0; kch < inChannels; kch += kchs) { TypedAttr offsetAttr = - rewriter.getI64IntegerAttr(((b * poolOutDim * poolOutDim + - porow * poolOutDim + pocol) * + rewriter.getI64IntegerAttr(((b * poolOutRowDim * poolOutColDim + + porow * poolOutColDim + pocol) * outChannels + poch) * sizeOfElemT); @@ -1614,7 +1648,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); if (transOutput1203) { offsetAttr = rewriter.getI64IntegerAttr( - ((porow * poolOutDim * batchSize + pocol * batchSize + + ((porow * poolOutColDim * batchSize + pocol * batchSize + b) * outChannels + poch) * @@ -1645,9 +1679,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int batches_ = batchSize - b > batches ? batches : batchSize - b; const int porows_ = - poolOutDim - porow > porows ? porows : poolOutDim - porow; + poolOutRowDim - porow > porows ? porows : poolOutRowDim - porow; const int pocols_ = - poolOutDim - pocol > pocols ? pocols : poolOutDim - pocol; + poolOutColDim - pocol > pocols ? pocols : poolOutColDim - pocol; const int pochs_ = outChannels - poch > pochs ? pochs : outChannels - poch; const int krows_ = @@ -1662,10 +1696,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int plpad = ocol < 0 ? -ocol : 0; const int prpad = - ocol + ocols_ > outDim ? ocol + ocols_ - outDim : 0; + ocol + ocols_ > outColDim ? ocol + ocols_ - outColDim : 0; const int pupad = orow < 0 ? -orow : 0; const int pdpad = - orow + orows_ > outDim ? orow + orows_ - outDim : 0; + orow + orows_ > outRowDim ? orow + orows_ - outRowDim : 0; const int dilatedKrows_ = krows_ + (kernelDilation - 1) * (krows_ - 1); @@ -1678,12 +1712,12 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { (orows_ - pupad - pdpad) * stride + dilatedKrows_ - 1; int lpad = icol < 0 ? -icol : 0; - int rpad = icol + icols_ > dilatedInDim - ? icol + icols_ - dilatedInDim + int rpad = icol + icols_ > dilatedInColDim + ? icol + icols_ - dilatedInColDim : 0; int upad = irow < 0 ? -irow : 0; - int dpad = irow + irows_ > dilatedInDim - ? irow + irows_ - dilatedInDim + int dpad = irow + irows_ > dilatedInRowDim + ? irow + irows_ - dilatedInRowDim : 0; if (inputDilated) { @@ -1736,8 +1770,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); } offsetAttr = rewriter.getI64IntegerAttr( - ((b * inDim * inDim + - ((irow + upad) >> inputDilated) * inDim + + ((b * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + ((icol + lpad) >> inputDilated)) * inChannels + kch) * @@ -1749,8 +1783,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offsetValue); if (transInput3120) { offsetAttr = rewriter.getI64IntegerAttr( - ((kch * inDim * inDim + - ((irow + upad) >> inputDilated) * inDim + + ((kch * inRowDim * inColDim + + ((irow + upad) >> inputDilated) * inColDim + ((icol + lpad) >> inputDilated)) * batchSize + b) * @@ -1760,9 +1794,11 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { input, offsetValue); } - spTiledConv(batchSize, inDim, inDim, inChannels, outChannels, outDim, outDim, - poolOutDim, poolOutDim, stride, padding, kernelDim, - kernelDilation, poolSize, poolStride, poolPadding, + spTiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, outColDim, + poolOutRowDim, poolOutColDim, stride, padding, kernelDim, + kernelDilation, + inStride, weightStride, outStride, + poolSize, poolStride, poolPadding, batches_, porows_, pocols_, pochs_, krows_, kcols_, kchs_, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, in, weightsSlice, out, bias_, @@ -1826,9 +1862,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit GemminiTileConvLowering(LLVMTypeConverter &typeConverter, - int64_t dim, size_t sizeOfElemT, + int64_t dim, int64_t addrLen, size_t sizeOfElemT, size_t sizeOfAccT) - : ConvertOpToLLVMPattern(typeConverter), dim(dim), + : ConvertOpToLLVMPattern(typeConverter), dim(dim), addrLen(addrLen), sizeOfElemT(sizeOfElemT), sizeOfAccT(sizeOfAccT) {} LogicalResult matchAndRewrite(TileConvOp tileConvOp, OpAdaptor adaptor, @@ -1846,21 +1882,24 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { ArrayRef weightsShape = weightsType.getShape(); ArrayRef biasShape = biasType.getShape(); // inDim - if (inputShape[1] != inputShape[2]) { - llvm::outs() << "inDim error.\n"; - return failure(); - } + // if (inputShape[1] != inputShape[2]) { + // llvm::outs() << "inDim error.\n"; + // return failure(); + // } // outChannels - if (biasShape[0] != outputShape[1] || biasShape[0] != weightsShape[1]) { - llvm::outs() << "outChannels error.\n"; - return failure(); - } - Value outDimValue = tileConvOp.getOutDim(); - int outDim = getNumberFromValue(outDimValue); + // if (biasShape[0] != outputShape[1] || biasShape[0] != weightsShape[1]) { + // llvm::outs() << "outChannels error.\n"; + // return failure(); + // } + Value outRowDimValue = tileConvOp.getOutRowDim(); + int outRowDim = getNumberFromValue(outRowDimValue); + Value outColDimValue = tileConvOp.getOutColDim(); + int outColDim = getNumberFromValue(outColDimValue); Value kernelDimValue = tileConvOp.getKernelDim(); int kernelDim = getNumberFromValue(kernelDimValue); int batchSize = inputShape[0]; - int inDim = inputShape[1]; + int inRowDim = inputShape[1]; + int inColDim = inputShape[2]; int inChannels = inputShape[3]; int outChannels = biasShape[0]; int stride = tileConvOp.getStride(); @@ -1901,13 +1940,15 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { poolStride = 1; poolPadding = 0; } - const int poolOutDim = - (outDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutRowDim = + (outRowDim + 2 * poolPadding - poolSize) / poolStride + 1; + const int poolOutColDim = + (outColDim + 2 * poolPadding - poolSize) / poolStride + 1; const bool downsample = stride == 2 && kernelDim == 1 && padding == 0 && - noPool && inDim % 2 == 0; - int args[] = {batchSize, poolOutDim, poolOutDim, outChannels, + noPool && inRowDim % 2 == 0 && inColDim % 2 == 0; + int args[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, kernelDim, kernelDim, inChannels}; - const int maxArgs[] = {batchSize, poolOutDim, poolOutDim, outChannels, + const int maxArgs[] = {batchSize, poolOutRowDim, poolOutColDim, outChannels, kernelDim, kernelDim, inChannels}; const int orowsIdx = 1; const int ocolsIdx = 2; @@ -2015,9 +2056,15 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int krows = args[4]; const int kcols = args[5]; const int kchs = args[6]; - tiledConv(batchSize, inDim, inChannels, outChannels, outDim, stride, - inputDilation, kernelDilation, padding, kernelDim, wrot180, - transOutput1203, transInput3120, transWeight1203, transWeight0132, + + const int inStride = inChannels; + const int outStride = outChannels; + const int weightStride = outChannels; + tiledConv(batchSize, inRowDim, inColDim, inChannels, outChannels, outRowDim, outColDim, + stride, + inputDilation, kernelDilation, padding, kernelDim, + inStride, weightStride, outStride, + wrot180, transOutput1203, transInput3120, transWeight1203, transWeight0132, batches, orows, ocols, ochs, krows, kcols, kchs, inputIndexCastOp, weightsIndexCastOp, biasIndexCastOp, outputIndexCastOp, act, scale, poolSize, noPool ? 0 : poolStride, poolPadding, tileConvOp, @@ -2027,6 +2074,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { private: int64_t dim; + int64_t addrLen; size_t sizeOfElemT; size_t sizeOfAccT; @@ -2053,7 +2101,7 @@ void mlir::populateGemminiLegalizeForLLVMExportPatterns( patterns.add(converter, addrLen); patterns.add(converter, dim, addrLen, sizeOfElemT, sizeOfAccT); - patterns.add(converter, dim, sizeOfElemT, + patterns.add(converter, dim, addrLen, sizeOfElemT, sizeOfAccT); } From f7237aa3c968f86fc94ef0ff26b28688eb15f77a Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 01:46:48 +0800 Subject: [PATCH 27/51] fix bug: tile_conv interface changes --- .../Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index 19cc8c0608..5930d33705 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -182,7 +182,7 @@ class Conv2DNchwFchwLowering loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(weightsShape[2])); rewriter.create( - loc, inputMat, weightsMat, bias, outputMat, outDim, kernelDim, + loc, inputMat, weightsMat, bias, outputMat, outDim, outDim, kernelDim, llvm::APFloat(float(1.0)), strides, dilations); rewriter.eraseOp(convOp); loopIvs0.clear(); @@ -309,7 +309,7 @@ class Conv2DNhwcHwcfLowering attr = rewriter.getI64IntegerAttr(kernelShape[1]); kernelDim = rewriter.create(loc, attr); rewriter.create( - loc, input, kernelMat, bias, outputMat, outDim, kernelDim, + loc, input, kernelMat, bias, outputMat, outDim, outDim, kernelDim, llvm::APFloat(float(1.0)), strides, dilations); // after the conv operation is completed, the data in outputmat needs to be // transferred into output. From 58f450fc69a80dc2175fee1e31ce4c6eba46cdf2 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 15:08:59 +0800 Subject: [PATCH 28/51] modify tile_conv case in ciface.mlir and tile-conv.mlir; fix LegalizeForLLVMExport.cpp --- examples/GemminiDialect/ciface.mlir | 12 ++++----- examples/GemminiDialect/tile-conv.mlir | 4 +-- .../Transforms/LegalizeForLLVMExport.cpp | 25 +++++++++++-------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/GemminiDialect/ciface.mlir b/examples/GemminiDialect/ciface.mlir index 004992f2a3..070ca581b2 100644 --- a/examples/GemminiDialect/ciface.mlir +++ b/examples/GemminiDialect/ciface.mlir @@ -127,7 +127,7 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13 func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) { %outdim = arith.constant 254 : i64 %kernelDim = arith.constant 3 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 return } @@ -136,7 +136,7 @@ func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8 func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) { %outdim = arith.constant 252 : i64 %kernelDim = arith.constant 5 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 return } @@ -145,7 +145,7 @@ func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) { %outdim = arith.constant 250 : i64 %kernelDim = arith.constant 7 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 return } @@ -154,7 +154,7 @@ func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) { %outdim = arith.constant 248 : i64 %kernelDim = arith.constant 9 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 return } @@ -163,7 +163,7 @@ func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) { %outdim = arith.constant 246 : i64 %kernelDim = arith.constant 11 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 return } @@ -172,7 +172,7 @@ func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1x func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) { %outdim = arith.constant 244 : i64 %kernelDim = arith.constant 13 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 return } diff --git a/examples/GemminiDialect/tile-conv.mlir b/examples/GemminiDialect/tile-conv.mlir index 42f0085ce3..6c85572a48 100644 --- a/examples/GemminiDialect/tile-conv.mlir +++ b/examples/GemminiDialect/tile-conv.mlir @@ -32,8 +32,8 @@ func.func @main() -> i64 { // CHECK: "gemmini.intr.loop_conv_ws_config6" // CHECK: "gemmini.intr.loop_conv_ws" // CHECK: "gemmini.intr.flush" - gemmini.tile_conv %input %weight %bias %output %3 %3 {stride = 1}: - memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> return %0 : i64 } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 1597356723..88fe1248d2 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1230,18 +1230,21 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; } - gemminiLoopConvWs( - batchSize, inRowDim, inChannels, outChannels, outRowDim, poolOutRowDim, stride, - padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding, - batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, - dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, - input, noBias, noPool, downsample, wrot180, inputDilated, act, - transOutput1203, transWeight1203, transWeight0132, transInput3120, - maxPixelsPerRow, dw, tileConvOp, rewriter); + if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { + gemminiLoopConvWs( + batchSize, inRowDim, inChannels, outChannels, outRowDim, + poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, + poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, + kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, + ocols, weights, output, bias, input, noBias, noPool, downsample, + wrot180, inputDilated, act, transOutput1203, transWeight1203, + transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, + rewriter); + return; + } if (!noPool) { - // TODO: Exit, but now I don't known how to do - // printf("Pooling with rectangular convolutions is currently not supported.\n"); - // exit(1); + llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n"; + return; } // Only rectangular convolutions will use the following C code // mvin bias From b57035876958306e220f3720ae779f2a60f4087d Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 15:21:45 +0800 Subject: [PATCH 29/51] fix writing error in ciface.mlir --- examples/GemminiDialect/ciface.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/GemminiDialect/ciface.mlir b/examples/GemminiDialect/ciface.mlir index 070ca581b2..e45b6bed28 100644 --- a/examples/GemminiDialect/ciface.mlir +++ b/examples/GemminiDialect/ciface.mlir @@ -128,7 +128,7 @@ func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8 %outdim = arith.constant 254 : i64 %kernelDim = arith.constant 3 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 i64 return } @@ -137,7 +137,7 @@ func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi %outdim = arith.constant 252 : i64 %kernelDim = arith.constant 5 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 i64 return } @@ -146,7 +146,7 @@ func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi %outdim = arith.constant 250 : i64 %kernelDim = arith.constant 7 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 i64 return } @@ -155,7 +155,7 @@ func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi %outdim = arith.constant 248 : i64 %kernelDim = arith.constant 9 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 i64 return } @@ -164,7 +164,7 @@ func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1x %outdim = arith.constant 246 : i64 %kernelDim = arith.constant 11 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 i64 return } @@ -173,7 +173,7 @@ func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1x %outdim = arith.constant 244 : i64 %kernelDim = arith.constant 13 : i64 gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : - memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 + memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 i64 return } From cebddcd47176e815c496232d554b4dc7563503f6 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 20:01:02 +0800 Subject: [PATCH 30/51] fix bug in RISCVInstrInfoBuddyExt.td and delete comments in LegalizeForLLVMExport.cpp --- .../Target/RISCV/RISCVInstrInfoBuddyExt.td | 10 ++++- .../Transforms/LegalizeForLLVMExport.cpp | 38 ++----------------- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td index b5491d5045..adc172ab2f 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -37,13 +37,13 @@ def MVIN : RVInstR<0b0000010, 0b011, OPC_CUSTOM_3, (outs), let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in def MVIN2 : RVInstR<0b0000001, 0b011, OPC_CUSTOM_3, (outs), - (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + (ins GPR:$rs1, GPR:$rs2), "mvin2","$rs1, $rs2"> { let rd = 0; } let hasSideEffects = 1, mayLoad = 1, mayStore = 1, Predicates = [HasBuddyExt] in def MVIN3 : RVInstR<0b0001110, 0b011, OPC_CUSTOM_3, (outs), - (ins GPR:$rs1, GPR:$rs2), "mvin","$rs1, $rs2"> { + (ins GPR:$rs1, GPR:$rs2), "mvin3","$rs1, $rs2"> { let rd = 0; } @@ -182,6 +182,12 @@ def LOOP_CONV_WS_CONFIG6 : RVInstR<0b0010101, 0b011, OPC_CUSTOM_3, (outs), let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin2 GPR:$rs1, GPR:$rs2), (MVIN2 GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasBuddyExt] in +def : Pat<(int_riscv_mvin3 GPR:$rs1, GPR:$rs2), (MVIN3 GPR:$rs1, GPR:$rs2)>; + let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_mvout GPR:$rs1, GPR:$rs2), (MVOUT GPR:$rs1, GPR:$rs2)>; diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 88fe1248d2..9fb8750856 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1255,10 +1255,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int maxOchsPerMvin = ochs < maxBlockLenAcc * dim ? ochs : maxBlockLenAcc * dim; Value zeroValue = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - // rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, ) - // TODO: configLd op 这里不够用,需要加block_mvinStride和pixel_repeats,应该不难 -// gemmini_extended4_config_ld(); - rewriter.create(loc, zeroValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); + rewriter.create(loc, zeroValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); for (int b = 0; b < batches; b++) for (int orow = 0; orow < orows; orow++) for (int ocol = 0; ocol < ocols; ocol += dim) { @@ -1266,12 +1263,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { for (int och = 0; och < ochs; och += maxOchsPerMvin) { const int J = ochs - och > maxOchsPerMvin ? maxOchsPerMvin : ochs - och; const uint32_t dSpAddr = dSpAddrStart + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; -// const acc_t * bias_dram_addr = noBias ? NULL : bias + och; -// gemmini_extended_mvin3(bias_dram_addr, -// dSpAddr, -// J, I); if (noBias) { -// Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); gemminiMvinOffset(zeroValue, 0, dSpAddr, J, I, addrLen, rewriter); } else { gemminiMvinOffset(bias, och, dSpAddr, J, I, addrLen, rewriter); @@ -1294,8 +1286,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { ichs * (irows >> downsample) * (icols >> downsample) : batches * (irows >> downsample) * (icols >> downsample); Value strideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride << downsample)); - rewriter.create(loc, strideValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 0, spadStride, maxPixelsPerRow); -// gemmini_extended5_config_ld(dramStride << downsample, MVIN_SCALE_IDENTITY, false, spadStride, maxPixelsPerRow, 0); + rewriter.create(loc, strideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 0, spadStride, maxPixelsPerRow); const int b_it = transInput3120 ? maxChsPerMvin : 1; const int ich_it = transInput3120 ? 1 : maxChsPerMvin; for (int b = 0; b < batches; b += b_it) @@ -1322,20 +1313,14 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { aSpAddr = aSpAddrStart + (b / dim) * ichs * DS(irows) * DS(icols) + ich * DS(irows) * DS(icols) + DS(irowPadded) * DS(icols) + DS(icolPadded); } const bool is_zeros = irow < 0 || irow >= irowsUnpadded || icol < 0 || icol >= icolsUnpadded; -// const elem_t * in = input + (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; size_t offset = (b*inRowDim*inColDim + irow*inColDim + icol) * inStride + ich; Value memAddr = input; if (is_zeros) { memAddr = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); offset = 0; -// in = NULL; - } else if (transInput3120) { offset = (ich*inRowDim*inColDim + irow*inColDim + icol) * batchSize + b; } -// gemmini_extended_mvin(in, -// aSpAddr, -// K, I >> downsample); gemminiMvinOffset(memAddr, offset, aSpAddr, K, I >> downsample, addrLen, rewriter); } icol += I; @@ -1360,10 +1345,8 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } const size_t spadBlockStride = transWeight0132 ? krows * kcols * ochs : krows * kcols * kchs; -// gemmini_extended4_config_ld(dramStride, MVIN_SCALE_IDENTITY, false, -// spadBlockStride, 1); Value dramStrideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride)); - rewriter.create(loc, dramStrideValue, llvm::APFloat(MVIN_SCALE_IDENTITY), false, 2, spadBlockStride); + rewriter.create(loc, dramStrideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 2, spadBlockStride); const size_t och_it = transWeight0132 ? dim : max_chs_per_mvin; const size_t kch_it = transWeight0132 ? max_chs_per_mvin : dim; @@ -1390,7 +1373,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { offset = (krow * kernelDim * outChannels + kcol * outChannels + och) * inChannels + kch; } gemminiMvinOffset(weights, offset, bSpAddr, J, K, addrLen, rewriter); -// gemmini_extended_mvin2(w, bSpAddr, J, K); } } } @@ -1405,8 +1387,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { /*aStride = */ irows * icols, /*aTranspose = */ 0, /*bTranspose*/ 0, /*setOnlyStrides = */ true); -// gemmini_extended3_config_ex(0, 0, 0, 0, orows * ocols, irows * icols, 0, -// 0, true); } for (int och = 0; och < ochs; och += dim) { for (int krow = 0; krow < krows; krow++) { @@ -1493,22 +1473,11 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); Value cSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cSpAddr)); - - - - - // perform matmul -// gemmini_extended_preload(perSpAddr, cSpAddr, J, K, J, -// I); rewriter.create(loc, perSpAddrOp, cSpAddrOp, jOp, kOp, jOp, iOp); if (newWeights) { -// gemmini_extended_compute_preloaded( -// aSpAddr, GARBAGE_ADDR, K, I, J, I); rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); } else { -// gemmini_extended_compute_accumulated( -// aSpAddr, GARBAGE_ADDR, K, I, J, I); rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); } ocol += ocol_it; @@ -1548,7 +1517,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { och; } gemminiMvoutOffset(output, outOffset, cSpAddr, J, I, addrLen, rewriter); -// gemmini_extended_mvout(out, cSpAddr, J, I); } } } else { From ceb85c0495fa9bcc3ac19f8ab03e15adbcd66dc9 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 23:59:32 +0800 Subject: [PATCH 31/51] fix bug in rectangle conv --- .../Transforms/LegalizeForLLVMExport.cpp | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 9fb8750856..0613d4cc54 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -414,7 +414,7 @@ struct GemminiPreloadLowering : public ConvertOpToLLVMPattern { Value bdCols = preloadOp.getBdCols(); Value bdRows = preloadOp.getBdRows(); Value cCols = preloadOp.getCCols(); - Value cRows = preloadOp.getBdRows(); + Value cRows = preloadOp.getCRows(); Location loc = preloadOp.getLoc(); uint64_t bdAddrInt = getNumberFromValue(bdAddr); uint64_t cAddrInt = getNumberFromValue(cAddr); @@ -1165,6 +1165,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { TileConvOp &tileConvOp, ConversionPatternRewriter &rewriter ) const { + Location loc = tileConvOp.getLoc(); if (dw) { kchs = 1; @@ -1230,18 +1231,18 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; } - if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { - gemminiLoopConvWs( - batchSize, inRowDim, inChannels, outChannels, outRowDim, - poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, - poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, - kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, - ocols, weights, output, bias, input, noBias, noPool, downsample, - wrot180, inputDilated, act, transOutput1203, transWeight1203, - transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, - rewriter); - return; - } +// if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { +// gemminiLoopConvWs( +// batchSize, inRowDim, inChannels, outChannels, outRowDim, +// poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, +// poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, +// kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, +// ocols, weights, output, bias, input, noBias, noPool, downsample, +// wrot180, inputDilated, act, transOutput1203, transWeight1203, +// transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, +// rewriter); +// return; +// } if (!noPool) { llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n"; return; @@ -1264,9 +1265,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int J = ochs - och > maxOchsPerMvin ? maxOchsPerMvin : ochs - och; const uint32_t dSpAddr = dSpAddrStart + (och / dim) * batches * orows * ocols + b * orows * ocols + orow * ocols + ocol; if (noBias) { - gemminiMvinOffset(zeroValue, 0, dSpAddr, J, I, addrLen, rewriter); + gemminiMvinOffset(zeroValue, 0 * sizeOfAccT, dSpAddr, J, I, addrLen, rewriter); } else { - gemminiMvinOffset(bias, och, dSpAddr, J, I, addrLen, rewriter); + gemminiMvinOffset(bias, och * sizeOfAccT, dSpAddr, J, I, addrLen, rewriter); } } } @@ -1321,7 +1322,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } else if (transInput3120) { offset = (ich*inRowDim*inColDim + irow*inColDim + icol) * batchSize + b; } - gemminiMvinOffset(memAddr, offset, aSpAddr, K, I >> downsample, addrLen, rewriter); + gemminiMvinOffset(memAddr, offset * sizeOfElemT, aSpAddr, K, I >> downsample, addrLen, rewriter); } icol += I; } @@ -1346,7 +1347,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const size_t spadBlockStride = transWeight0132 ? krows * kcols * ochs : krows * kcols * kchs; Value dramStrideValue = rewriter.create(loc, rewriter.getI64IntegerAttr(dramStride)); - rewriter.create(loc, dramStrideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 2, spadBlockStride); + rewriter.create(loc, dramStrideValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 1, spadBlockStride); const size_t och_it = transWeight0132 ? dim : max_chs_per_mvin; const size_t kch_it = transWeight0132 ? max_chs_per_mvin : dim; @@ -1372,7 +1373,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } else if (transWeight0132) { offset = (krow * kernelDim * outChannels + kcol * outChannels + och) * inChannels + kch; } - gemminiMvinOffset(weights, offset, bSpAddr, J, K, addrLen, rewriter); + gemminiMvinOffset(weights, offset * sizeOfElemT, bSpAddr, J, K, addrLen, rewriter); } } } @@ -1473,12 +1474,11 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { Value aSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(aSpAddr)); Value cSpAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(cSpAddr)); - rewriter.create(loc, perSpAddrOp, cSpAddrOp, jOp, - kOp, jOp, iOp); + rewriter.create(loc, perSpAddrOp, cSpAddrOp, kOp, jOp, iOp, jOp); if (newWeights) { - rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); + rewriter.create(loc, aSpAddrOp, garbageAddrOp, iOp, kOp, iOp, jOp); } else { - rewriter.create(loc, aSpAddrOp, garbageAddrOp, kOp, iOp, jOp, iOp); + rewriter.create(loc, aSpAddrOp, garbageAddrOp, iOp, kOp, iOp, jOp); } ocol += ocol_it; newWeights = false; @@ -1516,7 +1516,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { outChannels + och; } - gemminiMvoutOffset(output, outOffset, cSpAddr, J, I, addrLen, rewriter); + gemminiMvoutOffset(output, outOffset * sizeOfElemT, cSpAddr, J, I, addrLen, rewriter); } } } else { From 46eeb745fed1bcbc1f5d84202bc82c6396124870 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 7 Sep 2023 00:00:34 +0800 Subject: [PATCH 32/51] delete comment --- .../Transforms/LegalizeForLLVMExport.cpp | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 0613d4cc54..f8c0362983 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1231,18 +1231,18 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; } -// if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { -// gemminiLoopConvWs( -// batchSize, inRowDim, inChannels, outChannels, outRowDim, -// poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, -// poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, -// kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, -// ocols, weights, output, bias, input, noBias, noPool, downsample, -// wrot180, inputDilated, act, transOutput1203, transWeight1203, -// transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, -// rewriter); -// return; -// } + if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { + gemminiLoopConvWs( + batchSize, inRowDim, inChannels, outChannels, outRowDim, + poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, + poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, + kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, + ocols, weights, output, bias, input, noBias, noPool, downsample, + wrot180, inputDilated, act, transOutput1203, transWeight1203, + transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, + rewriter); + return; + } if (!noPool) { llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n"; return; @@ -1852,16 +1852,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { ArrayRef outputShape = outputType.getShape(); ArrayRef weightsShape = weightsType.getShape(); ArrayRef biasShape = biasType.getShape(); - // inDim - // if (inputShape[1] != inputShape[2]) { - // llvm::outs() << "inDim error.\n"; - // return failure(); - // } - // outChannels - // if (biasShape[0] != outputShape[1] || biasShape[0] != weightsShape[1]) { - // llvm::outs() << "outChannels error.\n"; - // return failure(); - // } + Value outRowDimValue = tileConvOp.getOutRowDim(); int outRowDim = getNumberFromValue(outRowDimValue); Value outColDimValue = tileConvOp.getOutColDim(); From dd82aca6f3bb3a47e4ed800a89d8bfc07e54ed19 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 7 Sep 2023 00:12:50 +0800 Subject: [PATCH 33/51] add test case tile-rect-conv.mlir --- examples/GemminiDialect/makefile | 9 +++++ examples/GemminiDialect/tile-rect-conv.mlir | 41 +++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 examples/GemminiDialect/tile-rect-conv.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 3c5925d886..ef4a282a27 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -103,6 +103,15 @@ tile-conv-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-rect-conv-run: + @${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + gemmini-linalg-matmul-run: @${BUDDY_OPT} ./matmul.mlir \ -convert-linalg-to-gemmini \ diff --git a/examples/GemminiDialect/tile-rect-conv.mlir b/examples/GemminiDialect/tile-rect-conv.mlir new file mode 100644 index 0000000000..0921093016 --- /dev/null +++ b/examples/GemminiDialect/tile-rect-conv.mlir @@ -0,0 +1,41 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputRowDim = 5 inputColDim inChannels = 2 +memref.global "private" @input : memref<1x5x10x1xi8> = dense<[[[[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[1, 2], [1, 2], [1, 2], + [1, 2], [1, 2], [1, 2], + [1, 2], [1, 2], [1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %8 = arith.constant 8 : i64 + %input = memref.get_global @input : memref<1x5x10x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<24x2xi8> + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.mvin3" + // CHECK: "gemmini.intr.mvin" + // CHECK: "gemmini.intr.mvin2" + // CHECK: "gemmini.intr.preload" + // CHECK: "gemmini.intr.compute_preloaded" + // CHECK: "gemmini.intr.compute_accumulated" + gemmini.tile_conv %input %weight %bias %output %3 %8 %3 {stride = 1}: + memref<1x5x10x1xi8> memref<9x2xi8> memref<2xi32> memref<24x2xi8> i64 i64 i64 + gemmini.print %output : memref<24x2xi8> + return %0 : i64 +} From 5dc0b58b7ec7c1d4995da48449b6dbe863aa747e Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 24 Oct 2023 22:56:47 +0800 Subject: [PATCH 34/51] add blank lines --- examples/GemminiDialect/tile-matmul-os.mlir | 1 + examples/GemminiDialect/tile-matmul-ws-igelu.mlir | 1 + examples/GemminiDialect/tile-rect-conv.mlir | 1 + 3 files changed, 3 insertions(+) diff --git a/examples/GemminiDialect/tile-matmul-os.mlir b/examples/GemminiDialect/tile-matmul-os.mlir index dafcabef72..120a44654d 100644 --- a/examples/GemminiDialect/tile-matmul-os.mlir +++ b/examples/GemminiDialect/tile-matmul-os.mlir @@ -21,6 +21,7 @@ func.func @main() -> i8 { memref.store %i2I32, %dArray[%i, %j] : memref<64x64xi32> } } + gemmini.print %aArray : memref<64x64xi8> gemmini.print %bArray : memref<64x64xi8> gemmini.print %dArray : memref<64x64xi32> diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir index 78db052800..a3193cde1c 100644 --- a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -27,6 +27,7 @@ func.func @main() -> i8 { memref.store %dI32, %dArray[%i3, %j3] : memref<3x3xi32> } } + gemmini.print %aArray : memref<3x3xi8> gemmini.print %bArray : memref<3x3xi8> // CHECK: "gemmini.intr.config_ex" diff --git a/examples/GemminiDialect/tile-rect-conv.mlir b/examples/GemminiDialect/tile-rect-conv.mlir index 0921093016..1ef00f5e1a 100644 --- a/examples/GemminiDialect/tile-rect-conv.mlir +++ b/examples/GemminiDialect/tile-rect-conv.mlir @@ -21,6 +21,7 @@ func.func @main() -> i64 { %0 = arith.constant 0 : i64 %3 = arith.constant 3 : i64 %8 = arith.constant 8 : i64 + %input = memref.get_global @input : memref<1x5x10x1xi8> %weight = memref.get_global @weight : memref<9x2xi8> %bias = memref.get_global @bias : memref<2xi32> From 501f28b8f8d2e80ef44ec4fa2b157b194e5aaa72 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 24 Oct 2023 22:58:49 +0800 Subject: [PATCH 35/51] add a space --- midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index f8c0362983..874314e377 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -44,7 +44,7 @@ int64_t getNumberFromValue(Value &value) { } int ceil_divide_int(int a, int b){ - int c = (a % b == 0) ? ((int)(a/b)) :(((int)(a/b)) + 1); + int c = (a % b == 0) ? ((int)(a/b)) : (((int)(a/b)) + 1); if(a < b) c = 1; return c; } From 6572223cb976527e459868351f50c992a1e61789 Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 24 Oct 2023 23:03:37 +0800 Subject: [PATCH 36/51] modify comments in tile-conv.mlir and tile-rect-conv.mlir --- examples/GemminiDialect/tile-conv.mlir | 2 +- examples/GemminiDialect/tile-rect-conv.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/GemminiDialect/tile-conv.mlir b/examples/GemminiDialect/tile-conv.mlir index 6c85572a48..9ac91acffc 100644 --- a/examples/GemminiDialect/tile-conv.mlir +++ b/examples/GemminiDialect/tile-conv.mlir @@ -2,7 +2,7 @@ // RUN: --lower-gemmini | \ // RUN: FileCheck %s -// batchSize = 1 inputDIm = 5 inChannels = 2 +// batchSize = 1 inputDim = 5 inChannels = 1 memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1]], diff --git a/examples/GemminiDialect/tile-rect-conv.mlir b/examples/GemminiDialect/tile-rect-conv.mlir index 1ef00f5e1a..0f6536da7c 100644 --- a/examples/GemminiDialect/tile-rect-conv.mlir +++ b/examples/GemminiDialect/tile-rect-conv.mlir @@ -2,7 +2,7 @@ // RUN: --lower-gemmini | \ // RUN: FileCheck %s -// batchSize = 1 inputRowDim = 5 inputColDim inChannels = 2 +// batchSize = 1 inputRowDim = 5 inputColDim = 10 inChannels = 2 memref.global "private" @input : memref<1x5x10x1xi8> = dense<[[[[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], From 0f5125a9b34260441b081925c0c098fb94ff6e9d Mon Sep 17 00:00:00 2001 From: yxy Date: Tue, 24 Oct 2023 23:21:38 +0800 Subject: [PATCH 37/51] modify tile-rect-conv.mlir --- examples/GemminiDialect/tile-rect-conv.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/GemminiDialect/tile-rect-conv.mlir b/examples/GemminiDialect/tile-rect-conv.mlir index 0f6536da7c..e982b3ad6f 100644 --- a/examples/GemminiDialect/tile-rect-conv.mlir +++ b/examples/GemminiDialect/tile-rect-conv.mlir @@ -2,7 +2,7 @@ // RUN: --lower-gemmini | \ // RUN: FileCheck %s -// batchSize = 1 inputRowDim = 5 inputColDim = 10 inChannels = 2 +// batchSize = 1 inputRowDim = 5 inputColDim = 10 inChannels = 1 memref.global "private" @input : memref<1x5x10x1xi8> = dense<[[[[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], [[1], [0], [-1], [0], [1], [1], [0], [-1], [0], [1]], From 55fc0f317ab6e2a4bd056d17e051855549ffbb0c Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 26 Oct 2023 19:53:43 +0800 Subject: [PATCH 38/51] change matrix shape --- .../GemminiDialect/tile-matmul-ws-igelu.mlir | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir index a3193cde1c..2d2f45dcf4 100644 --- a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -2,8 +2,9 @@ // RUN: --lower-gemmini | \ // RUN: FileCheck %s -memref.global "private" @g1 : memref<3x3xi8> = dense<[[1, 0, 0], [1, -1, 1], [-1, 0, 1]]> -memref.global "private" @g2 : memref<3x3xi8> = dense<[[1, -1, 0], [1, 0, -1], [-1, -1, 0]]> +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + func.func @main() -> i8 { %i0 = arith.constant 0 : i8 @@ -14,22 +15,20 @@ func.func @main() -> i8 { %dI32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %aArray = memref.get_global @g1 : memref<3x3xi8> - %bArray = memref.get_global @g2 : memref<3x3xi8> - %cArray = memref.alloc() : memref<3x3xi8> - %dArray = memref.alloc() : memref<3x3xi32> - %dim_I = memref.dim %aArray, %c0 : memref<3x3xi8> - %dim_J = memref.dim %bArray, %c1 : memref<3x3xi8> - %dim_K = memref.dim %aArray, %c1 : memref<3x3xi8> + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> scf.for %i3 = %c0 to %dim_I step %c1 { scf.for %j3 = %c0 to %dim_J step %c1 { - memref.store %dI32, %dArray[%i3, %j3] : memref<3x3xi32> + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> } } - gemmini.print %aArray : memref<3x3xi8> - gemmini.print %bArray : memref<3x3xi8> // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" // CHECK: "gemmini.intr.config_ld" @@ -41,7 +40,7 @@ func.func @main() -> i8 { // CHECK: "gemmini.intr.loop_ws_config_strides_dc" // CHECK: "gemmini.intr.loop_ws" // CHECk: "gemmini.intr.flush" - gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=3, bertScale=0.8:f32}: memref<3x3xi8> memref<3x3xi8> memref<3x3xi8> memref<3x3xi32> - gemmini.print %cArray : memref<3x3xi8> + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=3, bertScale=0.8:f32}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> return %i0 : i8 } From 95fa5598a5caf92ef5b55a870c8c4c6b8e21a963 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 26 Oct 2023 20:37:04 +0800 Subject: [PATCH 39/51] add relu and softmax test --- examples/GemminiDialect/makefile | 18 ++++++++ .../GemminiDialect/tile-matmul-ws-relu.mlir | 46 +++++++++++++++++++ .../tile-matmul-ws-softmax.mlir | 46 +++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 examples/GemminiDialect/tile-matmul-ws-relu.mlir create mode 100644 examples/GemminiDialect/tile-matmul-ws-softmax.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index ef4a282a27..88abb0647d 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -94,6 +94,24 @@ tile-matmul-ws-igelu-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-matmul-ws-relu-run: + @${BUDDY_OPT} ./tile-matmul-ws-relu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-matmul-ws-softmax-run: + @${BUDDY_OPT} ./tile-matmul-ws-softmax.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-run: @${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-matmul-ws-relu.mlir b/examples/GemminiDialect/tile-matmul-ws-relu.mlir new file mode 100644 index 0000000000..336014fb8c --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-relu.mlir @@ -0,0 +1,46 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} diff --git a/examples/GemminiDialect/tile-matmul-ws-softmax.mlir b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir new file mode 100644 index 0000000000..fa505e31d5 --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir @@ -0,0 +1,46 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 1, 0], [1, -1, 1, 0, 0], [-1, 0, 1, -1, 1], [1, 0, 0, 1, 0], [-1, 0, 0, -1, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[1, -1, 0, 0, 1], [1, 0, -1, 0, -1], [-1, -1, 0, -1, 1], [-1, 0, 0, 1, 0], [1, 0, 0, -1, 0]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=4, bertScale=0.05:f32}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} From 74436edc1b94ef4be1e49f3e9b042512ecc810b1 Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 26 Oct 2023 21:07:22 +0800 Subject: [PATCH 40/51] fix test error --- examples/GemminiDialect/tile-matmul-ws-relu.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/GemminiDialect/tile-matmul-ws-relu.mlir b/examples/GemminiDialect/tile-matmul-ws-relu.mlir index 336014fb8c..3416367d38 100644 --- a/examples/GemminiDialect/tile-matmul-ws-relu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-relu.mlir @@ -32,7 +32,6 @@ func.func @main() -> i8 { // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" // CHECK: "gemmini.intr.config_ld" - // CHECK: "gemmini.intr.config_norm" // CHECK: "gemmini.intr.loop_ws_config_bounds" // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" From c7ee7fb0cc0c7894672a03a015bb8b84d044ddec Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 26 Oct 2023 22:44:44 +0800 Subject: [PATCH 41/51] add tile-conv-relu --- examples/GemminiDialect/makefile | 9 +++++ examples/GemminiDialect/tile-conv-relu.mlir | 39 +++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 examples/GemminiDialect/tile-conv-relu.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 88abb0647d..3b11f8fd12 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -121,6 +121,15 @@ tile-conv-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-conv-relu-run: + @${BUDDY_OPT} ./tile-conv-relu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-rect-conv-run: @${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-conv-relu.mlir b/examples/GemminiDialect/tile-conv-relu.mlir new file mode 100644 index 0000000000..21fd10113d --- /dev/null +++ b/examples/GemminiDialect/tile-conv-relu.mlir @@ -0,0 +1,39 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} From bc9b1e2c5fedd2b135213e036c3a0569514772ee Mon Sep 17 00:00:00 2001 From: yxy Date: Thu, 26 Oct 2023 22:51:51 +0800 Subject: [PATCH 42/51] add conv-igelu and conv-softmax --- examples/GemminiDialect/makefile | 18 +++++++++ examples/GemminiDialect/tile-conv-igelu.mlir | 39 +++++++++++++++++++ .../GemminiDialect/tile-conv-softmax.mlir | 39 +++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 examples/GemminiDialect/tile-conv-igelu.mlir create mode 100644 examples/GemminiDialect/tile-conv-softmax.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 3b11f8fd12..5b5521229b 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -121,6 +121,24 @@ tile-conv-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-conv-igelu-run: + @${BUDDY_OPT} ./tile-conv-igelu.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + +tile-conv-softmax-run: + @${BUDDY_OPT} ./tile-conv-softmax.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-relu-run: @${BUDDY_OPT} ./tile-conv-relu.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-conv-igelu.mlir b/examples/GemminiDialect/tile-conv-igelu.mlir new file mode 100644 index 0000000000..1bf3b80f5d --- /dev/null +++ b/examples/GemminiDialect/tile-conv-igelu.mlir @@ -0,0 +1,39 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 3}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv-softmax.mlir b/examples/GemminiDialect/tile-conv-softmax.mlir new file mode 100644 index 0000000000..a5cc908467 --- /dev/null +++ b/examples/GemminiDialect/tile-conv-softmax.mlir @@ -0,0 +1,39 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 4}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} From ab50d5b116ac8d1f1bb40cedc330dda143ed8f70 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 16:10:00 +0800 Subject: [PATCH 43/51] print origin result of matmul --- examples/GemminiDialect/tile-matmul-ws-igelu.mlir | 3 +++ examples/GemminiDialect/tile-matmul-ws-relu.mlir | 3 +++ examples/GemminiDialect/tile-matmul-ws-softmax.mlir | 3 +++ 3 files changed, 9 insertions(+) diff --git a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir index 2d2f45dcf4..0edf6428bd 100644 --- a/examples/GemminiDialect/tile-matmul-ws-igelu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-igelu.mlir @@ -28,6 +28,9 @@ func.func @main() -> i8 { memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> } } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" diff --git a/examples/GemminiDialect/tile-matmul-ws-relu.mlir b/examples/GemminiDialect/tile-matmul-ws-relu.mlir index 3416367d38..f461950615 100644 --- a/examples/GemminiDialect/tile-matmul-ws-relu.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-relu.mlir @@ -28,6 +28,9 @@ func.func @main() -> i8 { memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> } } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" diff --git a/examples/GemminiDialect/tile-matmul-ws-softmax.mlir b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir index fa505e31d5..c81bccceac 100644 --- a/examples/GemminiDialect/tile-matmul-ws-softmax.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-softmax.mlir @@ -29,6 +29,9 @@ func.func @main() -> i8 { } } + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" // CHECK: "gemmini.intr.config_ld" From 3ef138fcec43a304bd79b30ee3468870d9a43350 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 17:31:21 +0800 Subject: [PATCH 44/51] fix small bug --- midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 874314e377..ec86228284 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1015,7 +1015,7 @@ class GemminiTileMatMulLowering : public ConvertOpToLLVMPattern { size_t tileI, tileJ, tileK; if (act == LAYERNORM || act == SOFTMAX) { tileI = 1; - tileJ = dimJPaded | dim; + tileJ = dimJPaded / dim; tileK = 1; } else { tileI = dimIPaded / dim < dbMaxTileIJ ? dimIPaded / dim : dbMaxTileIJ; From a7785ade565f6e65c172a9e805245d924b9117a4 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 17:31:42 +0800 Subject: [PATCH 45/51] add print origin result for conv test --- examples/GemminiDialect/tile-conv-igelu.mlir | 5 +++++ examples/GemminiDialect/tile-conv-relu.mlir | 5 +++++ examples/GemminiDialect/tile-conv-softmax.mlir | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/examples/GemminiDialect/tile-conv-igelu.mlir b/examples/GemminiDialect/tile-conv-igelu.mlir index 1bf3b80f5d..0db1ec5952 100644 --- a/examples/GemminiDialect/tile-conv-igelu.mlir +++ b/examples/GemminiDialect/tile-conv-igelu.mlir @@ -24,6 +24,11 @@ func.func @main() -> i64 { %weight = memref.get_global @weight : memref<9x2xi8> %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" // CHECK: "gemmini.intr.loop_conv_ws_config2" // CHECK: "gemmini.intr.loop_conv_ws_config3" diff --git a/examples/GemminiDialect/tile-conv-relu.mlir b/examples/GemminiDialect/tile-conv-relu.mlir index 21fd10113d..856fe1aa43 100644 --- a/examples/GemminiDialect/tile-conv-relu.mlir +++ b/examples/GemminiDialect/tile-conv-relu.mlir @@ -24,6 +24,11 @@ func.func @main() -> i64 { %weight = memref.get_global @weight : memref<9x2xi8> %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" // CHECK: "gemmini.intr.loop_conv_ws_config2" // CHECK: "gemmini.intr.loop_conv_ws_config3" diff --git a/examples/GemminiDialect/tile-conv-softmax.mlir b/examples/GemminiDialect/tile-conv-softmax.mlir index a5cc908467..70c158933f 100644 --- a/examples/GemminiDialect/tile-conv-softmax.mlir +++ b/examples/GemminiDialect/tile-conv-softmax.mlir @@ -24,6 +24,11 @@ func.func @main() -> i64 { %weight = memref.get_global @weight : memref<9x2xi8> %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" // CHECK: "gemmini.intr.loop_conv_ws_config2" // CHECK: "gemmini.intr.loop_conv_ws_config3" From 10deee86eec543a3fabeb681733fd5e063a66392 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 20:48:42 +0800 Subject: [PATCH 46/51] add layernorm test --- examples/GemminiDialect/makefile | 9 ++++ .../tile-matmul-ws-layernorm.mlir | 49 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 examples/GemminiDialect/tile-matmul-ws-layernorm.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 5b5521229b..3d65357832 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -112,6 +112,15 @@ tile-matmul-ws-softmax-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-matmul-ws-layernorm-run: + @${BUDDY_OPT} ./tile-matmul-ws-layernorm.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-conv-run: @${BUDDY_OPT} ./tile-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir new file mode 100644 index 0000000000..a86adfa723 --- /dev/null +++ b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +memref.global "private" @g1 : memref<5x5xi8> = dense<[[1, 0, 0, 0, 0], [0, -1, 1, 0, 1], [-1, 0, -1, -1, 0], [-1, 0, 0, 1, 0], [0, 0, 0, 0, 0]]> +memref.global "private" @g2 : memref<5x5xi8> = dense<[[-1, 0, 1, 0, -1], [1, -1, 1, 0, -1], [-1, -1, -1, 1, 1], [-1, 1, 0, -1, 1], [-1, 0, 1, 1, 1]]> + + +func.func @main() -> i8 { + %i0 = arith.constant 0 : i8 + %i1I8 = arith.constant 1 : i8 + %minus1 = arith.constant -2 : i8 + %i2I8 = arith.constant 2 : i8 + %i2I32 = arith.constant 2 : i32 + %dI32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %aArray = memref.get_global @g1 : memref<5x5xi8> + %bArray = memref.get_global @g2 : memref<5x5xi8> + %cArray = memref.alloc() : memref<5x5xi8> + %dArray = memref.alloc() : memref<5x5xi32> + %dim_I = memref.dim %aArray, %c0 : memref<5x5xi8> + %dim_J = memref.dim %bArray, %c1 : memref<5x5xi8> + %dim_K = memref.dim %aArray, %c1 : memref<5x5xi8> + + scf.for %i3 = %c0 to %dim_I step %c1 { + scf.for %j3 = %c0 to %dim_J step %c1 { + memref.store %dI32, %dArray[%i3, %j3] : memref<5x5xi32> + } + } + + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + + // CHECK: "gemmini.intr.config_ex" + // CHECK: "gemmini.intr.config_st" + // CHECK: "gemmini.intr.config_ld" + // CHECK: "gemmini.intr.config_norm" + // CHECK: "gemmini.intr.loop_ws_config_bounds" + // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" + // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" + // CHECK: "gemmini.intr.loop_ws_config_strides_ab" + // CHECK: "gemmini.intr.loop_ws_config_strides_dc" + // CHECK: "gemmini.intr.loop_ws" + // CHECk: "gemmini.intr.flush" + gemmini.tile_matmul %aArray %bArray %cArray %dArray {dataflow=1, act=2}: memref<5x5xi8> memref<5x5xi8> memref<5x5xi8> memref<5x5xi32> + gemmini.print %cArray : memref<5x5xi8> + return %i0 : i8 +} From 02a12919b787034db44a00e158f09ee00448b200 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 21:41:18 +0800 Subject: [PATCH 47/51] fix bug in tile-matmul-ws-layernorm --- examples/GemminiDialect/tile-matmul-ws-layernorm.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir index a86adfa723..cf3529c28b 100644 --- a/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir +++ b/examples/GemminiDialect/tile-matmul-ws-layernorm.mlir @@ -35,7 +35,6 @@ func.func @main() -> i8 { // CHECK: "gemmini.intr.config_ex" // CHECK: "gemmini.intr.config_st" // CHECK: "gemmini.intr.config_ld" - // CHECK: "gemmini.intr.config_norm" // CHECK: "gemmini.intr.loop_ws_config_bounds" // CHECK: "gemmini.intr.loop_ws_config_addrs_ab" // CHECK: "gemmini.intr.loop_ws_config_addrs_dc" From 834e8376dc5e85dd78644aa4579758ac11b40126 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 22:08:50 +0800 Subject: [PATCH 48/51] add layernorm and filecheck --- examples/GemminiDialect/makefile | 9 ++++ examples/GemminiDialect/tile-conv-igelu.mlir | 8 +++ .../GemminiDialect/tile-conv-layernorm.mlir | 52 +++++++++++++++++++ examples/GemminiDialect/tile-conv-relu.mlir | 8 +++ .../GemminiDialect/tile-conv-softmax.mlir | 8 +++ 5 files changed, 85 insertions(+) create mode 100644 examples/GemminiDialect/tile-conv-layernorm.mlir diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index 3d65357832..cba84b780a 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -157,6 +157,15 @@ tile-conv-relu-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +tile-conv-layernorm-run: + @${BUDDY_OPT} ./tile-conv-layernorm.mlir -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + tile-rect-conv-run: @${BUDDY_OPT} ./tile-rect-conv.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/tile-conv-igelu.mlir b/examples/GemminiDialect/tile-conv-igelu.mlir index 0db1ec5952..9eb34f1f5f 100644 --- a/examples/GemminiDialect/tile-conv-igelu.mlir +++ b/examples/GemminiDialect/tile-conv-igelu.mlir @@ -25,6 +25,14 @@ func.func @main() -> i64 { %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> diff --git a/examples/GemminiDialect/tile-conv-layernorm.mlir b/examples/GemminiDialect/tile-conv-layernorm.mlir new file mode 100644 index 0000000000..1d440c125d --- /dev/null +++ b/examples/GemminiDialect/tile-conv-layernorm.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s \ +// RUN: --lower-gemmini | \ +// RUN: FileCheck %s + +// batchSize = 1 inputDim = 5 inChannels = 1 +memref.global "private" @input : memref<1x5x5x1xi8> = dense<[[[[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]], + [[1], [0], [-1], [0], [1]]]]> + +// outChannels = 2 kernelDim = 3 inChannels = 1 +memref.global "private" @weight : memref<9x2xi8> = dense<[[-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2], + [-1, 2], [-1, 2], [-1, 2]]> + +// outChannels = 2 +memref.global "private" @bias : memref<2xi32> = dense<[1,1]> + +func.func @main() -> i64 { + %0 = arith.constant 0 : i64 + %3 = arith.constant 3 : i64 + %input = memref.get_global @input : memref<1x5x5x1xi8> + %weight = memref.get_global @weight : memref<9x2xi8> + %bias = memref.get_global @bias : memref<2xi32> + %output = memref.alloc() : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1, act = 2}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 + gemmini.print %output : memref<9x2xi8> + return %0 : i64 +} diff --git a/examples/GemminiDialect/tile-conv-relu.mlir b/examples/GemminiDialect/tile-conv-relu.mlir index 856fe1aa43..99198f583a 100644 --- a/examples/GemminiDialect/tile-conv-relu.mlir +++ b/examples/GemminiDialect/tile-conv-relu.mlir @@ -25,6 +25,14 @@ func.func @main() -> i64 { %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> diff --git a/examples/GemminiDialect/tile-conv-softmax.mlir b/examples/GemminiDialect/tile-conv-softmax.mlir index 70c158933f..67a63d5f4e 100644 --- a/examples/GemminiDialect/tile-conv-softmax.mlir +++ b/examples/GemminiDialect/tile-conv-softmax.mlir @@ -25,6 +25,14 @@ func.func @main() -> i64 { %bias = memref.get_global @bias : memref<2xi32> %output = memref.alloc() : memref<9x2xi8> + // CHECK: "gemmini.intr.loop_conv_ws_config1" + // CHECK: "gemmini.intr.loop_conv_ws_config2" + // CHECK: "gemmini.intr.loop_conv_ws_config3" + // CHECK: "gemmini.intr.loop_conv_ws_config4" + // CHECK: "gemmini.intr.loop_conv_ws_config5" + // CHECK: "gemmini.intr.loop_conv_ws_config6" + // CHECK: "gemmini.intr.loop_conv_ws" + // CHECK: "gemmini.intr.flush" gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> From 15ac6fd39dece0b684f11a8ac659105de719b9f1 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 22:14:43 +0800 Subject: [PATCH 49/51] delete useless var --- .../Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index ec86228284..378abb3b90 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1569,13 +1569,9 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int dilatedInRowDim = inRowDim + (inputDilation - 1) * (inRowDim - 1); const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); - size_t aSpadId = 0; - size_t bSpadId = 0; - int porowEnd = poolOutRowDim; int porowStart = 0; - bool a_reuse = false; - bool b_reuse = false; + size_t num_kch = ceil_divide_int(inChannels, kchs); size_t num_poch = ceil_divide_int(outChannels, pochs); size_t num_b = ceil_divide_int(batchSize, batches); @@ -1849,8 +1845,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { MemRefType weightsType = weights.getType().dyn_cast(); MemRefType biasType = bias.getType().dyn_cast(); ArrayRef inputShape = inputType.getShape(); - ArrayRef outputShape = outputType.getShape(); - ArrayRef weightsShape = weightsType.getShape(); ArrayRef biasShape = biasType.getShape(); Value outRowDimValue = tileConvOp.getOutRowDim(); From 5ea3690e313acf2b3df8b83b8b676278c9984be4 Mon Sep 17 00:00:00 2001 From: yxy Date: Fri, 27 Oct 2023 22:20:00 +0800 Subject: [PATCH 50/51] handle compiler warnings --- .../Gemmini/Transforms/LegalizeForLLVMExport.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 378abb3b90..32b19bcc0c 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1253,7 +1253,7 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const size_t maxBlockLenAcc = MAX_BYTES / (dim * 4); if (bias != NULL) { // TODO we probably don't need quite this many nested loops for this part - const int maxOchsPerMvin = ochs < maxBlockLenAcc * dim ? ochs : + const int maxOchsPerMvin = ochs < (int)(maxBlockLenAcc * dim) ? ochs : maxBlockLenAcc * dim; Value zeroValue = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); rewriter.create(loc, zeroValue, llvm::APFloat((float)MVIN_SCALE_IDENTITY), false, 2, batches * orows * ocols); @@ -1274,10 +1274,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } // mvin input if (input != NULL){ - int maxChsPerMvin = ichs < maxBlockLen * dim ? ichs : + int maxChsPerMvin = ichs < (int)(maxBlockLen * dim) ? ichs : maxBlockLen * dim; if (transInput3120) { - maxChsPerMvin = batches < maxBlockLen * dim ? batches : + maxChsPerMvin = batches < (int)(maxBlockLen * dim) ? batches : maxBlockLen * dim; } const int dramStride = transInput3120 ? @@ -1330,10 +1330,10 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { } // mvin weights if (weights != NULL) { - int max_chs_per_mvin = ochs < maxBlockLen * dim ? ochs : + int max_chs_per_mvin = ochs < (int)(maxBlockLen * dim) ? ochs : maxBlockLen * dim; if (transWeight0132) { - max_chs_per_mvin = kchs < maxBlockLen * dim ? kchs : + max_chs_per_mvin = kchs < (int)(maxBlockLen * dim) ? kchs : maxBlockLen * dim; } size_t dramStride = weightStride * sizeOfElemT; @@ -1466,7 +1466,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { newWeights ? bSpAddr : GARBAGE_ADDR; Value garbageAddrOp = rewriter.create(loc, rewriter.getI64IntegerAttr(GARBAGE_ADDR)); - Value dimOp = rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); Value iOp = rewriter.create(loc, rewriter.getI64IntegerAttr(I)); Value jOp = rewriter.create(loc, rewriter.getI64IntegerAttr(J)); Value kOp = rewriter.create(loc, rewriter.getI64IntegerAttr(K)); @@ -1580,11 +1579,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { size_t num_krow = ceil_divide_int(kernelDim, krows); size_t num_kcol = ceil_divide_int(kernelDim, kcols); - if(num_kch * num_poch * num_krow * num_kcol <= 2) - b_reuse = true; - if(num_kch * num_krow * num_kcol * num_b * num_porow * num_pocol <= 2) - a_reuse = true; - for (int b = 0; b < batchSize; b += batches) { for (int porow = 0; porow < porowEnd; porow += porows) { const int orow = porow * poolStride - poolPadding; From ae7ebcc7b1e40ae40f88eec3612a32c95dd55983 Mon Sep 17 00:00:00 2001 From: yxy Date: Sun, 29 Oct 2023 14:57:09 +0800 Subject: [PATCH 51/51] delete useless vars --- .../Transforms/LegalizeForLLVMExport.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 32b19bcc0c..b8d81bc4d5 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -43,12 +43,6 @@ int64_t getNumberFromValue(Value &value) { .getInt(); } -int ceil_divide_int(int a, int b){ - int c = (a % b == 0) ? ((int)(a/b)) : (((int)(a/b)) + 1); - if(a < b) c = 1; - return c; -} - acc_scale_t_bits acc_scale_t_to_acc_scale_t_bits(acc_scale_t x) { union { acc_scale_t_bits b; @@ -1569,15 +1563,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { const int dilatedInColDim = inColDim + (inputDilation - 1) * (inColDim - 1); int porowEnd = poolOutRowDim; - int porowStart = 0; - - size_t num_kch = ceil_divide_int(inChannels, kchs); - size_t num_poch = ceil_divide_int(outChannels, pochs); - size_t num_b = ceil_divide_int(batchSize, batches); - size_t num_porow = ceil_divide_int((porowEnd - porowStart), porows); - size_t num_pocol = ceil_divide_int(poolOutColDim, pocols); - size_t num_krow = ceil_divide_int(kernelDim, krows); - size_t num_kcol = ceil_divide_int(kernelDim, kcols); for (int b = 0; b < batchSize; b += batches) { for (int porow = 0; porow < porowEnd; porow += porows) { @@ -1835,8 +1820,6 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { Value weights = tileConvOp.getWeights(); Value bias = tileConvOp.getBias(); MemRefType inputType = input.getType().dyn_cast(); - MemRefType outputType = output.getType().dyn_cast(); - MemRefType weightsType = weights.getType().dyn_cast(); MemRefType biasType = bias.getType().dyn_cast(); ArrayRef inputShape = inputType.getShape(); ArrayRef biasShape = biasType.getShape();