-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add comprehensive tests to test the kernel across available dtypes.
Added softmax and gemm kernel to test across the available float and int dtypes.
- Loading branch information
Prashant Kumar
committed
Feb 27, 2024
1 parent
971231c
commit 05d834c
Showing
1 changed file
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
import shark_turbine.kernel as tk | ||
import shark_turbine.kernel.lang as tkl | ||
import pytest | ||
|
||
|
||
TKL_TO_TORCH_DTYPE = { | ||
tkl.f16: torch.half, | ||
tkl.f32: torch.float, | ||
tkl.f64: torch.double, | ||
tkl.bool: torch.bool, | ||
tkl.i8: torch.int8, | ||
tkl.i16: torch.int16, | ||
tkl.i32: torch.int32, | ||
tkl.i64: torch.int64, | ||
} | ||
|
||
FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64] | ||
INT_DTYPES = [ | ||
tkl.bool, | ||
tkl.i4, | ||
tkl.i8, | ||
tkl.i16, | ||
tkl.i32, | ||
tkl.i64, | ||
tkl.index, | ||
] | ||
|
||
|
||
def softmax_krnl(dtype, input, output): | ||
M = tkl.sym.M | ||
K = tkl.sym.K | ||
|
||
@tk.gen.thread(M) | ||
def softmax_kernel( | ||
input: tk.lang.InputBuffer[M, K, dtype], | ||
output: tk.lang.OutputBuffer[M, K, dtype], | ||
): | ||
row_index = tk.lang.program_id(0) | ||
input_row = input[row_index, :] | ||
numerator = tkl.exp2(input_row - tkl.max(input_row)) | ||
if dtype in INT_DTYPES: | ||
output_row = numerator // tkl.sum(numerator) | ||
else: | ||
output_row = numerator / tkl.sum(numerator) | ||
output[row_index, :] = output_row | ||
|
||
with tk.gen.TestLaunchContext(): | ||
softmax_kernel(input, output) | ||
|
||
|
||
def gemm_fx_kernel(dtype, A, B, output): | ||
N = tkl.sym.N | ||
M = tkl.sym.M | ||
K = tkl.sym.K | ||
BLOCK_SIZE = tkl.sym.BLOCK_SIZE | ||
|
||
@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE) | ||
def gemm_kernel( | ||
A: tkl.InputBuffer[N, K, dtype], | ||
B: tkl.InputBuffer[K, M, dtype], | ||
output: tkl.OutputBuffer[N, M, dtype], | ||
): | ||
grid_n = tkl.program_id(0) | ||
grid_m = tkl.program_id(1) | ||
|
||
acc = None | ||
# TODO: Only considering the float and integer cases. | ||
if dtype in INT_DTYPES: | ||
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0) | ||
else: | ||
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0) | ||
|
||
@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc]) | ||
def body(i, c): | ||
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE)) | ||
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE)) | ||
return (tkl.dot(a, b, c),) | ||
|
||
tkl.store(output, (grid_n, grid_m), body[0]) | ||
|
||
with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}): | ||
gemm_kernel(A, B, output) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("dtype",), | ||
[(x,) for x in FLOAT_DTYPES], | ||
) | ||
def test_softmax_krnl(dtype): | ||
if dtype in TKL_TO_TORCH_DTYPE: | ||
input = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype]) | ||
output = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype]) | ||
softmax_krnl(dtype, input, output) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("dtype",), | ||
[(x,) for x in FLOAT_DTYPES + INT_DTYPES], | ||
) | ||
def test_gemm_krnl(dtype): | ||
if dtype in TKL_TO_TORCH_DTYPE: | ||
A = torch.randn(512, 1024).to(TKL_TO_TORCH_DTYPE[dtype]) | ||
B = torch.randn(1024, 2048).to(TKL_TO_TORCH_DTYPE[dtype]) | ||
output = torch.zeros(512, 2048).to(TKL_TO_TORCH_DTYPE[dtype]) | ||
gemm_fx_kernel(dtype, A, B, output) |