diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ca4ba66f09cb8..3a2f7db679358 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -343,6 +343,11 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "cpu": + assert not engine_config.parallel_config.worker_use_ray, ( + "Ray is not supported with the CPU backend.") + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 35249cd7302cb..8d6a1fff91fd8 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -4,11 +4,12 @@ import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig -from vllm.executor.executor_base import ExecutorBase +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) logger = init_logger(__name__) @@ -100,6 +101,28 @@ def check_health(self) -> None: return +class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + return output + + async def check_health_async(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: if config.dtype == torch.float16: logger.warning("float16 is not supported on CPU, casting to bfloat16.")