From 217b86261e5f632030a43c94f812260d7cde0b5a Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 28 Jun 2024 13:20:58 -0700 Subject: [PATCH 1/2] SPMD worker Signed-off-by: Stephanie Wang --- vllm/engine/llm_engine.py | 7 + vllm/envs.py | 8 ++ vllm/executor/distributed_gpu_executor.py | 8 +- vllm/executor/ray_gpu_executor.py | 164 +++++++++++++--------- vllm/executor/ray_utils.py | 20 +-- vllm/executor/ray_xpu_executor.py | 153 ++++++++++++-------- vllm/worker/worker_base.py | 26 ++++ 7 files changed, 247 insertions(+), 139 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd13e..d8d526476b64d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -48,6 +49,8 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER + def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( @@ -413,6 +416,9 @@ def from_engine_args( elif distributed_executor_backend == "mp": from vllm.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutor) + assert not USE_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor @@ -424,6 +430,7 @@ def from_engine_args( log_stats=not engine_args.disable_log_stats, usage_context=usage_context, ) + return engine def __reduce__(self): diff --git a/vllm/envs.py b/vllm/envs.py index f3b6d2788d392..0bd52b85bdc25 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,6 +34,7 @@ VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -261,6 +262,13 @@ def get_default_config_root(): "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_SPMD_WORKER=1 to enable it. + "VLLM_USE_SPMD_WORKER": + lambda: bool(os.getenv("VLLM_USE_SPMD_WORKER", 0)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 3db82eb1fe790..4df54a09e5e8c 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[SamplerOutput]]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -73,7 +73,9 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 388f934ef75a6..6e2b94323dc83 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,6 +1,5 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -23,12 +22,28 @@ logger = init_logger(__name__) +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +# Currently, this requires USE_SPMD_WORKER=True. USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG +# If the env var is set, then we do not distinguish between the "driver worker" +# vs other workers. Also, the rank 0 worker will be executed in a remote Ray +# worker. Currently this requires USE_RAY_COMPILED_DAG=True. +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: + if USE_RAY_COMPILED_DAG: + assert USE_SPMD_WORKER, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") + if USE_SPMD_WORKER: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -40,11 +55,7 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() - self.extra_execute_model_run_workers_kwargs[ - "use_ray_compiled_dag"] = True + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -110,21 +121,24 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if USE_SPMD_WORKER: self.workers.append(worker) - - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if not USE_SPMD_WORKER and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -240,9 +254,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs + def _run_workers( self, method: str, @@ -252,7 +280,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -280,64 +307,57 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - ray_worker_outputs = [] - else: - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not USE_SPMD_WORKER: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + + # TODO(swang): Upgrade version. required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: @@ -355,7 +375,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): @@ -364,10 +384,24 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + outputs = await self.forward_dag.execute_async(execute_model_req) + return await outputs + async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -401,6 +435,8 @@ async def _run_task_with_lock(task, lock, *args, **kwargs): return results[-1] async def _start_worker_execution_loop(self): + assert not USE_SPMD_WORKER, ( + "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 242d6c136655f..bd74e94c8fa17 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -import pickle from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,16 +31,16 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_compiled_dag_remote(self, ignored): - """Used only when compiled DAG is enabled.""" - import torch - if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) - self.compiled_dag_cuda_device_set = True + def execute_model(self, execute_model_req: ExecuteModelRequest): + """Used only when SPMD worker and compiled DAG are both + enabled.""" + ## TODO(swang): remove? + #import torch + #if not self.compiled_dag_cuda_device_set: + # torch.cuda.set_device(self.worker.device) + # self.compiled_dag_cuda_device_set = True - output = self.worker.execute_model() - output = pickle.dumps(output) - return output + return self.worker.execute_model(execute_model_req) ray_import_err = None diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 33f9321b5ff36..b900a80ff8fd4 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,11 +1,11 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, Tuple, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -30,7 +30,12 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +# Currently, this requires USE_SPMD_WORKER=True. +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG +# If the env var is set, then we do not distinguish between the "driver worker" +# vs other workers. Also, the rank 0 worker will be executed in a remote Ray +# worker. Currently this requires USE_RAY_COMPILED_DAG=True. +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER class RayXPUExecutor(DistributedGPUExecutor): @@ -72,10 +77,7 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -85,7 +87,13 @@ def __init__( self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} def _init_executor(self) -> None: - pass + if USE_RAY_COMPILED_DAG: + assert USE_SPMD_WORKER, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") + if USE_SPMD_WORKER: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -144,20 +152,23 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if USE_SPMD_WORKER: self.workers.append(worker) - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + if not USE_SPMD_WORKER and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -241,9 +252,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs + def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers( @@ -270,7 +295,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -293,59 +317,48 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + if async_run_remote_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not USE_SPMD_WORKER: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + + # TODO(swang): Upgrade version. required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: @@ -353,7 +366,7 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. @@ -363,7 +376,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def check_health(self) -> None: """Raises an error if engine is unhealthy.""" @@ -389,14 +402,30 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not USE_SPMD_WORKER: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + outputs = await self.forward_dag.execute_async(execute_model_req) + return await outputs + async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not USE_SPMD_WORKER, ( + "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return await self.driver_exec_method("execute_model", execute_model_req) async def _start_worker_execution_loop(self): + assert not USE_SPMD_WORKER, ( + "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.workers diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 93ffea9106501..1897243d7719d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,6 +6,7 @@ import torch +import vllm.envs as envs from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -17,6 +18,8 @@ logger = init_logger(__name__) +USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER + class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for @@ -215,6 +218,24 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + if USE_SPMD_WORKER: + assert execute_model_req is not None, ( + "VLLM_USE_SPMD_WORKER=1 requires each worker to take in an " + "ExecuteModelRequest") + return self._execute_model_spmd(execute_model_req) + + return self._execute_model_with_nccl_control_plane(execute_model_req) + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest = None + ) -> Optional[List[SamplerOutput]]: + pass + + def _execute_model_with_nccl_control_plane( + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" @@ -323,6 +344,11 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) + if USE_SPMD_WORKER: + assert isinstance(worker_class, LocalOrDistributedWorkerBase), ( + "VLLM_USE_SPMD_WORKER=1 is currently only supported with " + "workers that inherit from LocalOrDistributedWorkerBase") + self.worker = worker_class(*args, **kwargs) def execute_method(self, method, *args, **kwargs): From 8df6b83e0eb4e1119233a23ddbbb5807bcc0c8e2 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 28 Jun 2024 18:38:46 -0700 Subject: [PATCH 2/2] up Signed-off-by: Rui Qiao --- .buildkite/test-pipeline.yaml | 3 + vllm/executor/ray_gpu_executor.py | 45 +++++++++---- vllm/executor/ray_utils.py | 14 ++-- vllm/executor/ray_xpu_executor.py | 104 +++++++++++++----------------- vllm/worker/cpu_worker.py | 1 + vllm/worker/worker.py | 1 + vllm/worker/worker_base.py | 61 ++++++++++++------ 7 files changed, 133 insertions(+), 96 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cd3a5e80d7bd0..d6b847bb790e3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,6 +84,8 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py @@ -108,6 +110,7 @@ steps: # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6e2b94323dc83..0525dad1b9abf 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -119,6 +119,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name=worker_module_name, worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, + use_spmd_worker=USE_SPMD_WORKER, ) if USE_SPMD_WORKER: @@ -269,7 +270,7 @@ def execute_model( self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) outputs = ray.get(self.forward_dag.execute(execute_model_req)) - return outputs + return outputs[0] def _run_workers( self, @@ -294,6 +295,10 @@ def _run_workers( - all_args/all_kwargs: args/kwargs for each worker are specified individually """ + if USE_SPMD_WORKER: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") if max_concurrent_workers: raise NotImplementedError( @@ -302,19 +307,23 @@ def _run_workers( count = len(self.workers) if not \ async_run_tensor_parallel_workers_only \ else len(self.non_driver_workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if USE_SPMD_WORKER else 1 all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, 1, None) + else islice(all_args, first_worker_args_index, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, 1, None) + else islice(all_kwargs, first_worker_args_index, None) # Start the ray workers first. ray_workers = self.workers if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers + ray_workers = self.non_driver_workers ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) ] if async_run_tensor_parallel_workers_only: @@ -356,10 +365,11 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + from packaging import version - # TODO(swang): Upgrade version. - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") @@ -371,8 +381,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote. - bind( # type: ignore[attr-defined] + worker.execute_model_spmd.bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) @@ -382,7 +391,9 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_exec_method = make_async(self.driver_worker.execute_method) + if not USE_SPMD_WORKER: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) async def execute_model_async( self, @@ -393,8 +404,9 @@ async def execute_model_async( if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - outputs = await self.forward_dag.execute_async(execute_model_req) - return await outputs + dag_future = await self.forward_dag.execute_async(execute_model_req) + outputs = await dag_future + return outputs[0] async def _driver_execute_model_async( self, @@ -442,3 +454,10 @@ async def _start_worker_execution_loop(self): for worker in self.non_driver_workers ] return await asyncio.gather(*coros) + + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index bd74e94c8fa17..ba7a23d10603f 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -31,14 +31,16 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model(self, execute_model_req: ExecuteModelRequest): + def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): """Used only when SPMD worker and compiled DAG are both enabled.""" - ## TODO(swang): remove? - #import torch - #if not self.compiled_dag_cuda_device_set: - # torch.cuda.set_device(self.worker.device) - # self.compiled_dag_cuda_device_set = True + # TODO(swang): This is needed right now because Ray aDAG executes + # on a background thread, so we need to reset torch's current + # device. + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True return self.worker.execute_model(execute_model_req) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index b900a80ff8fd4..259b88800ebc6 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -30,11 +30,11 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -# Currently, this requires USE_SPMD_WORKER=True. +# Currently, this is not supported yet. USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG # If the env var is set, then we do not distinguish between the "driver worker" # vs other workers. Also, the rank 0 worker will be executed in a remote Ray -# worker. Currently this requires USE_RAY_COMPILED_DAG=True. +# worker. Currently this is not supported yet. USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER @@ -87,13 +87,10 @@ def __init__( self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} def _init_executor(self) -> None: - if USE_RAY_COMPILED_DAG: - assert USE_SPMD_WORKER, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") - if USE_SPMD_WORKER: - # TODO: Support SPMD worker for non-DAG Ray executor. - assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " - "VLLM_USE_RAY_COMPILED_DAG=1") + assert not USE_RAY_COMPILED_DAG, ( + "Compiled DAG is not supported for XPU yet") + assert not USE_SPMD_WORKER, ( + "SPMD worker is not supported for XPU yet") def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -118,6 +115,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): + assert not USE_RAY_COMPILED_DAG, ( + "Compiled DAG is not supported for XPU yet") + assert not USE_SPMD_WORKER, ( + "SPMD worker is not supported for XPU yet") if self.parallel_config.tensor_parallel_size == 1: # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization @@ -152,23 +153,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - if USE_SPMD_WORKER: - self.workers.append(worker) + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) else: - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. - self.workers.append(worker) - if not USE_SPMD_WORKER and self.driver_dummy_worker is None: + # Else, added to the list of workers. + self.workers.append(worker) + if self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -260,14 +258,9 @@ def _driver_execute_model( def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not USE_SPMD_WORKER: - return super().execute_model(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - - outputs = ray.get(self.forward_dag.execute(execute_model_req)) - return outputs + assert not USE_SPMD_WORKER, ( + "SPMD worker is not supported for XPU yet") + return super().execute_model(execute_model_req) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -329,21 +322,18 @@ def _run_workers( return ray_worker_outputs driver_worker_output = [] - # In SPMD mode, the driver worker is the same as any other worker, - # so we only explicitly execute on the driver worker if using a - # non-SPMD worker class. - if not USE_SPMD_WORKER: - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + assert not USE_SPMD_WORKER + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) @@ -357,10 +347,11 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources + from packaging import version - # TODO(swang): Upgrade version. - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") @@ -405,14 +396,9 @@ def __init__(self, *args, **kwargs): async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not USE_SPMD_WORKER: - return super().execute_model(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - - outputs = await self.forward_dag.execute_async(execute_model_req) - return await outputs + assert not USE_SPMD_WORKER, ( + "SPMD worker is not supported for XPU yet") + return super().execute_model(execute_model_req) async def _driver_execute_model_async( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3c22c73267b7f..8b06a18c62b48 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -171,6 +171,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) + self.use_spmd_worker = False # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CPUCacheEngine] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 56d8587f8f010..0a194abe3c9e5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -101,6 +101,7 @@ def __init__( multimodal_config=multimodal_config, **speculative_args, ) + self.use_spmd_worker: bool = False # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1897243d7719d..2304b34796340 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,7 +6,6 @@ import torch -import vllm.envs as envs from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -18,14 +17,13 @@ logger = init_logger(__name__) -USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER - class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to communicate request metadata to other workers. """ + use_spmd_worker: bool @abstractmethod def init_device(self) -> None: @@ -219,7 +217,9 @@ def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: - if USE_SPMD_WORKER: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + if self.use_spmd_worker: assert execute_model_req is not None, ( "VLLM_USE_SPMD_WORKER=1 requires each worker to take in an " "ExecuteModelRequest") @@ -227,18 +227,15 @@ def execute_model( return self._execute_model_with_nccl_control_plane(execute_model_req) - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest = None - ) -> Optional[List[SamplerOutput]]: - pass - def _execute_model_with_nccl_control_plane( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" + """ + Execute model with NCCL control plane. To execute model on all workers, + the driver worker first uses NCCL broadcasting primitive to broadcast + input data to all other workers. + """ if self.is_driver_worker: if execute_model_req is None: if self.do_metadata_broadcast: @@ -302,6 +299,30 @@ def _execute_model_with_nccl_control_plane( # list to conform to interface. return output + def _execute_model_spmd( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None) + class WorkerWrapperBase: """ @@ -314,10 +335,12 @@ class WorkerWrapperBase: def __init__(self, worker_module_name: str, worker_class_name: str, - trust_remote_code: bool = False) -> None: + trust_remote_code: bool = False, + use_spmd_worker: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name - self.worker = None + self.use_spmd_worker = use_spmd_worker + self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -344,12 +367,14 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) - if USE_SPMD_WORKER: - assert isinstance(worker_class, LocalOrDistributedWorkerBase), ( - "VLLM_USE_SPMD_WORKER=1 is currently only supported with " - "workers that inherit from LocalOrDistributedWorkerBase") + if self.use_spmd_worker: + assert issubclass(worker_class, LocalOrDistributedWorkerBase), ( + f"VLLM_USE_SPMD_WORKER=1 requires worker class {worker_class}" + " to inherit from LocalOrDistributedWorkerBase") self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + self.worker.use_spmd_worker = self.use_spmd_worker def execute_method(self, method, *args, **kwargs): try: