diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 36b7e2265efab..ba6177e51a453 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,8 +1,5 @@ from typing import Callable, List, Optional, Tuple, Type, Union -import torch - -from vllm.config import ModelConfig, ParallelConfig from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -23,7 +20,6 @@ def _init_executor(self) -> None: assert self.speculative_config is None, ( "Speculative decoding not yet supported for XPU backend") - self.model_config = _verify_and_get_model_config(self.model_config) GPUExecutor._init_executor(self) def _get_worker_module_and_class( @@ -53,26 +49,3 @@ async def execute_model_async( output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req) return output - - -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype == torch.bfloat16: - logger.warning( - "bfloat16 is not fully supported on XPU, casting to float16.") - config.dtype = torch.float16 - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: - if (config.distributed_executor_backend is not None - and config.distributed_executor_backend != "ray"): - logger.warning( - "%s is not supported on XPU, fallback to ray distributed executor " - "backend.", config.distributed_executor_backend) - config.distributed_executor_backend = "ray" - return config diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d0b3dca9a4195..62db285f6696a 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,9 +1,16 @@ +from typing import TYPE_CHECKING + import torch from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) @@ -34,3 +41,17 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + logger.warning( + "bfloat16 is not fully supported on XPU, casting to float16.") + model_config.dtype = torch.float16 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True