Skip to content

Commit

Permalink
Merge pull request #317 from opencompl/sasha/matmul-f32
Browse files Browse the repository at this point in the history
add linalg and linalg_xdsl matmul f32
  • Loading branch information
superlopuh authored Oct 21, 2024
2 parents 1c387b6 + 85869fa commit 1fc89dd
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 23 deletions.
6 changes: 5 additions & 1 deletion Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ TESTSET_FAST = [
# 3d templated kernels
*expand(
"matmul_transb/4x16x16xf32/{variant}",
variant=["baseline", "snrt", "snitch_stream"],
variant=["linalg", "baseline", "snrt", "snitch_stream"],
),
*expand(
"matmul_transb/4x16x16xf64/{variant}",
variant=["linalg", "linalg_xdsl"],
),
*expand(
"matmul/4x16x8xf64/{variant}",
Expand Down
19 changes: 15 additions & 4 deletions kernels/matmul_transb/data.h.template
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
#pragma once

#define PRECISION {{precision}}
#if PRECISION == 16
#define DTYPE __fp16
#elif PRECISION == 32
#define DTYPE float
#elif PRECISION == 64
#define DTYPE double
#else
#error Unsupported precision
#endif

#define M {{M}}
#define K {{K}}
#define N {{N}}

extern const float X[M * K];
extern const float Y[N * K];
extern const float G_IN[M * N];
extern const float G_OUT[M * N];
extern const DTYPE X[M * K];
extern const DTYPE Y[N * K];
extern const DTYPE G_IN[M * N];
extern const DTYPE G_OUT[M * N];

#define TEST_COUNT (M * N)
8 changes: 4 additions & 4 deletions kernels/matmul_transb/gendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# usage: `python -m matmul.gendata -p matmul/4x4xf64/params.json`

import numpy as np
from typing import Iterator
from collections.abc import Iterator

from gendatautils import main, Define, Array

Expand All @@ -15,7 +15,7 @@ def matrix_data(
yield Define("K", K)
yield Define("N", N)

t = {32: np.float32}[precision]
t = {32: np.float32, 64: np.float64}[precision]

# Errors accumulate a lot with the strategy used in snrt,
# especially due to the repeated (quite a bit) flaky SIMD reductions.
Expand All @@ -28,8 +28,8 @@ def matrix_data(
n = N
k = K
np.random.seed(0)
x = np.random.uniform(rmin, rmax, m * k).astype(t).reshape((m, k))
y = np.random.uniform(rmin, rmax, k * n).astype(t).reshape((k, n))
x = np.random.uniform(rmin, rmax, (m, k)).astype(t)
y = np.random.uniform(rmin, rmax, (k, n)).astype(t)
g_in = np.zeros((m, n), dtype=t)

g_out = x @ y
Expand Down
8 changes: 8 additions & 0 deletions kernels/matmul_transb/linalg.mlir.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
func.func public @matmul_transb(%X: tensor<{{M}}x{{K}}xf{{precision}}> {"llvm.noalias"},
%Y: tensor<{{N}}x{{K}}xf{{precision}}> {"llvm.noalias"},
%Z: tensor<{{M}}x{{N}}xf{{precision}}> {"llvm.noalias"}) -> tensor<{{M}}x{{N}}xf{{precision}}> {
%zero = arith.constant 0.0 : f{{precision}}
%zeros = linalg.fill ins(%zero : f{{precision}}) outs(%Z : tensor<{{M}}x{{N}}xf{{precision}}>) -> tensor<{{M}}x{{N}}xf{{precision}}>
%res = linalg.matmul_transpose_b ins(%X, %Y : tensor<{{M}}x{{K}}xf{{precision}}>, tensor<{{N}}x{{K}}xf{{precision}}>) outs(%zeros : tensor<{{M}}x{{N}}xf{{precision}}>) -> tensor<{{M}}x{{N}}xf{{precision}}>
func.return %res : tensor<{{M}}x{{N}}xf{{precision}}>
}
16 changes: 8 additions & 8 deletions kernels/matmul_transb/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
#include <math.h>

// Kernel provided via external definition
void matmul_transb(const float *x, const float *y, float *g);
void matmul_transb(const DTYPE *x, const DTYPE *y, DTYPE *g);

int main() {
// Allocate shared local memory
// By avoiding allocators and bumping by a known offset a base pointer
// (snrt_l1_next()) that is the same for all the cores in the cluster, we are
// essentially providing the same memory regions to all the cores in this cluster.
float *local_x = (float *)snrt_l1_next();
float *local_y = local_x + M * K;
float *local_z = local_y + N * K;
DTYPE *local_x = (DTYPE *)snrt_l1_next();
DTYPE *local_y = local_x + M * K;
DTYPE *local_z = local_y + N * K;

// Copy data in shared local memory
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_x, X, M * K * sizeof(float));
snrt_dma_start_1d(local_y, Y, N * K * sizeof(float));
snrt_dma_start_1d(local_z, G_IN, M * N * sizeof(float));
snrt_dma_start_1d(local_x, X, M * K * sizeof(DTYPE));
snrt_dma_start_1d(local_y, Y, N * K * sizeof(DTYPE));
snrt_dma_start_1d(local_z, G_IN, M * N * sizeof(DTYPE));
snrt_dma_wait_all();
}

Expand All @@ -39,7 +39,7 @@ int main() {
// Correctness check
int nerr = 0;
for (int i = 0; i < TEST_COUNT; i++) {
float d = fabsf(local_z[i] - G_OUT[i]);
DTYPE d = fabs(local_z[i] - G_OUT[i]);
nerr += !(d <= 1E-2f); // Make sure to take into account NaNs (e.g.: happy path
// on the taken branch)
}
Expand Down
3 changes: 3 additions & 0 deletions results/kernels.csv
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ matmul,4x16x8xf64,baseline,2495,3293,3290,2.9941520467836256,1.4991334488734835,
matmul,4x16x8xf64,linalg,2694,3483,3480,2.9941520467836256,1.4745484400656814,512,513,1536,0.19042316258351893,0.44415584415584414,1155,898,609,0.4287305122494432,0,33,1.0,1.0,1,3.4,1155,0.73520050922979,416,17,5,0.1544172234595397,69,790,0.0,0.5831477357089829,0.0
matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0.8163841807909604,0.996551724137931,580,0,0,0.8192090395480226,0,0,5.37037037037037,5.37037037037037,1,0.0,108,0.5869565217391305,76,0,0,0.10734463276836158,0,786,0.0,0.9265536723163842,0.0
matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0
matmul_transb,4x16x16xf32,linalg,5038,5831,5828,2.9970731707317073,1.4995663486556807,0,1025,3072,0.203453751488686,0.4569772625947392,2243,1729,1153,0.44521635569670504,0,65,1.0,1.0,1,3.4,2243,0.8036546040845575,548,17,5,0.10877332274712187,69,794,0.0,0.553989678443827,0.0
matmul_transb,4x16x16xf32,snitch_stream,845,1636,1633,2.7429906542056073,0.0,0,642,1761,0.7597633136094675,0.9067796610169492,708,0,0,0.8378698224852071,0,64,2.0823529411764703,2.0823529411764707,1,0.0,340,0.7296137339055794,126,0,0,0.14911242603550295,0,792,0.0,0.98698224852071,0.0
matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0
matmul_transb,4x16x16xf64,linalg,5142,5967,5964,2.9970731707317073,1.4995663486556807,1024,1025,3072,0.19933877868533645,0.4569772625947392,2243,1729,1153,0.43621159082069233,0,65,1.0,1.0,1,3.4,2243,0.7769310703152061,644,17,5,0.12524309607156747,133,826,0.0,0.5614546868922599,0.0
matmul_transb,4x16x16xf64,linalg_xdsl,1295,2124,2121,2.815424610051993,0.0,1024,1154,3249,0.8911196911196911,0.9982698961937716,1156,0,0,0.8926640926640926,0,0,5.452830188679246,5.452830188679245,1,0.0,212,0.6794871794871795,100,0,0,0.07722007722007722,0,830,0.0,0.9698841698841698,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg,484,1242,1239,0.993103448275862,1.0909090909090908,0,145,144,0.29958677685950413,0.5823293172690763,249,96,88,0.5144628099173554,0,16,1.0,1.0,1,0.0,249,0.8498293515358362,44,0,0,0.09090909090909091,32,759,0.0,0.6053719008264463,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0
Expand Down
3 changes: 3 additions & 0 deletions results/kernels.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ matmul,4x16x8xf64,baseline,2495,3293,3290,2.9941520467836256,1.4991334488734835,
matmul,4x16x8xf64,linalg,2694,3483,3480,2.9941520467836256,1.4745484400656814,512,513,1536,0.19042316258351893,0.44415584415584414,1155,898,609,0.4287305122494432,0,33,1.0,1.0,1,3.4,1155,0.73520050922979,416,17,5,0.1544172234595397,69,790,0.0,0.5831477357089829,0.0
matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0.8163841807909604,0.996551724137931,580,0,0,0.8192090395480226,0,0,5.37037037037037,5.37037037037037,1,0.0,108,0.5869565217391305,76,0,0,0.10734463276836158,0,786,0.0,0.9265536723163842,0.0
matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0
matmul_transb,4x16x16xf32,linalg,5038,5831,5828,2.9970731707317073,1.4995663486556807,0,1025,3072,0.203453751488686,0.4569772625947392,2243,1729,1153,0.44521635569670504,0,65,1.0,1.0,1,3.4,2243,0.8036546040845575,548,17,5,0.10877332274712187,69,794,0.0,0.553989678443827,0.0
matmul_transb,4x16x16xf32,snitch_stream,845,1636,1633,2.7429906542056073,0.0,0,642,1761,0.7597633136094675,0.9067796610169492,708,0,0,0.8378698224852071,0,64,2.0823529411764703,2.0823529411764707,1,0.0,340,0.7296137339055794,126,0,0,0.14911242603550295,0,792,0.0,0.98698224852071,0.0
matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0
matmul_transb,4x16x16xf64,linalg,5142,5967,5964,2.9970731707317073,1.4995663486556807,1024,1025,3072,0.19933877868533645,0.4569772625947392,2243,1729,1153,0.43621159082069233,0,65,1.0,1.0,1,3.4,2243,0.7769310703152061,644,17,5,0.12524309607156747,133,826,0.0,0.5614546868922599,0.0
matmul_transb,4x16x16xf64,linalg_xdsl,1295,2124,2121,2.815424610051993,0.0,1024,1154,3249,0.8911196911196911,0.9982698961937716,1156,0,0,0.8926640926640926,0,0,5.452830188679246,5.452830188679245,1,0.0,212,0.6794871794871795,100,0,0,0.07722007722007722,0,830,0.0,0.9698841698841698,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg,484,1242,1239,0.993103448275862,1.0909090909090908,0,145,144,0.29958677685950413,0.5823293172690763,249,96,88,0.5144628099173554,0,16,1.0,1.0,1,0.0,249,0.8498293515358362,44,0,0,0.09090909090909091,32,759,0.0,0.6053719008264463,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,956,965,,213,577
dense 8x8xf64,3206,3530,,2741,2723
fill 4x4xf64,50,50,64,,
matmul 4x16x8xf64,2495,2694,708,,
matmul_transb 4x16x16xf32,3386,,,845,849
matmul_transb 4x16x16xf32,3386,5038,,845,849
matmul_transb 4x16x16xf64,,5142,1295,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,584,484,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,832,271,,
relu 4x4xf64,142,125,72,,
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,956,965,,213,577
dense 8x8xf64,3206,3530,,2741,2723
fill 4x4xf64,50,50,64,,
matmul 4x16x8xf64,2495,2694,708,,
matmul_transb 4x16x16xf32,3386,,,845,849
matmul_transb 4x16x16xf32,3386,5038,,845,849
matmul_transb 4x16x16xf64,,5142,1295,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,584,484,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,832,271,,
relu 4x4xf64,142,125,72,,
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted_fpu.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,0.13,0.13,,0.64,0.22
dense 8x8xf64,0.20,0.18,,0.26,0.26
fill 4x4xf64,0.02,0.02,0.28,,
matmul 4x16x8xf64,0.21,0.19,0.82,,
matmul_transb 4x16x16xf32,0.21,,,0.76,0.79
matmul_transb 4x16x16xf32,0.21,0.20,,0.76,0.79
matmul_transb 4x16x16xf64,,0.20,0.89,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,0.30,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,0.17,0.66,,
relu 4x4xf64,0.13,0.14,0.25,,
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted_fpu.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,0.13,0.13,,0.64,0.22
dense 8x8xf64,0.20,0.18,,0.26,0.26
fill 4x4xf64,0.02,0.02,0.28,,
matmul 4x16x8xf64,0.21,0.19,0.82,,
matmul_transb 4x16x16xf32,0.21,,,0.76,0.79
matmul_transb 4x16x16xf32,0.21,0.20,,0.76,0.79
matmul_transb 4x16x16xf64,,0.20,0.89,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,0.30,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,0.17,0.66,,
relu 4x4xf64,0.13,0.14,0.25,,
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted_ipc.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,0.95,0.94,,0.74,0.25
dense 8x8xf64,0.51,0.55,,0.39,0.33
fill 4x4xf64,0.46,0.46,0.53,,
matmul 4x16x8xf64,0.56,0.58,0.93,,
matmul_transb 4x16x16xf32,0.95,,,0.99,0.88
matmul_transb 4x16x16xf32,0.95,0.55,,0.99,0.88
matmul_transb 4x16x16xf64,,0.56,0.97,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,0.61,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,0.43,0.84,,
relu 4x4xf64,0.40,0.46,0.53,,
Expand Down
3 changes: 2 additions & 1 deletion results/pivoted_ipc.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ ddot 128xf64,0.95,0.94,,0.74,0.25
dense 8x8xf64,0.51,0.55,,0.39,0.33
fill 4x4xf64,0.46,0.46,0.53,,
matmul 4x16x8xf64,0.56,0.58,0.93,,
matmul_transb 4x16x16xf32,0.95,,,0.99,0.88
matmul_transb 4x16x16xf32,0.95,0.55,,0.99,0.88
matmul_transb 4x16x16xf64,,0.56,0.97,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,0.61,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,0.43,0.84,,
relu 4x4xf64,0.40,0.46,0.53,,
Expand Down
1 change: 1 addition & 0 deletions results/regalloc.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ dense,8x8xf64,5,11
fill,4x4xf64,3,3
matmul,4x16x8xf64,8,8
matmul_transb,4x16x16xf32,11,12
matmul_transb,4x16x16xf64,8,8
pooling_nchw_max_d1_s2_3x3,4x4xf64,7,6
pooling_nchw_sum_d1_s2_3x3,4x4xf64,7,6
relu,4x4xf64,3,5
Expand Down

0 comments on commit 1fc89dd

Please sign in to comment.