Skip to content

Commit

Permalink
Add comprehensive tests to test the kernel across available dtypes.
Browse files Browse the repository at this point in the history
Added softmax and gemm kernel to test across the available float and int dtypes.
  • Loading branch information
Prashant Kumar committed Feb 27, 2024
1 parent 971231c commit 05d834c
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions core/tests/kernel/coverage_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import pytest


TKL_TO_TORCH_DTYPE = {
tkl.f16: torch.half,
tkl.f32: torch.float,
tkl.f64: torch.double,
tkl.bool: torch.bool,
tkl.i8: torch.int8,
tkl.i16: torch.int16,
tkl.i32: torch.int32,
tkl.i64: torch.int64,
}

FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
INT_DTYPES = [
tkl.bool,
tkl.i4,
tkl.i8,
tkl.i16,
tkl.i32,
tkl.i64,
tkl.index,
]


def softmax_krnl(dtype, input, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def softmax_kernel(
input: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
input_row = input[row_index, :]
numerator = tkl.exp2(input_row - tkl.max(input_row))
if dtype in INT_DTYPES:
output_row = numerator // tkl.sum(numerator)
else:
output_row = numerator / tkl.sum(numerator)
output[row_index, :] = output_row

with tk.gen.TestLaunchContext():
softmax_kernel(input, output)


def gemm_fx_kernel(dtype, A, B, output):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K
BLOCK_SIZE = tkl.sym.BLOCK_SIZE

@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
def gemm_kernel(
A: tkl.InputBuffer[N, K, dtype],
B: tkl.InputBuffer[K, M, dtype],
output: tkl.OutputBuffer[N, M, dtype],
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = None
# TODO: Only considering the float and integer cases.
if dtype in INT_DTYPES:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
else:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)

@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
gemm_kernel(A, B, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_softmax_krnl(dtype):
if dtype in TKL_TO_TORCH_DTYPE:
input = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
output = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
softmax_krnl(dtype, input, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_gemm_krnl(dtype):
if dtype in TKL_TO_TORCH_DTYPE:
A = torch.randn(512, 1024).to(TKL_TO_TORCH_DTYPE[dtype])
B = torch.randn(1024, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
output = torch.zeros(512, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
gemm_fx_kernel(dtype, A, B, output)

0 comments on commit 05d834c

Please sign in to comment.