Skip to content

Commit

Permalink
[Platform] Move device related code to platform
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Nov 30, 2024
1 parent 661175b commit f84256a
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 17 deletions.
16 changes: 1 addition & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,

# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
logger.warning(
"Async output processing is only supported for CUDA, TPU, XPU "
"and HPU."
"Disabling it for other platforms.")
if not current_platform.is_async_output_supported(self.enforce_eager):
self.use_async_output_proc = False
return

Expand All @@ -495,16 +491,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type == "cuda" and self.enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
self.use_async_output_proc = not self.enforce_eager
return

# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.task == "embedding":
Expand Down
13 changes: 12 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, List, TypeVar
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar

import pynvml
import torch
Expand Down Expand Up @@ -75,6 +75,7 @@ class CudaPlatformBase(Platform):
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
is_async_output_support: bool = True

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
Expand Down Expand Up @@ -112,6 +113,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return cls.is_async_output_support


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class HpuPlatform(Platform):
device_name: str = "hpu"
device_type: str = "hpu"
dispatch_key: str = "HPU"
is_async_output_support: bool = True

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
16 changes: 16 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import numpy as np
import torch

from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None

logger = init_logger(__name__)


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
Expand Down Expand Up @@ -63,6 +67,7 @@ class Platform:
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
supported_quantization: list[str] = []
is_async_output_support: bool = False

def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
Expand Down Expand Up @@ -184,6 +189,17 @@ def verify_quantization(cls, quant: str) -> None:
f"{quant} quantization is currently not supported in "
f"{cls.device_name}.")

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
"""
Check if the current platform supports async output.
"""
if not cls.is_async_output_support:
warn_msg = ("Async output processing is not supported on the "
f"current platform type {cls.device_type}.")
logger.warning(warn_msg)
return cls.is_async_output_support


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
13 changes: 12 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import torch

Expand Down Expand Up @@ -43,6 +43,7 @@ class RocmPlatform(Platform):
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
]
is_async_output_support: bool = True

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down Expand Up @@ -94,3 +95,13 @@ def verify_quantization(cls, quant: str) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return cls.is_async_output_support
1 change: 1 addition & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TpuPlatform(Platform):
device_type: str = "tpu"
dispatch_key: str = "XLA"
supported_quantization: list[str] = ["tpu_int8"]
is_async_output_support: bool = True

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class XPUPlatform(Platform):
device_name: str = "xpu"
device_type: str = "xpu"
dispatch_key: str = "XPU"
is_async_output_support: bool = True

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
Expand Down

0 comments on commit f84256a

Please sign in to comment.