diff --git a/kernel_hyperdrive/__init__.py b/kernel_hyperdrive/__init__.py index ef576fbe..95e71f90 100644 --- a/kernel_hyperdrive/__init__.py +++ b/kernel_hyperdrive/__init__.py @@ -1,9 +1,11 @@ from .utils import compile_helpers from .vector_addition import ( VectorAddition_CUDA, + VectorAddition_Naive, VectorAddition_PyTorch, VectorAddition_Triton, vector_addition_cuda, + vector_addition_naive, vector_addition_pytorch, vector_addition_triton, ) diff --git a/kernel_hyperdrive/vector_addition/__init__.py b/kernel_hyperdrive/vector_addition/__init__.py index c821a5c8..35ce8945 100644 --- a/kernel_hyperdrive/vector_addition/__init__.py +++ b/kernel_hyperdrive/vector_addition/__init__.py @@ -1,3 +1,4 @@ from .cuda_kernel import VectorAddition_CUDA, vector_addition_cuda +from .naive import VectorAddition_Naive, vector_addition_naive from .pytorch import VectorAddition_PyTorch, vector_addition_pytorch from .triton_kernel import VectorAddition_Triton, vector_addition_triton diff --git a/kernel_hyperdrive/vector_addition/naive.py b/kernel_hyperdrive/vector_addition/naive.py new file mode 100644 index 00000000..47e8d3d0 --- /dev/null +++ b/kernel_hyperdrive/vector_addition/naive.py @@ -0,0 +1,37 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +def _vector_addition_naive( + x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, num_elements: int, BLOCK_SIZE: int +) -> None: + for block_start in range(0, num_elements, BLOCK_SIZE): + block_end = max(block_start + BLOCK_SIZE, num_elements) + + output[block_start:block_end] = x[block_start:block_end] + y[block_start:block_end] + + +class _VectorAddition_Naive(torch.autograd.Function): + def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + assert x.dim() == 1 + output = torch.empty_like(x) + + num_elements = x.numel() + + _vector_addition_naive(x, y, output, num_elements, BLOCK_SIZE=1024) + + return output + + def backward(ctx, output_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return output_grad, output_grad + + +def vector_addition_naive(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return _VectorAddition_Naive.apply(x, y) + + +class VectorAddition_Naive(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return vector_addition_naive(x, y) diff --git a/tests/vector_addition_test.py b/tests/vector_addition_test.py index da5eb44e..2d913576 100644 --- a/tests/vector_addition_test.py +++ b/tests/vector_addition_test.py @@ -1,9 +1,15 @@ from typing import Callable import torch -from kernel_hperdrive import vector_addition_cuda, vector_addition_pytorch, vector_addition_triton from parameterized import parameterized +from kernel_hyperdrive import ( + vector_addition_cuda, + vector_addition_naive, + vector_addition_pytorch, + vector_addition_triton, +) + from .test_commons import TestCommons @@ -13,16 +19,24 @@ class VectorAdditionTest(TestCommons): TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes() ) ) - def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None: - self._test_vector_addition(size, device, dtype, vector_addition_triton) + def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None: + self._test_vector_addition(size, device, dtype, vector_addition_cuda) @parameterized.expand( TestCommons.make_args_matrix( TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes() ) ) - def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None: - self._test_vector_addition(size, device, dtype, vector_addition_cuda) + def test_vector_addition_naive(self, size: int, device: torch.device, dtype: torch.dtype) -> None: + self._test_vector_addition(size, device, dtype, vector_addition_naive) + + @parameterized.expand( + TestCommons.make_args_matrix( + TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes() + ) + ) + def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None: + self._test_vector_addition(size, device, dtype, vector_addition_triton) def _test_vector_addition(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None: x = torch.randn(size, device=device, dtype=dtype)