diff --git a/Snakefile b/Snakefile index 0fb93c74..db1008ed 100644 --- a/Snakefile +++ b/Snakefile @@ -104,7 +104,11 @@ TESTSET_FAST = [ # 3d templated kernels *expand( "matmul_transb/4x16x16xf32/{variant}", - variant=["baseline", "snrt", "snitch_stream"], + variant=["linalg", "baseline", "snrt", "snitch_stream"], + ), + *expand( + "matmul_transb/4x16x16xf64/{variant}", + variant=["linalg", "linalg_xdsl"], ), *expand( "matmul/4x16x8xf64/{variant}", diff --git a/kernels/matmul_transb/data.h.template b/kernels/matmul_transb/data.h.template index 23080e7b..a1b14d3f 100644 --- a/kernels/matmul_transb/data.h.template +++ b/kernels/matmul_transb/data.h.template @@ -1,12 +1,23 @@ #pragma once +#define PRECISION {{precision}} +#if PRECISION == 16 +#define DTYPE __fp16 +#elif PRECISION == 32 +#define DTYPE float +#elif PRECISION == 64 +#define DTYPE double +#else +#error Unsupported precision +#endif + #define M {{M}} #define K {{K}} #define N {{N}} -extern const float X[M * K]; -extern const float Y[N * K]; -extern const float G_IN[M * N]; -extern const float G_OUT[M * N]; +extern const DTYPE X[M * K]; +extern const DTYPE Y[N * K]; +extern const DTYPE G_IN[M * N]; +extern const DTYPE G_OUT[M * N]; #define TEST_COUNT (M * N) diff --git a/kernels/matmul_transb/gendata.py b/kernels/matmul_transb/gendata.py index e843744f..0b0d81bc 100644 --- a/kernels/matmul_transb/gendata.py +++ b/kernels/matmul_transb/gendata.py @@ -3,7 +3,7 @@ # usage: `python -m matmul.gendata -p matmul/4x4xf64/params.json` import numpy as np -from typing import Iterator +from collections.abc import Iterator from gendatautils import main, Define, Array @@ -15,7 +15,7 @@ def matrix_data( yield Define("K", K) yield Define("N", N) - t = {32: np.float32}[precision] + t = {32: np.float32, 64: np.float64}[precision] # Errors accumulate a lot with the strategy used in snrt, # especially due to the repeated (quite a bit) flaky SIMD reductions. @@ -28,8 +28,8 @@ def matrix_data( n = N k = K np.random.seed(0) - x = np.random.uniform(rmin, rmax, m * k).astype(t).reshape((m, k)) - y = np.random.uniform(rmin, rmax, k * n).astype(t).reshape((k, n)) + x = np.random.uniform(rmin, rmax, (m, k)).astype(t) + y = np.random.uniform(rmin, rmax, (k, n)).astype(t) g_in = np.zeros((m, n), dtype=t) g_out = x @ y diff --git a/kernels/matmul_transb/linalg.mlir.template b/kernels/matmul_transb/linalg.mlir.template new file mode 100644 index 00000000..f0ee3f67 --- /dev/null +++ b/kernels/matmul_transb/linalg.mlir.template @@ -0,0 +1,8 @@ +func.func public @matmul_transb(%X: tensor<{{M}}x{{K}}xf{{precision}}> {"llvm.noalias"}, + %Y: tensor<{{N}}x{{K}}xf{{precision}}> {"llvm.noalias"}, + %Z: tensor<{{M}}x{{N}}xf{{precision}}> {"llvm.noalias"}) -> tensor<{{M}}x{{N}}xf{{precision}}> { + %zero = arith.constant 0.0 : f{{precision}} + %zeros = linalg.fill ins(%zero : f{{precision}}) outs(%Z : tensor<{{M}}x{{N}}xf{{precision}}>) -> tensor<{{M}}x{{N}}xf{{precision}}> + %res = linalg.matmul_transpose_b ins(%X, %Y : tensor<{{M}}x{{K}}xf{{precision}}>, tensor<{{N}}x{{K}}xf{{precision}}>) outs(%zeros : tensor<{{M}}x{{N}}xf{{precision}}>) -> tensor<{{M}}x{{N}}xf{{precision}}> + func.return %res : tensor<{{M}}x{{N}}xf{{precision}}> +} diff --git a/kernels/matmul_transb/main.c b/kernels/matmul_transb/main.c index 8d360a6f..e72b532b 100644 --- a/kernels/matmul_transb/main.c +++ b/kernels/matmul_transb/main.c @@ -5,22 +5,22 @@ #include // Kernel provided via external definition -void matmul_transb(const float *x, const float *y, float *g); +void matmul_transb(const DTYPE *x, const DTYPE *y, DTYPE *g); int main() { // Allocate shared local memory // By avoiding allocators and bumping by a known offset a base pointer // (snrt_l1_next()) that is the same for all the cores in the cluster, we are // essentially providing the same memory regions to all the cores in this cluster. - float *local_x = (float *)snrt_l1_next(); - float *local_y = local_x + M * K; - float *local_z = local_y + N * K; + DTYPE *local_x = (DTYPE *)snrt_l1_next(); + DTYPE *local_y = local_x + M * K; + DTYPE *local_z = local_y + N * K; // Copy data in shared local memory if (snrt_is_dm_core()) { - snrt_dma_start_1d(local_x, X, M * K * sizeof(float)); - snrt_dma_start_1d(local_y, Y, N * K * sizeof(float)); - snrt_dma_start_1d(local_z, G_IN, M * N * sizeof(float)); + snrt_dma_start_1d(local_x, X, M * K * sizeof(DTYPE)); + snrt_dma_start_1d(local_y, Y, N * K * sizeof(DTYPE)); + snrt_dma_start_1d(local_z, G_IN, M * N * sizeof(DTYPE)); snrt_dma_wait_all(); } @@ -39,7 +39,7 @@ int main() { // Correctness check int nerr = 0; for (int i = 0; i < TEST_COUNT; i++) { - float d = fabsf(local_z[i] - G_OUT[i]); + DTYPE d = fabs(local_z[i] - G_OUT[i]); nerr += !(d <= 1E-2f); // Make sure to take into account NaNs (e.g.: happy path // on the taken branch) } diff --git a/results/kernels.csv b/results/kernels.csv index fbe80bb2..cef48fc4 100644 --- a/results/kernels.csv +++ b/results/kernels.csv @@ -26,8 +26,11 @@ matmul,4x16x8xf64,baseline,2495,3293,3290,2.9941520467836256,1.4991334488734835, matmul,4x16x8xf64,linalg,2694,3483,3480,2.9941520467836256,1.4745484400656814,512,513,1536,0.19042316258351893,0.44415584415584414,1155,898,609,0.4287305122494432,0,33,1.0,1.0,1,3.4,1155,0.73520050922979,416,17,5,0.1544172234595397,69,790,0.0,0.5831477357089829,0.0 matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0.8163841807909604,0.996551724137931,580,0,0,0.8192090395480226,0,0,5.37037037037037,5.37037037037037,1,0.0,108,0.5869565217391305,76,0,0,0.10734463276836158,0,786,0.0,0.9265536723163842,0.0 matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0 +matmul_transb,4x16x16xf32,linalg,5038,5831,5828,2.9970731707317073,1.4995663486556807,0,1025,3072,0.203453751488686,0.4569772625947392,2243,1729,1153,0.44521635569670504,0,65,1.0,1.0,1,3.4,2243,0.8036546040845575,548,17,5,0.10877332274712187,69,794,0.0,0.553989678443827,0.0 matmul_transb,4x16x16xf32,snitch_stream,845,1636,1633,2.7429906542056073,0.0,0,642,1761,0.7597633136094675,0.9067796610169492,708,0,0,0.8378698224852071,0,64,2.0823529411764703,2.0823529411764707,1,0.0,340,0.7296137339055794,126,0,0,0.14911242603550295,0,792,0.0,0.98698224852071,0.0 matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0 +matmul_transb,4x16x16xf64,linalg,5142,5967,5964,2.9970731707317073,1.4995663486556807,1024,1025,3072,0.19933877868533645,0.4569772625947392,2243,1729,1153,0.43621159082069233,0,65,1.0,1.0,1,3.4,2243,0.7769310703152061,644,17,5,0.12524309607156747,133,826,0.0,0.5614546868922599,0.0 +matmul_transb,4x16x16xf64,linalg_xdsl,1295,2124,2121,2.815424610051993,0.0,1024,1154,3249,0.8911196911196911,0.9982698961937716,1156,0,0,0.8926640926640926,0,0,5.452830188679246,5.452830188679245,1,0.0,212,0.6794871794871795,100,0,0,0.07722007722007722,0,830,0.0,0.9698841698841698,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg,484,1242,1239,0.993103448275862,1.0909090909090908,0,145,144,0.29958677685950413,0.5823293172690763,249,96,88,0.5144628099173554,0,16,1.0,1.0,1,0.0,249,0.8498293515358362,44,0,0,0.09090909090909091,32,759,0.0,0.6053719008264463,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0 diff --git a/results/kernels.fast.csv b/results/kernels.fast.csv index fbe80bb2..cef48fc4 100644 --- a/results/kernels.fast.csv +++ b/results/kernels.fast.csv @@ -26,8 +26,11 @@ matmul,4x16x8xf64,baseline,2495,3293,3290,2.9941520467836256,1.4991334488734835, matmul,4x16x8xf64,linalg,2694,3483,3480,2.9941520467836256,1.4745484400656814,512,513,1536,0.19042316258351893,0.44415584415584414,1155,898,609,0.4287305122494432,0,33,1.0,1.0,1,3.4,1155,0.73520050922979,416,17,5,0.1544172234595397,69,790,0.0,0.5831477357089829,0.0 matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0.8163841807909604,0.996551724137931,580,0,0,0.8192090395480226,0,0,5.37037037037037,5.37037037037037,1,0.0,108,0.5869565217391305,76,0,0,0.10734463276836158,0,786,0.0,0.9265536723163842,0.0 matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0 +matmul_transb,4x16x16xf32,linalg,5038,5831,5828,2.9970731707317073,1.4995663486556807,0,1025,3072,0.203453751488686,0.4569772625947392,2243,1729,1153,0.44521635569670504,0,65,1.0,1.0,1,3.4,2243,0.8036546040845575,548,17,5,0.10877332274712187,69,794,0.0,0.553989678443827,0.0 matmul_transb,4x16x16xf32,snitch_stream,845,1636,1633,2.7429906542056073,0.0,0,642,1761,0.7597633136094675,0.9067796610169492,708,0,0,0.8378698224852071,0,64,2.0823529411764703,2.0823529411764707,1,0.0,340,0.7296137339055794,126,0,0,0.14911242603550295,0,792,0.0,0.98698224852071,0.0 matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0 +matmul_transb,4x16x16xf64,linalg,5142,5967,5964,2.9970731707317073,1.4995663486556807,1024,1025,3072,0.19933877868533645,0.4569772625947392,2243,1729,1153,0.43621159082069233,0,65,1.0,1.0,1,3.4,2243,0.7769310703152061,644,17,5,0.12524309607156747,133,826,0.0,0.5614546868922599,0.0 +matmul_transb,4x16x16xf64,linalg_xdsl,1295,2124,2121,2.815424610051993,0.0,1024,1154,3249,0.8911196911196911,0.9982698961937716,1156,0,0,0.8926640926640926,0,0,5.452830188679246,5.452830188679245,1,0.0,212,0.6794871794871795,100,0,0,0.07722007722007722,0,830,0.0,0.9698841698841698,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg,484,1242,1239,0.993103448275862,1.0909090909090908,0,145,144,0.29958677685950413,0.5823293172690763,249,96,88,0.5144628099173554,0,16,1.0,1.0,1,0.0,249,0.8498293515358362,44,0,0,0.09090909090909091,32,759,0.0,0.6053719008264463,0.0 pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0 diff --git a/results/pivoted.csv b/results/pivoted.csv index 80e299f3..e70a112e 100644 --- a/results/pivoted.csv +++ b/results/pivoted.csv @@ -4,7 +4,8 @@ ddot 128xf64,956,965,,213,577 dense 8x8xf64,3206,3530,,2741,2723 fill 4x4xf64,50,50,64,, matmul 4x16x8xf64,2495,2694,708,, -matmul_transb 4x16x16xf32,3386,,,845,849 +matmul_transb 4x16x16xf32,3386,5038,,845,849 +matmul_transb 4x16x16xf64,,5142,1295,, pooling_nchw_max_d1_s2_3x3 4x4xf64,584,484,275,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,832,271,, relu 4x4xf64,142,125,72,, diff --git a/results/pivoted.fast.csv b/results/pivoted.fast.csv index 80e299f3..e70a112e 100644 --- a/results/pivoted.fast.csv +++ b/results/pivoted.fast.csv @@ -4,7 +4,8 @@ ddot 128xf64,956,965,,213,577 dense 8x8xf64,3206,3530,,2741,2723 fill 4x4xf64,50,50,64,, matmul 4x16x8xf64,2495,2694,708,, -matmul_transb 4x16x16xf32,3386,,,845,849 +matmul_transb 4x16x16xf32,3386,5038,,845,849 +matmul_transb 4x16x16xf64,,5142,1295,, pooling_nchw_max_d1_s2_3x3 4x4xf64,584,484,275,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,832,271,, relu 4x4xf64,142,125,72,, diff --git a/results/pivoted_fpu.csv b/results/pivoted_fpu.csv index 4b657c1c..70b4c53c 100644 --- a/results/pivoted_fpu.csv +++ b/results/pivoted_fpu.csv @@ -4,7 +4,8 @@ ddot 128xf64,0.13,0.13,,0.64,0.22 dense 8x8xf64,0.20,0.18,,0.26,0.26 fill 4x4xf64,0.02,0.02,0.28,, matmul 4x16x8xf64,0.21,0.19,0.82,, -matmul_transb 4x16x16xf32,0.21,,,0.76,0.79 +matmul_transb 4x16x16xf32,0.21,0.20,,0.76,0.79 +matmul_transb 4x16x16xf64,,0.20,0.89,, pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,0.30,0.65,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,0.17,0.66,, relu 4x4xf64,0.13,0.14,0.25,, diff --git a/results/pivoted_fpu.fast.csv b/results/pivoted_fpu.fast.csv index 4b657c1c..70b4c53c 100644 --- a/results/pivoted_fpu.fast.csv +++ b/results/pivoted_fpu.fast.csv @@ -4,7 +4,8 @@ ddot 128xf64,0.13,0.13,,0.64,0.22 dense 8x8xf64,0.20,0.18,,0.26,0.26 fill 4x4xf64,0.02,0.02,0.28,, matmul 4x16x8xf64,0.21,0.19,0.82,, -matmul_transb 4x16x16xf32,0.21,,,0.76,0.79 +matmul_transb 4x16x16xf32,0.21,0.20,,0.76,0.79 +matmul_transb 4x16x16xf64,,0.20,0.89,, pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,0.30,0.65,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,0.17,0.66,, relu 4x4xf64,0.13,0.14,0.25,, diff --git a/results/pivoted_ipc.csv b/results/pivoted_ipc.csv index 3f816012..bcec93a6 100644 --- a/results/pivoted_ipc.csv +++ b/results/pivoted_ipc.csv @@ -4,7 +4,8 @@ ddot 128xf64,0.95,0.94,,0.74,0.25 dense 8x8xf64,0.51,0.55,,0.39,0.33 fill 4x4xf64,0.46,0.46,0.53,, matmul 4x16x8xf64,0.56,0.58,0.93,, -matmul_transb 4x16x16xf32,0.95,,,0.99,0.88 +matmul_transb 4x16x16xf32,0.95,0.55,,0.99,0.88 +matmul_transb 4x16x16xf64,,0.56,0.97,, pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,0.61,0.84,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,0.43,0.84,, relu 4x4xf64,0.40,0.46,0.53,, diff --git a/results/pivoted_ipc.fast.csv b/results/pivoted_ipc.fast.csv index 3f816012..bcec93a6 100644 --- a/results/pivoted_ipc.fast.csv +++ b/results/pivoted_ipc.fast.csv @@ -4,7 +4,8 @@ ddot 128xf64,0.95,0.94,,0.74,0.25 dense 8x8xf64,0.51,0.55,,0.39,0.33 fill 4x4xf64,0.46,0.46,0.53,, matmul 4x16x8xf64,0.56,0.58,0.93,, -matmul_transb 4x16x16xf32,0.95,,,0.99,0.88 +matmul_transb 4x16x16xf32,0.95,0.55,,0.99,0.88 +matmul_transb 4x16x16xf64,,0.56,0.97,, pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,0.61,0.84,, pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,0.43,0.84,, relu 4x4xf64,0.40,0.46,0.53,, diff --git a/results/regalloc.fast.csv b/results/regalloc.fast.csv index e088657a..825a9b6c 100644 --- a/results/regalloc.fast.csv +++ b/results/regalloc.fast.csv @@ -5,6 +5,7 @@ dense,8x8xf64,5,11 fill,4x4xf64,3,3 matmul,4x16x8xf64,8,8 matmul_transb,4x16x16xf32,11,12 +matmul_transb,4x16x16xf64,8,8 pooling_nchw_max_d1_s2_3x3,4x4xf64,7,6 pooling_nchw_sum_d1_s2_3x3,4x4xf64,7,6 relu,4x4xf64,3,5