Skip to content

Commit

Permalink
Naive vec add (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 authored Jun 26, 2024
1 parent 94eea2d commit f545648
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
2 changes: 2 additions & 0 deletions kernel_hyperdrive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .utils import compile_helpers
from .vector_addition import (
VectorAddition_CUDA,
VectorAddition_Naive,
VectorAddition_PyTorch,
VectorAddition_Triton,
vector_addition_cuda,
vector_addition_naive,
vector_addition_pytorch,
vector_addition_triton,
)
Expand Down
1 change: 1 addition & 0 deletions kernel_hyperdrive/vector_addition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cuda_kernel import VectorAddition_CUDA, vector_addition_cuda
from .naive import VectorAddition_Naive, vector_addition_naive
from .pytorch import VectorAddition_PyTorch, vector_addition_pytorch
from .triton_kernel import VectorAddition_Triton, vector_addition_triton
37 changes: 37 additions & 0 deletions kernel_hyperdrive/vector_addition/naive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Tuple

import torch
import torch.nn as nn


def _vector_addition_naive(
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, num_elements: int, BLOCK_SIZE: int
) -> None:
for block_start in range(0, num_elements, BLOCK_SIZE):
block_end = max(block_start + BLOCK_SIZE, num_elements)

output[block_start:block_end] = x[block_start:block_end] + y[block_start:block_end]


class _VectorAddition_Naive(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.dim() == 1
output = torch.empty_like(x)

num_elements = x.numel()

_vector_addition_naive(x, y, output, num_elements, BLOCK_SIZE=1024)

return output

def backward(ctx, output_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return output_grad, output_grad


def vector_addition_naive(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return _VectorAddition_Naive.apply(x, y)


class VectorAddition_Naive(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return vector_addition_naive(x, y)
24 changes: 19 additions & 5 deletions tests/vector_addition_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Callable

import torch
from kernel_hperdrive import vector_addition_cuda, vector_addition_pytorch, vector_addition_triton
from parameterized import parameterized

from kernel_hyperdrive import (
vector_addition_cuda,
vector_addition_naive,
vector_addition_pytorch,
vector_addition_triton,
)

from .test_commons import TestCommons


Expand All @@ -13,16 +19,24 @@ class VectorAdditionTest(TestCommons):
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
)
)
def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_triton)
def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_cuda)

@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
)
)
def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_cuda)
def test_vector_addition_naive(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_naive)

@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
)
)
def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_triton)

def _test_vector_addition(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None:
x = torch.randn(size, device=device, dtype=dtype)
Expand Down

0 comments on commit f545648

Please sign in to comment.