diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c491ce0782c0f..258818ba30957 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,8 +84,8 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py @@ -110,7 +110,7 @@ steps: # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8123648f7e6e4..f6d1335989b23 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -49,7 +49,7 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER +USE_SPMD_WORKER = envs.VLLM_USE_RAY_SPMD_WORKER def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: @@ -419,7 +419,7 @@ def from_engine_args( MultiprocessingGPUExecutor) assert not USE_SPMD_WORKER, ( "multiprocessing distributed executor backend does not " - "support VLLM_USE_SPMD_WORKER=1") + "support VLLM_USE_RAY_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor diff --git a/vllm/envs.py b/vllm/envs.py index 0bd52b85bdc25..595992e51db87 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,7 +34,7 @@ VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 - VLLM_USE_SPMD_WORKER: bool = False + VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -265,9 +265,9 @@ def get_default_config_root(): # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. - # Run vLLM with VLLM_USE_SPMD_WORKER=1 to enable it. - "VLLM_USE_SPMD_WORKER": - lambda: bool(os.getenv("VLLM_USE_SPMD_WORKER", 0)), + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6a37b9582f040..ac4869488e48f 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -22,27 +22,29 @@ logger = init_logger(__name__) -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -# Currently, this requires USE_SPMD_WORKER=True. -USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG -# If the env var is set, then we do not distinguish between the "driver worker" -# vs other workers. Also, the rank 0 worker will be executed in a remote Ray -# worker. Currently this requires USE_RAY_COMPILED_DAG=True. -USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER - class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: - if USE_RAY_COMPILED_DAG: - assert USE_SPMD_WORKER, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_SPMD_WORKER=1") - if USE_SPMD_WORKER: + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: # TODO: Support SPMD worker for non-DAG Ray executor. - assert USE_RAY_COMPILED_DAG, ("VLLM_USE_SPMD_WORKER=1 requires " - "VLLM_USE_RAY_COMPILED_DAG=1") + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -119,10 +121,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name=worker_module_name, worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, - use_spmd_worker=USE_SPMD_WORKER, ) - if USE_SPMD_WORKER: + if self.use_ray_spmd_worker: self.workers.append(worker) else: worker_ip = ray.get(worker.get_node_ip.remote()) @@ -139,7 +140,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) - if not USE_SPMD_WORKER and self.driver_dummy_worker is None: + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -269,15 +270,15 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - assert not USE_SPMD_WORKER, ( - "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not USE_SPMD_WORKER: + if not self.use_ray_spmd_worker: return super().execute_model(execute_model_req) if self.forward_dag is None: @@ -309,7 +310,7 @@ def _run_workers( - all_args/all_kwargs: args/kwargs for each worker are specified individually """ - if USE_SPMD_WORKER: + if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( "async_run_tensor_parallel_workers_only is not supported for " "spmd mode.") @@ -324,7 +325,7 @@ def _run_workers( # If using SPMD worker, all workers are the same, so we should execute # the args on all workers. Otherwise, we skip the first worker's args # because those args will go to the driver worker. - first_worker_args_index: int = 0 if USE_SPMD_WORKER else 1 + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, first_worker_args_index, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ @@ -348,7 +349,7 @@ def _run_workers( # In SPMD mode, the driver worker is the same as any other worker, # so we only explicitly execute on the driver worker if using a # non-SPMD worker class. - if not USE_SPMD_WORKER: + if not self.use_ray_spmd_worker: driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] @@ -400,19 +401,26 @@ def _compiled_ray_dag(self, enable_asyncio: bool): ]) return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not USE_SPMD_WORKER: + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: self.driver_exec_method = make_async( self.driver_worker.execute_method) async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not USE_SPMD_WORKER: + if not self.use_ray_spmd_worker: return await super().execute_model_async(execute_model_req) if self.forward_dag is None: @@ -426,8 +434,8 @@ async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - assert not USE_SPMD_WORKER, ( - "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -461,8 +469,8 @@ async def _run_task_with_lock(task, lock, *args, **kwargs): return results[-1] async def _start_worker_execution_loop(self): - assert not USE_SPMD_WORKER, ( - "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index ba7a23d10603f..fcbfa30d7a38a 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -42,7 +42,7 @@ def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - return self.worker.execute_model(execute_model_req) + return self.worker._execute_model_spmd(execute_model_req) ray_import_err = None diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 259b88800ebc6..2a93616ced06c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -30,12 +30,7 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -# Currently, this is not supported yet. USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG -# If the env var is set, then we do not distinguish between the "driver worker" -# vs other workers. Also, the rank 0 worker will be executed in a remote Ray -# worker. Currently this is not supported yet. -USE_SPMD_WORKER = envs.VLLM_USE_SPMD_WORKER class RayXPUExecutor(DistributedGPUExecutor): @@ -77,7 +72,9 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.forward_dag: Optional["ray.dag.CompiledDAG"] = None + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -87,10 +84,7 @@ def __init__( self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} def _init_executor(self) -> None: - assert not USE_RAY_COMPILED_DAG, ( - "Compiled DAG is not supported for XPU yet") - assert not USE_SPMD_WORKER, ( - "SPMD worker is not supported for XPU yet") + pass def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -115,10 +109,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - assert not USE_RAY_COMPILED_DAG, ( - "Compiled DAG is not supported for XPU yet") - assert not USE_SPMD_WORKER, ( - "SPMD worker is not supported for XPU yet") if self.parallel_config.tensor_parallel_size == 1: # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization @@ -250,18 +240,9 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - assert not USE_SPMD_WORKER, ( - "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - assert not USE_SPMD_WORKER, ( - "SPMD worker is not supported for XPU yet") - return super().execute_model(execute_model_req) - def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers( @@ -322,7 +303,6 @@ def _run_workers( return ray_worker_outputs driver_worker_output = [] - assert not USE_SPMD_WORKER driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. @@ -393,25 +373,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - assert not USE_SPMD_WORKER, ( - "SPMD worker is not supported for XPU yet") - return super().execute_model(execute_model_req) - async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: - assert not USE_SPMD_WORKER, ( - "driver_worker does not exist for VLLM_USE_SPMD_WORKER=1") return await self.driver_exec_method("execute_model", execute_model_req) async def _start_worker_execution_loop(self): - assert not USE_SPMD_WORKER, ( - "worker loop is disabled for VLLM_USE_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.workers diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 8b06a18c62b48..3c22c73267b7f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -171,7 +171,6 @@ def __init__( kv_cache_dtype=kv_cache_dtype, prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) - self.use_spmd_worker = False # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CPUCacheEngine] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82e740cf4f75c..f3c379d1aa34d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -101,7 +101,6 @@ def __init__( multimodal_config=multimodal_config, **speculative_args, ) - self.use_spmd_worker: bool = False # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 2304b34796340..a10281b02db89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -23,7 +23,6 @@ class WorkerBase(ABC): different hardware. Also abstracts control plane communication, e.g., to communicate request metadata to other workers. """ - use_spmd_worker: bool @abstractmethod def init_device(self) -> None: @@ -219,23 +218,6 @@ def execute_model( ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" - if self.use_spmd_worker: - assert execute_model_req is not None, ( - "VLLM_USE_SPMD_WORKER=1 requires each worker to take in an " - "ExecuteModelRequest") - return self._execute_model_spmd(execute_model_req) - - return self._execute_model_with_nccl_control_plane(execute_model_req) - - def _execute_model_with_nccl_control_plane( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model with NCCL control plane. To execute model on all workers, - the driver worker first uses NCCL broadcasting primitive to broadcast - input data to all other workers. - """ if self.is_driver_worker: if execute_model_req is None: if self.do_metadata_broadcast: @@ -307,6 +289,9 @@ def _execute_model_spmd( All workers take the same request, prepare the input and execute the model. """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) model_input: ModelRunnerInputBase = ( @@ -335,11 +320,9 @@ class WorkerWrapperBase: def __init__(self, worker_module_name: str, worker_class_name: str, - trust_remote_code: bool = False, - use_spmd_worker: bool = False) -> None: + trust_remote_code: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name - self.use_spmd_worker = use_spmd_worker self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -367,14 +350,9 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) - if self.use_spmd_worker: - assert issubclass(worker_class, LocalOrDistributedWorkerBase), ( - f"VLLM_USE_SPMD_WORKER=1 requires worker class {worker_class}" - " to inherit from LocalOrDistributedWorkerBase") self.worker = worker_class(*args, **kwargs) assert self.worker is not None - self.worker.use_spmd_worker = self.use_spmd_worker def execute_method(self, method, *args, **kwargs): try: