From 58f450fc69a80dc2175fee1e31ce4c6eba46cdf2 Mon Sep 17 00:00:00 2001 From: yxy Date: Wed, 6 Sep 2023 15:08:59 +0800 Subject: [PATCH] modify tile_conv case in ciface.mlir and tile-conv.mlir; fix LegalizeForLLVMExport.cpp --- examples/GemminiDialect/ciface.mlir | 12 ++++----- examples/GemminiDialect/tile-conv.mlir | 4 +-- .../Transforms/LegalizeForLLVMExport.cpp | 25 +++++++++++-------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/GemminiDialect/ciface.mlir b/examples/GemminiDialect/ciface.mlir index 004992f2a3..070ca581b2 100644 --- a/examples/GemminiDialect/ciface.mlir +++ b/examples/GemminiDialect/ciface.mlir @@ -127,7 +127,7 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13 func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) { %outdim = arith.constant 254 : i64 %kernelDim = arith.constant 3 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64 return } @@ -136,7 +136,7 @@ func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8 func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) { %outdim = arith.constant 252 : i64 %kernelDim = arith.constant 5 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64 return } @@ -145,7 +145,7 @@ func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) { %outdim = arith.constant 250 : i64 %kernelDim = arith.constant 7 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64 return } @@ -154,7 +154,7 @@ func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) { %outdim = arith.constant 248 : i64 %kernelDim = arith.constant 9 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64 return } @@ -163,7 +163,7 @@ func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) { %outdim = arith.constant 246 : i64 %kernelDim = arith.constant 11 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64 return } @@ -172,7 +172,7 @@ func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1x func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) { %outdim = arith.constant 244 : i64 %kernelDim = arith.constant 13 : i64 - gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} : + gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} : memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64 return } diff --git a/examples/GemminiDialect/tile-conv.mlir b/examples/GemminiDialect/tile-conv.mlir index 42f0085ce3..6c85572a48 100644 --- a/examples/GemminiDialect/tile-conv.mlir +++ b/examples/GemminiDialect/tile-conv.mlir @@ -32,8 +32,8 @@ func.func @main() -> i64 { // CHECK: "gemmini.intr.loop_conv_ws_config6" // CHECK: "gemmini.intr.loop_conv_ws" // CHECK: "gemmini.intr.flush" - gemmini.tile_conv %input %weight %bias %output %3 %3 {stride = 1}: - memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 + gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}: + memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64 gemmini.print %output : memref<9x2xi8> return %0 : i64 } diff --git a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp index 1597356723..88fe1248d2 100644 --- a/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp @@ -1230,18 +1230,21 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern { if (output != 0) { cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS; } - gemminiLoopConvWs( - batchSize, inRowDim, inChannels, outChannels, outRowDim, poolOutRowDim, stride, - padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding, - batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad, - dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias, - input, noBias, noPool, downsample, wrot180, inputDilated, act, - transOutput1203, transWeight1203, transWeight0132, transInput3120, - maxPixelsPerRow, dw, tileConvOp, rewriter); + if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) { + gemminiLoopConvWs( + batchSize, inRowDim, inChannels, outChannels, outRowDim, + poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize, + poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols, + kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows, + ocols, weights, output, bias, input, noBias, noPool, downsample, + wrot180, inputDilated, act, transOutput1203, transWeight1203, + transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp, + rewriter); + return; + } if (!noPool) { - // TODO: Exit, but now I don't known how to do - // printf("Pooling with rectangular convolutions is currently not supported.\n"); - // exit(1); + llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n"; + return; } // Only rectangular convolutions will use the following C code // mvin bias