Skip to content

Commit

Permalink
prune tests, add skipif, has_good_tensor->is_weak_contiguous
Browse files Browse the repository at this point in the history
Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith committed Nov 6, 2024
1 parent 5abafe4 commit d5e390d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions tests/kernels/test_scaled_mm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def get_8bit_types():
return types


@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("N", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
use_scalar_scale_b, use_bias):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton.language as tl


def has_good_tensor_strides(x: torch.Tensor):
def is_weak_contiguous(x: torch.Tensor):
strides = x.stride()
sizes = x.shape
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
Expand Down Expand Up @@ -142,8 +142,8 @@ def scaled_mm_triton(input: torch.Tensor,
[N, 1])
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert has_good_tensor_strides(input)
assert has_good_tensor_strides(weight)
assert is_weak_contiguous(input)
assert is_weak_contiguous(weight)

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
N, META['BLOCK_SIZE_N']), )
Expand Down

0 comments on commit d5e390d

Please sign in to comment.