Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 committed Aug 30, 2024
1 parent 5ac156d commit 8e1a22d
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions tests/vector_addition_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Callable

import torch
Expand All @@ -11,25 +12,25 @@
class VectorAdditionTest(TestCommons):
@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes(), [False, True]
)
)
def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_cuda)
def test_vector_addition_cuda(self, size: int, device: torch.device, dtype: torch.dtype, in_place: bool) -> None:
self._test_vector_addition(size, device, dtype, partial(vector_addition_cuda, in_place=in_place))

@parameterized.expand(
TestCommons.make_args_matrix(
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes()
TestCommons.get_1d_tensor_sizes(), [torch.device("cuda")], TestCommons.get_dtypes(), [False, True]
)
)
def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype) -> None:
self._test_vector_addition(size, device, dtype, vector_addition_triton)
def test_vector_addition_triton(self, size: int, device: torch.device, dtype: torch.dtype, in_place: bool) -> None:
self._test_vector_addition(size, device, dtype, partial(vector_addition_cuda, in_place=in_place))

def _test_vector_addition(self, size: int, device: torch.device, dtype: torch.dtype, function: Callable) -> None:
x = torch.randn(size, device=device, dtype=dtype)
y = torch.randn(size, device=device, dtype=dtype)

z_expected = vector_addition_torch(x, y, in_place=False)
z_kernel = function(x, y)
z_expected = vector_addition_torch(x, y)

self.assert_equal_tensors(z_kernel, z_expected, True)

0 comments on commit 8e1a22d

Please sign in to comment.