diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 3c2ca1bddd906..79647589d5204 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -11,7 +11,7 @@ import pytest from vllm import LLM -from vllm.utils import is_hip +from vllm.platforms import current_platform from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from ..models.utils import check_outputs_equal @@ -51,7 +51,7 @@ def test_models( enforce_eager: bool, ) -> None: - if backend == "FLASHINFER" and is_hip(): + if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") os.environ["VLLM_ATTENTION_BACKEND"] = backend diff --git a/tests/compile/utils.py b/tests/compile/utils.py index c69343b51ae02..64fc08e80de3b 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -5,7 +5,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams from vllm.compilation.levels import CompilationLevel -from vllm.utils import is_hip +from vllm.platforms import current_platform TEST_MODELS = [ ("facebook/opt-125m", {}), @@ -55,7 +55,7 @@ "quantization": "marlin" })) -if not is_hip() and is_quant_method_supported("awq"): +if not current_platform.is_rocm() and is_quant_method_supported("awq"): TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { "quantization": "AWQ" })) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 8f6a54ff5979c..f2358940fc7b8 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -2,12 +2,13 @@ import torch -from vllm.utils import is_hip +from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. ROCM_FP8_MAX = 224.0 -FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn +FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \ + else torch.float8_e4m3fn def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: @@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor, qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max - qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min + qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + else qtype_traits.max + qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + else qtype_traits.min qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ -> Tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max - fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min + fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ + else fp8_traits.max + fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \ + else fp8_traits.min fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 52f1ecd176963..1604aa4d2d6e5 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,11 +6,12 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything +from vllm.platforms import current_platform +from vllm.utils import get_max_shared_memory_bytes, seed_everything from .allclose_default import get_default_atol, get_default_rtol -if not is_hip(): +if not current_platform.is_rocm(): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask @@ -23,8 +24,9 @@ NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [torch.half, torch.bfloat16, torch.float - ] if not is_hip() else [torch.half, torch.bfloat16] +DTYPES = [ + torch.half, torch.bfloat16, torch.float +] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing @@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) + "version", + ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -317,8 +320,8 @@ def test_paged_attention( # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. @@ -368,7 +371,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(is_hip(), +@pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") @torch.inference_mode() def test_multi_query_kv_attention( @@ -425,6 +428,6 @@ def test_multi_query_kv_attention( scale, dtype, ) - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index df3e770e260e0..3fe9ca0b0450f 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch): False) assert backend.name == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): + with patch("vllm.attention.selector.current_platform.is_rocm", + return_value=True): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "ROCM_FLASH" diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index f3bd8f0524264..b65efb3abc230 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -7,7 +7,8 @@ from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) -from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything +from vllm.platforms import current_platform +from vllm.utils import get_max_shared_memory_bytes, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -316,8 +317,8 @@ def test_paged_attention( # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6b979d0558c46..bc99c5559d388 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,7 +18,7 @@ from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, global_force_attn_backend_context_manager) -from vllm.utils import is_hip +from vllm.platforms import current_platform # List of support backends for encoder/decoder models LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] @@ -82,7 +82,7 @@ class TestResources(NamedTuple): will leverage attn_backend for the purpose of constructing backend-compatible attention metadata instances - + Attributes: * scale: 1/sqrt(d) scale factor for attn @@ -105,10 +105,10 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources: Build key components for performing encoder/decoder attention test. Note that - (1) The Attention instance constructed here, automatically selects + (1) The Attention instance constructed here, automatically selects an attention backend class based on platform info & a set of canned heuristics, so - (2) The attention backend instance constructed here is thus *not + (2) The attention backend instance constructed here is thus *not the same backend instance* used by attn, but rather it is intended to be a *different instance* of the *same backend class*; therefore, @@ -156,7 +156,7 @@ def _encoder_attn_setup( ''' Set up test vectors & data structures for encoder attention test. - A triplet of synthetic query/key/value tensors are constructed. + A triplet of synthetic query/key/value tensors are constructed. Given this is an encoder attention test, the key & value sequences will have the same length as the corresponding queries. @@ -169,14 +169,14 @@ def _encoder_attn_setup( Arguments: * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, + following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field - + Returns: - + * PhaseTestParameters data structure comprising (1) packed query/key/value tensors, (2) the ideal output of attention computed using a naive implementation, and (3) KVCache field set to None @@ -265,7 +265,7 @@ def _decoder_attn_setup( Arguments: * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, + following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field @@ -275,14 +275,14 @@ def _decoder_attn_setup( * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size) query/key/value tensors * Prefill-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) + including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data + computed using a naive implementation, and (3) memory-mapping data structures appropriate for prefill phase. - * Decode-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) - query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data + * Decode-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data structures appropriate for decode phase. * max_block_idx: max physical address in decoder self-attention block-table (intended to be used as the base address for the encoder/ @@ -436,12 +436,12 @@ def _enc_dec_cross_attn_setup_reuses_query( This function also constructs the cross-attention KV cache memory mapping (slot mapping and block table), ensuring that the block table starts at - block_base_addr. + block_base_addr. Arguments: * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x - num_heads x head_size) decoder self-attention inputs; + num_heads x head_size) decoder self-attention inputs; this function relies on the query and q_seq_lens fields * encoder_test_params: PhaseTestParameters data structure which was @@ -452,7 +452,7 @@ def _enc_dec_cross_attn_setup_reuses_query( self-attention; all fields including KV cache required * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, + following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field @@ -460,16 +460,16 @@ def _enc_dec_cross_attn_setup_reuses_query( Returns: - * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data - structure, including (1) packed + * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a + along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for prefill phase. - * Decode-phase encoder/decoder cross-attention PhaseTestParameters data + * Decode-phase encoder/decoder cross-attention PhaseTestParameters data structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a + along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for decode phase. ''' @@ -596,7 +596,7 @@ def _run_encoder_attention_test( ''' Run encoder attention. - attn.forward() is passed attn_type=AttentionType.ENCODER in order + attn.forward() is passed attn_type=AttentionType.ENCODER in order to configure the kernel invocation for encoder attention Requires attn_metadata.num_decode_tokens == 0 @@ -607,7 +607,7 @@ def _run_encoder_attention_test( * attn: Attention wrapper instance * encoder_test_params: encoder PhaseTestParameters data structure; this function relies on the packed - (number_of_tokens x num_heads x head_size) + (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention @@ -646,7 +646,7 @@ def _run_decoder_self_attention_test( and attn (Attention wrapper instance) fields * decoder_test_params: decoder PhaseTestParameters data structure; this function relies on the packed - (number_of_tokens x num_heads x head_size) + (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) @@ -694,11 +694,11 @@ def _run_encoder_decoder_cross_attention_test( and attn (Attention wrapper instance) fields * decoder_test_params: decoder PhaseTestParameters data structure; this function relies on the packed - (number_of_tokens x num_heads x head_size) + (number_of_tokens x num_heads x head_size) query field * cross_test_params: encoder/decoder PhaseTestParameters data structure; this function relies on the packed - (number_of_tokens x num_heads x head_size) + (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention @@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test( attn_type=attn_type) -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.skipif(current_platform.is_rocm(), + reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @@ -755,7 +756,8 @@ def test_encoder_only( No KV cache is required for encoder-only attention. Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). + AMD GPUs, therefore this test simply is skipped if + current_platform.is_rocm(). This test globally forces an override of the usual backend auto-selection process, forcing the specific backend-under-test @@ -811,7 +813,8 @@ def test_encoder_only( assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.skipif(current_platform.is_rocm(), + reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @@ -837,14 +840,14 @@ def test_e2e_enc_dec_attn( attributes for prefill-phase, and (2) an analogous attention metadata structure but for decode-phase * Test attention steps in the following order - + * Encoder attention * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention - * Besides being reflective of realistic use-cases, this order would - exacerbate any accidental overlap in the self-/cross-attention + * Besides being reflective of realistic use-cases, this order would + exacerbate any accidental overlap in the self-/cross-attention block tables, which one hopes to avoid @@ -864,10 +867,11 @@ def test_e2e_enc_dec_attn( to be utilized. Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). + AMD GPUs, therefore this test simply is skipped if + current_platform.is_rocm(). Note on metadata: there is a single attention metadata structure shared by - all prefill-phase attention operations (encoder, decoder, enc/dec cross), + all prefill-phase attention operations (encoder, decoder, enc/dec cross), and a single one shared by all decode-phase attention operations (decoder & enc/dec cross.) This is intended to reflect the behavior of EncoderDecoderModelRunner, which constructs a single attention metadata diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index c0053071258ea..4bfc089c82179 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import is_hip, seed_everything +from vllm.utils import seed_everything @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype): @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) -@pytest.mark.skipif(is_hip(), reason="Skip for rocm") +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, n: int, @@ -256,7 +257,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("is_k_full", [True, False]) -@pytest.mark.skipif(is_hip(), reason="Skip for rocm") +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_single_marlin_moe_multiply( m: int, n: int, diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index f7c1d4f041c12..15ec66b0f5502 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -4,7 +4,7 @@ import vllm from vllm.lora.request import LoRARequest -from vllm.utils import is_hip +from vllm.platforms import current_platform MODEL_PATH = "google/gemma-7b" @@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts -@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") +@pytest.mark.xfail(current_platform.is_rocm(), + reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index d004c65929418..5432fa4ad0d3a 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -8,7 +8,7 @@ import vllm from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from vllm.utils import is_hip +from vllm.platforms import current_platform @dataclass @@ -19,7 +19,7 @@ class ModelWithQuantization: MODELS: List[ModelWithQuantization] #AWQ quantization is currently not supported in ROCm. -if is_hip(): +if current_platform.is_rocm(): MODELS = [ ModelWithQuantization( model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", diff --git a/tests/models/decoder_only/vision_language/test_paligemma.py b/tests/models/decoder_only/vision_language/test_paligemma.py index a3ca0845e5ff8..69189ba2f25cb 100644 --- a/tests/models/decoder_only/vision_language/test_paligemma.py +++ b/tests/models/decoder_only/vision_language/test_paligemma.py @@ -6,8 +6,9 @@ BatchEncoding) from vllm.multimodal.utils import rescale_image_size +from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ...utils import check_logprobs_close @@ -24,7 +25,7 @@ # ROCm Triton FA can run into compilation issues with these models due to, # excessive use of shared memory. Use other backends in the meantime. # FIXME (mattwong, gshtrasb, hongxiayan) -if is_hip(): +if current_platform.is_rocm(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" @@ -70,7 +71,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -151,7 +152,7 @@ def process(hf_inputs: BatchEncoding): pytest.param( "float", marks=pytest.mark.skipif( - is_hip(), + current_platform.is_rocm(), reason= "ROCm FA does not yet fully support 32-bit precision on PaliGemma") ), "half" diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index dfe10629f1c66..1840b4bb8574c 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -12,7 +12,6 @@ from vllm.multimodal.utils import rescale_image_size from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from vllm.utils import is_hip from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) @@ -56,7 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, # ROCm Triton FA can run into shared memory issues with these models, # use other backends in the meantime # FIXME (mattwong, gshtrasb, hongxiayan) -if is_hip(): +if current_platform.is_rocm(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index b829d1a5be784..25562ca85adf4 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.utils import is_hip +from vllm.platforms import current_platform from .conftest import run_equality_correctness_test_tp @@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, batch_size: int, output_len: int, seed: int): """Verify greedy equality when tensor parallelism is used. """ - if is_hip(): + if current_platform.is_rocm(): pytest.skip("hip is not well-supported yet") run_equality_correctness_test_tp("JackFram/llama-68m", common_llm_kwargs, diff --git a/tests/utils.py b/tests/utils.py index e983104e3cb0c..0c61891cfefec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,7 +26,7 @@ from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, - cuda_device_count_stateless, get_open_port, is_hip) + cuda_device_count_stateless, get_open_port) if current_platform.is_rocm(): from amdsmi import (amdsmi_get_gpu_vram_usage, @@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int], output: Dict[int, str] = {} output_raw: Dict[int, float] = {} for device in devices: - if is_hip(): + if current_platform.is_rocm(): dev_handle = amdsmi_get_processor_handles()[device] mem_info = amdsmi_get_gpu_vram_usage(dev_handle) gb_used = mem_info["vram_used"] / 2**10 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f57414bd5197e..46a2fb8bc80a2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -659,11 +659,11 @@ def scaled_fp8_quant( Args: input: The input tensor to be quantized to FP8 scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic + scale_ub: Optional upper bound for scaling factor in dynamic per token case num_token_padding: If specified, pad the first dimension of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token + use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. Returns: @@ -674,8 +674,8 @@ def scaled_fp8_quant( assert (input.ndim == 2) shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \ - else torch.float8_e4m3fn + out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + if current_platform.is_rocm() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index e4dc576d27932..a98eb431ac7fc 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -3,7 +3,6 @@ import torch from vllm.platforms import current_platform -from vllm.utils import is_hip from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) @@ -32,8 +31,9 @@ def __init__( ): super().__init__() if use_spda is None: - use_spda = is_hip() or current_platform.is_cpu() or not \ - IS_COMPUTE_8_OR_ABOVE + use_spda = current_platform.is_rocm() or \ + current_platform.is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 10d4509b38279..376b3136f0fb8 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, is_hip +from vllm.utils import STR_BACKEND_ENV_VAR logger = init_logger(__name__) @@ -208,7 +208,7 @@ def which_attn_to_use( logger.info("Cannot use %s backend on TPU.", selected_backend) return _Backend.PALLAS - if is_hip(): + if current_platform.is_rocm(): # AMD GPUs. selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) diff --git a/vllm/config.py b/vllm/config.py index a1fba98233b80..99a82c8f1b40b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, print_warning_once) + print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -43,7 +43,7 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. - It is also used as the content for `model_name` tag in metrics + It is also used as the content for `model_name` tag in metrics output when `served_model_name` is not specified. task: The task to use the model for. Each vLLM instance only supports one task, even if the same model can be used for multiple tasks. @@ -99,15 +99,15 @@ class ModelConfig: skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. - override_neuron_config: Initialize non default neuron config or - override default neuron config that are specific to Neuron devices, - this argument will be used to configure the neuron config that - can not be gathered from the vllm arguments. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. mm_processor_kwargs: Arguments to be forwarded to the model's processor @@ -350,7 +350,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if is_hip( + if current_platform.is_rocm( ) and self.quantization not in rocm_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " @@ -365,7 +365,7 @@ def _verify_quantization(self) -> None: "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) - if (self.quantization == "awq" and is_hip() + if (self.quantization == "awq" and current_platform.is_rocm() and not envs.VLLM_USE_TRITON_AWQ): logger.warning( "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" @@ -385,7 +385,7 @@ def _verify_cuda_graph(self) -> None: def _verify_bnb_config(self) -> None: """ - The current version of bitsandbytes (0.44.0) with 8-bit models does not + The current version of bitsandbytes (0.44.0) with 8-bit models does not yet support CUDA graph. """ is_bitsandbytes = self.quantization == "bitsandbytes" @@ -810,7 +810,7 @@ class LoadConfig: fast weight loading. "bitsandbytes" will load nf4 type weights. ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's + Default to "original/**/*" to avoid repeated loading of llama's checkpoints. """ @@ -843,7 +843,8 @@ def _verify_load_format(self) -> None: self.load_format = LoadFormat(load_format) rocm_not_supported_load_format: List[str] = [] - if is_hip() and load_format in rocm_not_supported_load_format: + if current_platform.is_rocm( + ) and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) @@ -967,7 +968,7 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if is_hip(): + if current_platform.is_rocm(): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " @@ -996,7 +997,7 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. - preemption_mode: Whether to perform preemption by swapping or + preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than swapping. However, when the sequence group has multiple sequences @@ -1215,7 +1216,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the @@ -1225,7 +1226,7 @@ def maybe_create_spec_config( If set to False, token log probabilities are returned according to the log probability settings in SamplingParams. If not specified, it defaults to True. - + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. @@ -1470,13 +1471,13 @@ def __init__( typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior probability of a token in the target model for it to be - accepted. This threshold is used only when we use the + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. disable_logprobs: If set to True, token log probabilities will not - be returned even if requested by sampling parameters. This + be returned even if requested by sampling parameters. This reduces latency by skipping logprob calculation in proposal sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be @@ -1843,10 +1844,10 @@ def get_min_sliding_window( def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """ - If the input is a non-empty list, the first model_name in - `served_model_name` is taken. - If the input is a non-empty string, it is used directly. - For cases where the input is either an empty string or an + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an empty list, the fallback is to use `self.model`. """ if not served_model_name: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 0af7b3386d895..aa546ebada473 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import get_ip, is_hip +from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -231,7 +231,7 @@ def initialize_ray_cluster( assert_ray_available() # Connect to a ray cluster. - if is_hip() or current_platform.is_xpu(): + if current_platform.is_rocm() or current_platform.is_xpu(): ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 71eed6eb68d78..83910339f3c9f 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -7,7 +7,7 @@ from vllm.compilation.levels import CompilationLevel from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_hip, print_warning_once +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -72,7 +72,7 @@ def dispatch_forward(self): if not enabled: return self.forward_native - if is_hip(): + if current_platform.is_rocm(): return self.forward_hip elif current_platform.is_cpu(): return self.forward_cpu diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c21aaa40ff2cc..be3d3985a74ad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip, print_warning_once +from vllm.platforms import current_platform +from vllm.utils import print_warning_once class GPTQMarlinState(Enum): @@ -150,7 +151,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale.max(), requires_grad=False) # If rocm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if current_platform.is_rocm(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 7270b302ef965..73cc8ce0d2a4b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -12,7 +12,7 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) -from vllm.utils import is_hip +from vllm.platforms import current_platform __all__ = ["CompressedTensorsW8A8Fp8"] @@ -40,7 +40,7 @@ def process_weights_after_loading(self, layer) -> None: logical_widths=layer.logical_widths, ) - if is_hip(): + if current_platform.is_rocm(): weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=max_w_scale, @@ -56,7 +56,7 @@ def process_weights_after_loading(self, layer) -> None: elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight - if is_hip(): + if current_platform.is_rocm(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index f26907176ad1a..825d01d1b3551 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -19,7 +19,6 @@ from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform -from vllm.utils import is_hip logger = init_logger(__name__) @@ -127,7 +126,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight - if is_hip(): + if current_platform.is_rocm(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b5feb55db0e74..d34579b7099bb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -26,7 +26,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_hip, print_warning_once +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -123,7 +123,7 @@ def __init__(self, quant_config: Fp8Config): self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm - if is_hip(): + if current_platform.is_rocm(): self.use_marlin = False def create_weights( @@ -226,7 +226,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale = layer.weight_scale # If rocm, use float8_e4m3fnuz. - if is_hip(): + if current_platform.is_rocm(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -372,7 +372,7 @@ def process_weights_after_loading(self, layer: Module) -> None: if not self.quant_config.is_checkpoint_fp8_serialized: # If rocm, use float8_e4m3fnuz as dtype fp8_dtype = torch.float8_e4m3fnuz \ - if is_hip() else torch.float8_e4m3fn + if current_platform.is_rocm() else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -420,7 +420,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) # If rocm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if current_platform.is_rocm(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 411af922149fd..1879d2855d93d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -4,16 +4,16 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import is_hip # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \ + if current_platform.is_rocm() else None def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm - if is_hip(): + if current_platform.is_rocm(): return False capability_tuple = current_platform.get_device_capability() diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4126ceb7117d4..22f194c776b69 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -49,9 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.exaone import ExaoneConfig -from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -595,7 +595,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.transformer.h[layer_idx], nn.Identity): layer_self_attn = self.transformer.h[layer_idx].attn - if is_hip(): + if current_platform.is_rocm(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 5a397ed8ff6a0..c968817747754 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -49,8 +49,8 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -534,7 +534,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn - if is_hip(): + if current_platform.is_rocm(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c346e3e808e3f..b0ca1fe006239 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -50,8 +50,8 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -423,7 +423,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.layers[layer_idx], nn.Identity): layer_self_attn = self.layers[layer_idx].self_attn - if is_hip(): + if current_platform.is_rocm(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e62aa9e77121c..fc25ed8ea82a7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,7 +12,7 @@ import torch.nn as nn from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.platforms import current_platform from .interfaces import (has_inner_state, is_attention_free, supports_multimodal, supports_pp) @@ -248,7 +248,7 @@ def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, ) -> Optional[Type[nn.Module]]: - if is_hip(): + if current_platform.is_rocm(): if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError(f"Model architecture '{model_arch}' is not " "supported by ROCm for now.") diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 5a3dd3c02b85b..e3e7ccb5cf179 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -49,8 +49,8 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -558,7 +558,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if not isinstance(self.model.layers[layer_idx], nn.Identity): layer_self_attn = self.model.layers[layer_idx].self_attn - if is_hip(): + if current_platform.is_rocm(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting diff --git a/vllm/utils.py b/vllm/utils.py index d4f2c936ca9cc..c3f9a6bdd8b80 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -314,10 +314,6 @@ def reset(self): self._index = 0 -def is_hip() -> bool: - return torch.version.hip is not None - - @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless( if not torch.cuda._is_compiled(): return 0 - if is_hip(): + if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 raw_count = torch.cuda._device_count_amdsmi() if (hasattr( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4a287e3741d0f..233a9e664d845 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -41,6 +41,7 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) +from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -49,7 +50,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.transformers_utils.config import uses_mrope from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, + flatten_2d_lists, is_pin_memory_available, supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -737,13 +738,13 @@ def _get_cuda_graph_pad_size(self, family of functions. Args: - num_seqs (int): Number of sequences scheduled to run. + num_seqs (int): Number of sequences scheduled to run. max_decode_seq_len (int): Greatest of all the decode sequence lengths. Used only in checking the viablility of using CUDA graphs. max_encoder_seq_len (int, optional): Greatest of all the encode sequence lengths. Defaults to 0. Used only in checking the - viability of using CUDA graphs. + viability of using CUDA graphs. Returns: int: Returns the determined number of padding sequences. If CUDA graphs is not viable, returns -1. @@ -1103,7 +1104,7 @@ def load_model(self) -> None: self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and is_hip(): + if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated # in the future.