From 217b86261e5f632030a43c94f812260d7cde0b5a Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 28 Jun 2024 13:20:58 -0700 Subject: [PATCH] SPMD worker Signed-off-by: Stephanie Wang --- vllm/engine/llm_engine.py | 7 + vllm/envs.py | 8 ++ vllm/executor/distributed_gpu_executor.py | 8 +- vllm/executor/ray_gpu_executor.py | 164 +++++++++++++--------- vllm/executor/ray_utils.py | 20 +-- vllm/executor/ray_xpu_executor.py | 153 ++++++++++++-------- vllm/worker/worker_base.py | 26 ++++ 7 files changed, 247 insertions(+), 139 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd13e..d8d526476b64d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -48,6 +49,8 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER + def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( @@ -413,6 +416,9 @@ def from_engine_args( elif distributed_executor_backend == "mp": from vllm.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutor) + assert not USE_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor @@ -424,6 +430,7 @@ def from_engine_args( log_stats=not engine_args.disable_log_stats, usage_context=usage_context, ) + return engine def __reduce__(self): diff --git a/vllm/envs.py b/vllm/envs.py index f3b6d2788d392..0bd52b85bdc25 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,6 +34,7 @@ VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -261,6 +262,13 @@ def get_default_config_root(): "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_SPMD_WORKER=1 to enable it. + "VLLM_USE_SPMD_WORKER": + lambda: bool(os.getenv("VLLM_USE_SPMD_WORKER", 0)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 3db82eb1fe790..4df54a09e5e8c 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[SamplerOutput]]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -73,7 +73,9 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 388f934ef75a6..6e2b94323dc83 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,6 +1,5 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -23,12 +22,28 @@ logger = init_logger(__name__) +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +# Currently, this requires USE_SPMD_WORKER=True. USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG +# If the env var is set, then we do not distinguish between the "driver worker" +# vs other workers. Also, the rank 0 worker will be executed in a remote Ray +# worker. Currently this requires USE_RAY_COMPILED_DAG=True. +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: + if USE_RAY_COMPILED_DAG: + assert USE_SPMD_WORKER, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") + if USE_SPMD_WORKER: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -40,11 +55,7 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() - self.extra_execute_model_run_workers_kwargs[ - "use_ray_compiled_dag"] = True + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -110,21 +121,24 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if USE_SPMD_WORKER: self.workers.append(worker) - - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if not USE_SPMD_WORKER and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -240,9 +254,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs + def _run_workers( self, method: str, @@ -252,7 +280,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -280,64 +307,57 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - ray_worker_outputs = [] - else: - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not USE_SPMD_WORKER: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + + # TODO(swang): Upgrade version. required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: @@ -355,7 +375,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): @@ -364,10 +384,24 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + outputs = await self.forward_dag.execute_async(execute_model_req) + return await outputs + async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -401,6 +435,8 @@ async def _run_task_with_lock(task, lock, *args, **kwargs): return results[-1] async def _start_worker_execution_loop(self): + assert not USE_SPMD_WORKER, ( + "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 242d6c136655f..bd74e94c8fa17 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -import pickle from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,16 +31,16 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_compiled_dag_remote(self, ignored): - """Used only when compiled DAG is enabled.""" - import torch - if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) - self.compiled_dag_cuda_device_set = True + def execute_model(self, execute_model_req: ExecuteModelRequest): + """Used only when SPMD worker and compiled DAG are both + enabled.""" + ## TODO(swang): remove? + #import torch + #if not self.compiled_dag_cuda_device_set: + # torch.cuda.set_device(self.worker.device) + # self.compiled_dag_cuda_device_set = True - output = self.worker.execute_model() - output = pickle.dumps(output) - return output + return self.worker.execute_model(execute_model_req) ray_import_err = None diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 33f9321b5ff36..b900a80ff8fd4 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,11 +1,11 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, Tuple, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -30,7 +30,12 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +# Currently, this requires USE_SPMD_WORKER=True. +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG +# If the env var is set, then we do not distinguish between the "driver worker" +# vs other workers. Also, the rank 0 worker will be executed in a remote Ray +# worker. Currently this requires USE_RAY_COMPILED_DAG=True. +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER class RayXPUExecutor(DistributedGPUExecutor): @@ -72,10 +77,7 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -85,7 +87,13 @@ def __init__( self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} def _init_executor(self) -> None: - pass + if USE_RAY_COMPILED_DAG: + assert USE_SPMD_WORKER, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") + if USE_SPMD_WORKER: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -144,20 +152,23 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if USE_SPMD_WORKER: self.workers.append(worker) - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + if not USE_SPMD_WORKER and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -241,9 +252,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers( @@ -270,7 +295,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -293,59 +317,48 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + if async_run_remote_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not USE_SPMD_WORKER: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + + # TODO(swang): Upgrade version. required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: @@ -353,7 +366,7 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. @@ -363,7 +376,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def check_health(self) -> None: """Raises an error if engine is unhealthy.""" @@ -389,14 +402,30 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + outputs = await self.forward_dag.execute_async(execute_model_req) + return await outputs + async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return await self.driver_exec_method("execute_model", execute_model_req) async def _start_worker_execution_loop(self): + assert not USE_SPMD_WORKER, ( + "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.workers diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 93ffea9106501..1897243d7719d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,6 +6,7 @@ import torch +import vllm.envs as envs from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -17,6 +18,8 @@ logger = init_logger(__name__) +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER + class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for @@ -215,6 +218,24 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + if USE_SPMD_WORKER: + assert execute_model_req is not None, ( + "VLLM_USE_SPMD_WORKER=1 requires each worker to take in an " + "ExecuteModelRequest") + return self._execute_model_spmd(execute_model_req) + + return self._execute_model_with_nccl_control_plane(execute_model_req) + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest = None + ) -> Optional[List[SamplerOutput]]: + pass + + def _execute_model_with_nccl_control_plane( + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -323,6 +344,11 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) + if USE_SPMD_WORKER: + assert isinstance(worker_class, LocalOrDistributedWorkerBase), ( + "VLLM_USE_SPMD_WORKER=1 is currently only supported with " + "workers that inherit from LocalOrDistributedWorkerBase") + self.worker = worker_class(*args, **kwargs) def execute_method(self, method, *args, **kwargs):