diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 445d74d6d9bbe..258818ba30957 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,6 +84,8 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py @@ -108,6 +110,7 @@ steps: # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 68ca9a97a3c61..77539eab0db23 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -414,6 +415,9 @@ def from_engine_args( elif distributed_executor_backend == "mp": from vllm.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor @@ -426,6 +430,7 @@ def from_engine_args( usage_context=usage_context, stat_loggers=stat_loggers, ) + return engine def __reduce__(self): diff --git a/vllm/envs.py b/vllm/envs.py index f3b6d2788d392..595992e51db87 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,6 +34,7 @@ VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -261,6 +262,13 @@ def get_default_config_root(): "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 3db82eb1fe790..4df54a09e5e8c 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[SamplerOutput]]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -73,7 +73,9 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index edff9b6c93e09..92899ba5b0217 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,6 +1,5 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -23,12 +22,30 @@ logger = init_logger(__name__) -USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG - class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -40,11 +57,7 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() - self.extra_execute_model_run_workers_kwargs[ - "use_ray_compiled_dag"] = True + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -110,21 +123,24 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if self.use_ray_spmd_worker: self.workers.append(worker) - - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -254,9 +270,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs[0] + def _run_workers( self, method: str, @@ -266,7 +296,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -281,6 +310,10 @@ def _run_workers( - all_args/all_kwargs: args/kwargs for each worker are specified individually """ + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") if max_concurrent_workers: raise NotImplementedError( @@ -289,71 +322,69 @@ def _run_workers( count = len(self.workers) if not \ async_run_tensor_parallel_workers_only \ else len(self.non_driver_workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, 1, None) + else islice(all_args, first_worker_args_index, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, 1, None) - - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - ray_worker_outputs = [] - else: - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) - ] + else islice(all_kwargs, first_worker_args_index, None) + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + from packaging import version + + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") @@ -365,23 +396,47 @@ def _compiled_ray_dag(self): # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote. - bind( # type: ignore[attr-defined] + worker.execute_model_spmd.bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_exec_method = make_async(self.driver_worker.execute_method) + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + dag_future = await self.forward_dag.execute_async(execute_model_req) + outputs = await dag_future + return outputs[0] async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -415,8 +470,17 @@ async def _run_task_with_lock(task, lock, *args, **kwargs): return results[-1] async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers ] return await asyncio.gather(*coros) + + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 242d6c136655f..fcbfa30d7a38a 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -import pickle from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,16 +31,18 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_compiled_dag_remote(self, ignored): - """Used only when compiled DAG is enabled.""" + def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): + """Used only when SPMD worker and compiled DAG are both + enabled.""" + # TODO(swang): This is needed right now because Ray aDAG executes + # on a background thread, so we need to reset torch's current + # device. import torch if not self.compiled_dag_cuda_device_set: torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - output = self.worker.execute_model() - output = pickle.dumps(output) - return output + return self.worker._execute_model_spmd(execute_model_req) ray_import_err = None diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 33f9321b5ff36..2a93616ced06c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,11 +1,11 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, Tuple, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -30,7 +30,7 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG class RayXPUExecutor(DistributedGPUExecutor): @@ -72,10 +72,9 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. self.forward_dag = None if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -270,7 +269,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -293,26 +291,20 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + if async_run_remote_workers_only: # Just return futures return ray_worker_outputs + driver_worker_output = [] driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( @@ -324,36 +316,28 @@ def _run_workers( method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + from packaging import version + + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. @@ -363,7 +347,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def check_health(self) -> None: """Raises an error if engine is unhealthy.""" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 93ffea9106501..a10281b02db89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -281,6 +281,33 @@ def execute_model( # list to conform to interface. return output + def _execute_model_spmd( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None) + class WorkerWrapperBase: """ @@ -296,7 +323,7 @@ def __init__(self, trust_remote_code: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name - self.worker = None + self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -323,7 +350,9 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None def execute_method(self, method, *args, **kwargs): try: