diff --git a/tests/kernels/test_scaled_mm_triton.py b/tests/kernels/test_scaled_mm_triton.py index ed5ca9de2971d..7db770452d1f5 100644 --- a/tests/kernels/test_scaled_mm_triton.py +++ b/tests/kernels/test_scaled_mm_triton.py @@ -8,7 +8,7 @@ import pytest import torch -from vllm.utils import seed_everything +from vllm.platforms import current_platform device = "cuda" @@ -42,7 +42,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t ).is_floating_point() - seed_everything(0) + current_platform.seed_everything(0) # NOTE: There are cases, where if the matrix is large enough, an output # like 65504.4 can be produced, and can easily turn into inf when @@ -70,7 +70,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, if use_scalar_scale_b: scale_b = torch.rand((1, 1), device=device) else: - scale_b = 0.25 * torch.rand((1, 1), device=device) + scale_b = 0.25 * torch.rand((N, 1), device=device) bias = None if use_bias: