Skip to content

Commit

Permalink
modify tile_conv case in ciface.mlir and tile-conv.mlir; fix Legalize…
Browse files Browse the repository at this point in the history
…ForLLVMExport.cpp
  • Loading branch information
Xinyu302 committed Sep 6, 2023
1 parent f7237aa commit 58f450f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
12 changes: 6 additions & 6 deletions examples/GemminiDialect/ciface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func.func @linalg_conv6(%arg0 : memref<1x1x256x256xi8>, %arg1 : memref<1x1x13x13
func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8>, %bias: memref<1xi32>, %output: memref<64516x1xi8>) {
%outdim = arith.constant 254 : i64
%kernelDim = arith.constant 3 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<9x1xi8> memref<1xi32> memref<64516x1xi8> i64 i64
return
}
Expand All @@ -136,7 +136,7 @@ func.func @gemmini_conv1(%input: memref<1x256x256x1xi8>, %weights: memref<9x1xi8
func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi8>, %bias: memref<1xi32>, %output: memref<63504x1xi8>) {
%outdim = arith.constant 252 : i64
%kernelDim = arith.constant 5 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<25x1xi8> memref<1xi32> memref<63504x1xi8> i64 i64
return
}
Expand All @@ -145,7 +145,7 @@ func.func @gemmini_conv2(%input: memref<1x256x256x1xi8>, %weights: memref<25x1xi
func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi8>, %bias: memref<1xi32>, %output: memref<62500x1xi8>) {
%outdim = arith.constant 250 : i64
%kernelDim = arith.constant 7 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<49x1xi8> memref<1xi32> memref<62500x1xi8> i64 i64
return
}
Expand All @@ -154,7 +154,7 @@ func.func @gemmini_conv3(%input: memref<1x256x256x1xi8>, %weights: memref<49x1xi
func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi8>, %bias: memref<1xi32>, %output: memref<61504x1xi8>) {
%outdim = arith.constant 248 : i64
%kernelDim = arith.constant 9 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<81x1xi8> memref<1xi32> memref<61504x1xi8> i64 i64
return
}
Expand All @@ -163,7 +163,7 @@ func.func @gemmini_conv4(%input: memref<1x256x256x1xi8>, %weights: memref<81x1xi
func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1xi8>, %bias: memref<1xi32>, %output: memref<60516x1xi8>) {
%outdim = arith.constant 246 : i64
%kernelDim = arith.constant 11 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<121x1xi8> memref<1xi32> memref<60516x1xi8> i64 i64
return
}
Expand All @@ -172,7 +172,7 @@ func.func @gemmini_conv5(%input: memref<1x256x256x1xi8>, %weights: memref<121x1x
func.func @gemmini_conv6(%input: memref<1x256x256x1xi8>, %weights: memref<169x1xi8>, %bias: memref<1xi32>, %output: memref<59536x1xi8>) {
%outdim = arith.constant 244 : i64
%kernelDim = arith.constant 13 : i64
gemmini.tile_conv %input %weights %bias %output %outdim %kernelDim {stride = 1} :
gemmini.tile_conv %input %weights %bias %output %outdim %outdim %kernelDim {stride = 1} :
memref<1x256x256x1xi8> memref<169x1xi8> memref<1xi32> memref<59536x1xi8> i64 i64
return
}
Expand Down
4 changes: 2 additions & 2 deletions examples/GemminiDialect/tile-conv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ func.func @main() -> i64 {
// CHECK: "gemmini.intr.loop_conv_ws_config6"
// CHECK: "gemmini.intr.loop_conv_ws"
// CHECK: "gemmini.intr.flush"
gemmini.tile_conv %input %weight %bias %output %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64
gemmini.tile_conv %input %weight %bias %output %3 %3 %3 {stride = 1}:
memref<1x5x5x1xi8> memref<9x2xi8> memref<2xi32> memref<9x2xi8> i64 i64 i64
gemmini.print %output : memref<9x2xi8>
return %0 : i64
}
25 changes: 14 additions & 11 deletions midend/lib/Dialect/Gemmini/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,18 +1230,21 @@ class GemminiTileConvLowering : public ConvertOpToLLVMPattern<TileConvOp> {
if (output != 0) {
cSpAddrRow = (cSpAddrRow + ACC_ROWS / 2) % ACC_ROWS;
}
gemminiLoopConvWs(
batchSize, inRowDim, inChannels, outChannels, outRowDim, poolOutRowDim, stride,
padding, kernelDim, kernelDilation, poolSize, poolStride, poolPadding,
batches, porows, pocols, pochs, krows, kcols, kchs, lpad, rpad, upad,
dpad, plpad, prpad, pupad, pdpad, orows, ocols, weights, output, bias,
input, noBias, noPool, downsample, wrot180, inputDilated, act,
transOutput1203, transWeight1203, transWeight0132, transInput3120,
maxPixelsPerRow, dw, tileConvOp, rewriter);
if (inRowDim == inColDim && outRowDim == outColDim && poolOutRowDim == poolOutColDim) {
gemminiLoopConvWs(
batchSize, inRowDim, inChannels, outChannels, outRowDim,
poolOutRowDim, stride, padding, kernelDim, kernelDilation, poolSize,
poolStride, poolPadding, batches, porows, pocols, pochs, krows, kcols,
kchs, lpad, rpad, upad, dpad, plpad, prpad, pupad, pdpad, orows,
ocols, weights, output, bias, input, noBias, noPool, downsample,
wrot180, inputDilated, act, transOutput1203, transWeight1203,
transWeight0132, transInput3120, maxPixelsPerRow, dw, tileConvOp,
rewriter);
return;
}
if (!noPool) {
// TODO: Exit, but now I don't known how to do
// printf("Pooling with rectangular convolutions is currently not supported.\n");
// exit(1);
llvm::outs() << "Pooling with rectangular convolutions is currently not supported.\n";
return;
}
// Only rectangular convolutions will use the following C code
// mvin bias
Expand Down

0 comments on commit 58f450f

Please sign in to comment.