Skip to content

Commit

Permalink
[Hardware][ROCM] using current_platform.is_rocm (vllm-project#9642)
Browse files Browse the repository at this point in the history
Signed-off-by: wangshuai09 <[email protected]>
  • Loading branch information
wangshuai09 authored Oct 28, 2024
1 parent 34a9941 commit 4e2d95e
Show file tree
Hide file tree
Showing 32 changed files with 162 additions and 148 deletions.
4 changes: 2 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from vllm import LLM
from vllm.utils import is_hip
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from ..models.utils import check_outputs_equal
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool,
) -> None:

if backend == "FLASHINFER" and is_hip():
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip
from vllm.platforms import current_platform

TEST_MODELS = [
("facebook/opt-125m", {}),
Expand Down Expand Up @@ -55,7 +55,7 @@
"quantization": "marlin"
}))

if not is_hip() and is_quant_method_supported("awq"):
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
Expand Down
17 changes: 11 additions & 6 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import torch

from vllm.utils import is_hip
from vllm.platforms import current_platform

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
Expand All @@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,

qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
Expand Down Expand Up @@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:

fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

Expand Down
23 changes: 13 additions & 10 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything

from .allclose_default import get_default_atol, get_default_rtol

if not is_hip():
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

Expand All @@ -23,8 +24,9 @@
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
Expand Down Expand Up @@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(


@pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5

# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
Expand Down Expand Up @@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
Expand Down Expand Up @@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
3 changes: 2 additions & 1 deletion tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
Expand Down
7 changes: 4 additions & 3 deletions tests/kernels/test_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything

from .allclose_default import get_default_atol, get_default_rtol

Expand Down Expand Up @@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5

# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
Expand Down
Loading

0 comments on commit 4e2d95e

Please sign in to comment.