Skip to content

Commit

Permalink
add scf_xdsl rule and collect cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 28, 2023
1 parent 118c9d4 commit 4f0e0c9
Show file tree
Hide file tree
Showing 32 changed files with 179 additions and 17 deletions.
10 changes: 10 additions & 0 deletions kernels/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ CONV_8_TESTS += $(CONV_8)/baseline.csv
CONV_8_TESTS += $(CONV_8)/linalg.csv
CONV_8_TESTS += $(CONV_8)/snitch_stream.csv
CONV_8_TESTS += $(CONV_8)/snrt.csv
# spill to j register...
# CONV_8_TESTS += $(CONV_8)/scf_xdsl.csv

$(CONV_8)/tests.csv: $(CONV_8_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -47,6 +49,8 @@ DENSE_8_TESTS += $(DENSE_8)/fused.csv
DENSE_8_TESTS += $(DENSE_8)/linalg.csv
DENSE_8_TESTS += $(DENSE_8)/snrt.csv
DENSE_8_TESTS += $(DENSE_8)/snitch_stream.csv
# spill to j register...
# DENSE_8_TESTS += $(DENSE_8)/scf_xdsl.csv

$(DENSE_8)/tests.csv: $(DENSE_8_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -65,6 +69,7 @@ DSUM_8_16_TESTS += $(DSUM_8_16)/snrt.csv
DSUM_8_16_TESTS += $(DSUM_8_16)/ssr1d.csv
DSUM_8_16_TESTS += $(DSUM_8_16)/ssr2d.csv
DSUM_8_16_TESTS += $(DSUM_8_16)/snitch_stream.csv
DSUM_8_16_TESTS += $(DSUM_8_16)/scf_xdsl.csv

$(DSUM_8_16)/tests.csv: $(DSUM_8_16_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -80,6 +85,7 @@ MATMUL_8_TESTS += $(MATMUL_8)/baseline.csv
MATMUL_8_TESTS += $(MATMUL_8)/linalg.csv
MATMUL_8_TESTS += $(MATMUL_8)/snitch_stream.csv
MATMUL_8_TESTS += $(MATMUL_8)/snrt.csv
MATMUL_8_TESTS += $(MATMUL_8)/scf_xdsl.csv

$(MATMUL_8)/tests.csv: $(MATMUL_8_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -95,6 +101,7 @@ MAX_POOL_16_TESTS += $(MAX_POOL_16)/baseline.csv
MAX_POOL_16_TESTS += $(MAX_POOL_16)/linalg.csv
MAX_POOL_16_TESTS += $(MAX_POOL_16)/snitch_stream.csv
MAX_POOL_16_TESTS += $(MAX_POOL_16)/snrt.csv
MAX_POOL_16_TESTS += $(MAX_POOL_16)/scf_xdsl.csv

$(MAX_POOL_16)/tests.csv: $(MAX_POOL_16_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -110,6 +117,7 @@ SUM_POOL_16_TESTS += $(SUM_POOL_16)/baseline.csv
SUM_POOL_16_TESTS += $(SUM_POOL_16)/linalg.csv
SUM_POOL_16_TESTS += $(SUM_POOL_16)/snitch_stream.csv
SUM_POOL_16_TESTS += $(SUM_POOL_16)/snrt.csv
SUM_POOL_16_TESTS += $(SUM_POOL_16)/scf_xdsl.csv

$(SUM_POOL_16)/tests.csv: $(SUM_POOL_16_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -127,6 +135,7 @@ RELU_16_TESTS += $(RELU_16)/ssr_frep.csv
RELU_16_TESTS += $(RELU_16)/snrt.csv
RELU_16_TESTS += $(RELU_16)/linalg.csv
RELU_16_TESTS += $(RELU_16)/snitch_stream.csv
RELU_16_TESTS += $(RELU_16)/scf_xdsl.csv

$(RELU_16)/tests.csv: $(RELU_16_TESTS)
./generate_tests_csv.sh $@ $^
Expand All @@ -141,6 +150,7 @@ FILL_16_TESTS += $(FILL_16)/baseline.csv
FILL_16_TESTS += $(FILL_16)/linalg.csv
FILL_16_TESTS += $(FILL_16)/snitch_stream.csv
FILL_16_TESTS += $(FILL_16)/snrt.csv
FILL_16_TESTS += $(FILL_16)/scf_xdsl.csv


$(FILL_16)/tests.csv: $(FILL_16_TESTS)
Expand Down
1 change: 1 addition & 0 deletions kernels/dsum/8x16xf32/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ TESTS += snitch_stream.x
TESTS += linalg.x
TESTS += linalg_xdsl.x
TESTS += scf.x
TESTS += scf_xdsl.x

include ../../Makefile.kernels
1 change: 1 addition & 0 deletions kernels/dsum/8x16xf32/scf_xdsl.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3564
17 changes: 17 additions & 0 deletions kernels/dsum/8x16xf32/scf_xdsl.xdsl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module {
func.func public @dsum(%arg0: memref<8x16xf64>, %arg1: memref<8x16xf64>, %arg2: memref<8x16xf64>) -> memref<8x16xf64> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
scf.for %arg3 = %c0 to %c8 step %c1 {
scf.for %arg4 = %c0 to %c16 step %c1 {
%0 = memref.load %arg0[%arg3, %arg4] : memref<8x16xf64>
%1 = memref.load %arg1[%arg3, %arg4] : memref<8x16xf64>
%2 = arith.addf %0, %1 : f64
memref.store %2, %arg2[%arg3, %arg4] : memref<8x16xf64>
}
}
return %arg2 : memref<8x16xf64>
}
}
1 change: 1 addition & 0 deletions kernels/dsum/8x16xf32/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ snrt,187
ssr1d,253
ssr2d,273
snitch_stream,206
scf_xdsl,3564
1 change: 1 addition & 0 deletions kernels/dsum/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
8x16xf32,ssr1d,253
8x16xf32,ssr2d,273
8x16xf32,snitch_stream,206
8x16xf32,scf_xdsl,3564
1 change: 1 addition & 0 deletions kernels/fill/16x16xf64/scf_xdsl.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2676
13 changes: 13 additions & 0 deletions kernels/fill/16x16xf64/scf_xdsl.xdsl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module {
func.func public @fill(%arg0: f64, %arg1: memref<16x16xf64>) -> memref<16x16xf64> {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
scf.for %arg2 = %c0 to %c16 step %c1 {
scf.for %arg3 = %c0 to %c16 step %c1 {
memref.store %arg0, %arg1[%arg2, %arg3] : memref<16x16xf64>
}
}
return %arg1 : memref<16x16xf64>
}
}
1 change: 1 addition & 0 deletions kernels/fill/16x16xf64/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ baseline,370
linalg,347
snitch_stream,296
snrt,299
scf_xdsl,2676
1 change: 1 addition & 0 deletions kernels/fill/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
16x16xf64,linalg,347
16x16xf64,snitch_stream,296
16x16xf64,snrt,299
16x16xf64,scf_xdsl,2676
6 changes: 6 additions & 0 deletions kernels/kernels.csv
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,32 @@ dsum,8x16xf32,snrt,187
dsum,8x16xf32,ssr1d,253
dsum,8x16xf32,ssr2d,273
dsum,8x16xf32,snitch_stream,206
dsum,8x16xf32,scf_xdsl,3564
fill,16x16xf64,baseline,370
fill,16x16xf64,linalg,347
fill,16x16xf64,snitch_stream,296
fill,16x16xf64,snrt,299
fill,16x16xf64,scf_xdsl,2676
matmul,8x8xf64,baseline,4230
matmul,8x8xf64,linalg,6220
matmul,8x8xf64,snitch_stream,2339
matmul,8x8xf64,snrt,2321
matmul,8x8xf64,scf_xdsl,21363
pooling_nchw_max_d1_s2_3x3,1x16x16x1xf64,baseline,1406
pooling_nchw_max_d1_s2_3x3,1x16x16x1xf64,linalg,2330
pooling_nchw_max_d1_s2_3x3,1x16x16x1xf64,snitch_stream,1116
pooling_nchw_max_d1_s2_3x3,1x16x16x1xf64,snrt,1124
pooling_nchw_max_d1_s2_3x3,1x16x16x1xf64,scf_xdsl,24729
pooling_nchw_sum_d1_s2_3x3,1x16x16x1xf64,baseline,1982
pooling_nchw_sum_d1_s2_3x3,1x16x16x1xf64,linalg,3212
pooling_nchw_sum_d1_s2_3x3,1x16x16x1xf64,snitch_stream,2012
pooling_nchw_sum_d1_s2_3x3,1x16x16x1xf64,snrt,2006
pooling_nchw_sum_d1_s2_3x3,1x16x16x1xf64,scf_xdsl,24729
relu,16x16xf64,baseline,1339
relu,16x16xf64,ssr,846
relu,16x16xf64,ssr_frep,327
relu,16x16xf64,snrt,334
relu,16x16xf64,linalg,1335
relu,16x16xf64,snitch_stream,322
relu,16x16xf64,scf_xdsl,5020
softmax,16xf64,baseline,32611
1 change: 1 addition & 0 deletions kernels/matmul/8x8xf64/scf_xdsl.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
21363
20 changes: 20 additions & 0 deletions kernels/matmul/8x8xf64/scf_xdsl.xdsl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module {
func.func public @matmul(%arg0: memref<8x8xf64>, %arg1: memref<8x8xf64>, %arg2: memref<8x8xf64>) -> memref<8x8xf64> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
scf.for %arg3 = %c0 to %c8 step %c1 {
scf.for %arg4 = %c0 to %c8 step %c1 {
scf.for %arg5 = %c0 to %c8 step %c1 {
%0 = memref.load %arg0[%arg3, %arg5] : memref<8x8xf64>
%1 = memref.load %arg1[%arg5, %arg4] : memref<8x8xf64>
%2 = memref.load %arg2[%arg3, %arg4] : memref<8x8xf64>
%3 = arith.mulf %0, %1 : f64
%4 = arith.addf %2, %3 : f64
memref.store %4, %arg2[%arg3, %arg4] : memref<8x8xf64>
}
}
}
return %arg2 : memref<8x8xf64>
}
}
1 change: 1 addition & 0 deletions kernels/matmul/8x8xf64/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ baseline,4230
linalg,6220
snitch_stream,2339
snrt,2321
scf_xdsl,21363
1 change: 1 addition & 0 deletions kernels/matmul/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
8x8xf64,linalg,6220
8x8xf64,snitch_stream,2339
8x8xf64,snrt,2321
8x8xf64,scf_xdsl,21363
20 changes: 10 additions & 10 deletions kernels/pivoted.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
combined,baseline,linalg,linalg_xdsl,snitch_stream,snrt,min_llvm_mlir,speedup
conv2d_d1_s1_3x3 1x8x8x1xf64,1850,3946,,1500,1498,1850,1.23x
dense 8x8xf64,3240,7069,,2721,2711,3240,1.19x
dsum 8x16xf32,1202,1077,194,206,187,1077,5.23x
fill 16x16xf64,370,347,,296,299,347,1.17x
matmul 8x8xf64,4230,6220,,2339,2321,4230,1.81x
pooling_nchw_max_d1_s2_3x3 1x16x16x1xf64,1406,2330,,1116,1124,1406,1.26x
pooling_nchw_sum_d1_s2_3x3 1x16x16x1xf64,1982,3212,,2012,2006,1982,0.99x
relu 16x16xf64,1339,1335,,322,334,1335,4.15x
softmax 16xf64,32611,,,,,32611,?x
combined,baseline,linalg,linalg_xdsl,scf_xdsl,snitch_stream,snrt,min_llvm_mlir,speedup
conv2d_d1_s1_3x3 1x8x8x1xf64,1850,3946,,,1500,1498,1850,1.23x
dense 8x8xf64,3240,7069,,,2721,2711,3240,1.19x
dsum 8x16xf32,1202,1077,194,3564,206,187,1077,5.23x
fill 16x16xf64,370,347,,2676,296,299,347,1.17x
matmul 8x8xf64,4230,6220,,21363,2339,2321,4230,1.81x
pooling_nchw_max_d1_s2_3x3 1x16x16x1xf64,1406,2330,,24729,1116,1124,1406,1.26x
pooling_nchw_sum_d1_s2_3x3 1x16x16x1xf64,1982,3212,,24729,2012,2006,1982,0.99x
relu 16x16xf64,1339,1335,,5020,322,334,1335,4.15x
softmax 16xf64,32611,,,,,,32611,?x
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
24729
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module {
func.func public @pooling_nchw_max_d1_s2_3x3(%arg0: memref<1x1x16x16xf64>, %arg1: memref<1x1x7x7xf64>) -> memref<1x1x7x7xf64> {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c7 = arith.constant 7 : index
%c3 = arith.constant 3 : index
scf.for %arg2 = %c0 to %c7 step %c1 {
scf.for %arg3 = %c0 to %c7 step %c1 {
scf.for %arg4 = %c0 to %c3 step %c1 {
scf.for %arg5 = %c0 to %c3 step %c1 {
%0 = arith.muli %arg2, %c2 : index
%1 = arith.addi %0, %arg4 : index
%2 = arith.muli %arg3, %c2 : index
%3 = arith.addi %2, %arg5 : index
%4 = memref.load %arg0[%c0, %c0, %1, %3] : memref<1x1x16x16xf64>
%5 = memref.load %arg1[%c0, %c0, %arg2, %arg3] : memref<1x1x7x7xf64>
%6 = arith.maxf %5, %4 : f64
memref.store %6, %arg1[%c0, %c0, %arg2, %arg3] : memref<1x1x7x7xf64>
}
}
}
}
return %arg1 : memref<1x1x7x7xf64>
}
}
1 change: 1 addition & 0 deletions kernels/pooling_nchw_max_d1_s2_3x3/1x16x16x1xf64/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ baseline,1406
linalg,2330
snitch_stream,1116
snrt,1124
scf_xdsl,24729
1 change: 1 addition & 0 deletions kernels/pooling_nchw_max_d1_s2_3x3/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
1x16x16x1xf64,linalg,2330
1x16x16x1xf64,snitch_stream,1116
1x16x16x1xf64,snrt,1124
1x16x16x1xf64,scf_xdsl,24729
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
24729
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module {
func.func public @pooling_nchw_sum_d1_s2_3x3(%arg0: memref<1x1x16x16xf64>, %arg1: memref<1x1x7x7xf64>) -> memref<1x1x7x7xf64> {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c7 = arith.constant 7 : index
%c3 = arith.constant 3 : index
scf.for %arg2 = %c0 to %c7 step %c1 {
scf.for %arg3 = %c0 to %c7 step %c1 {
scf.for %arg4 = %c0 to %c3 step %c1 {
scf.for %arg5 = %c0 to %c3 step %c1 {
%0 = arith.muli %arg2, %c2 : index
%1 = arith.addi %0, %arg4 : index
%2 = arith.muli %arg3, %c2 : index
%3 = arith.addi %2, %arg5 : index
%4 = memref.load %arg0[%c0, %c0, %1, %3] : memref<1x1x16x16xf64>
%5 = memref.load %arg1[%c0, %c0, %arg2, %arg3] : memref<1x1x7x7xf64>
%6 = arith.addf %5, %4 : f64
memref.store %6, %arg1[%c0, %c0, %arg2, %arg3] : memref<1x1x7x7xf64>
}
}
}
}
return %arg1 : memref<1x1x7x7xf64>
}
}
1 change: 1 addition & 0 deletions kernels/pooling_nchw_sum_d1_s2_3x3/1x16x16x1xf64/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ baseline,1982
linalg,3212
snitch_stream,2012
snrt,2006
scf_xdsl,24729
1 change: 1 addition & 0 deletions kernels/pooling_nchw_sum_d1_s2_3x3/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
1x16x16x1xf64,linalg,3212
1x16x16x1xf64,snitch_stream,2012
1x16x16x1xf64,snrt,2006
1x16x16x1xf64,scf_xdsl,24729
1 change: 1 addition & 0 deletions kernels/relu/16x16xf64/scf_xdsl.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5020
16 changes: 16 additions & 0 deletions kernels/relu/16x16xf64/scf_xdsl.xdsl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module {
func.func public @relu(%arg0: memref<16x16xf64>, %arg1: memref<16x16xf64>) -> memref<16x16xf64> {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f64
scf.for %arg2 = %c0 to %c16 step %c1 {
scf.for %arg3 = %c0 to %c16 step %c1 {
%0 = memref.load %arg0[%arg2, %arg3] : memref<16x16xf64>
%1 = arith.maxf %0, %cst : f64
memref.store %1, %arg1[%arg2, %arg3] : memref<16x16xf64>
}
}
return %arg1 : memref<16x16xf64>
}
}
1 change: 1 addition & 0 deletions kernels/relu/16x16xf64/tests.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ ssr_frep,327
snrt,334
linalg,1335
snitch_stream,322
scf_xdsl,5020
1 change: 1 addition & 0 deletions kernels/relu/params.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
16x16xf64,snrt,334
16x16xf64,linalg,1335
16x16xf64,snitch_stream,322
16x16xf64,scf_xdsl,5020
2 changes: 1 addition & 1 deletion scripts/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

pivoted = df.pivot(index="combined", columns="impl")["cycles"]

PIVOTED_COLS = set(("linalg", "baseline", "snitch_stream", "snrt", "linalg_xdsl"))
PIVOTED_COLS = set(("linalg", "baseline", "snitch_stream", "snrt", "linalg_xdsl", "scf_xdsl"))

for col in pivoted:
if col not in PIVOTED_COLS:
Expand Down
17 changes: 13 additions & 4 deletions snitch/Makefile.rules
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ XDSLOPTFLAGS += -t riscv-asm
%/linalg_xdsl.xdsl.mlir: %/linalg.mlir
$(MLIROPT) $(MLIROPTFLAGS_XDSL) --mlir-print-local-scope -o $@ $<

.PRECIOUS: %/scf_xdsl.xdsl.mlir
%/scf_xdsl.xdsl.mlir: %/linalg.mlir
$(MLIROPT) $(MLIROPTFLAGS_SCF_XDSL) --mlir-print-local-scope -o $@ $<

%.S: %.xdsl.mlir $(XDSL_COMMIT_FILE)
$(XDSLOPT) $(XDSLOPTFLAGS) $< -o $@

Expand All @@ -151,10 +155,12 @@ MLIROPTFLAGS_XDSL += -empty-tensor-to-alloc-tensor
MLIROPTFLAGS_XDSL += --one-shot-bufferize='bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map'
MLIROPTFLAGS_XDSL += --canonicalize

MLIROPTFLAGS = $(MLIROPTFLAGS_XDSL)
MLIROPTFLAGS += --convert-linalg-to-loops
MLIROPTFLAGS += --lower-affine
MLIROPTFLAGS += --canonicalize
MLIROPTFLAGS_SCF_XDSL = $(MLIROPTFLAGS_XDSL)
MLIROPTFLAGS_SCF_XDSL += --convert-linalg-to-loops
MLIROPTFLAGS_SCF_XDSL += --lower-affine
MLIROPTFLAGS_SCF_XDSL += --canonicalize

MLIROPTFLAGS = $(MLIROPTFLAGS_SCF_XDSL)
MLIROPTFLAGS += --convert-scf-to-cf
MLIROPTFLAGS += --canonicalize
MLIROPTFLAGS += --cse
Expand Down Expand Up @@ -204,6 +210,9 @@ MLIROPTFLAGS += --reconcile-unrealized-casts
%/fused.x: %/fused.o %/main.o %/data.o
cd $(dir $<); $(LD) $(LDFLAGS) $(notdir $^) -o $(notdir $@)

%/scf_xdsl.x: %/scf_xdsl.o %/main.o %/data.o
cd $(dir $<); $(LD) $(LDFLAGS) $(notdir $^) -o $(notdir $@)

# Trace rules

LOG_DIR = $<.logs
Expand Down
2 changes: 1 addition & 1 deletion xdsl_commit.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d40ad7beed1fc9e9bb8a2cdbba43545645927f7d
8bddd0aada21c3591a4ab0d325b44116515bf19e

0 comments on commit 4f0e0c9

Please sign in to comment.