diff --git a/requirements-test.in b/requirements-test.in index c0b228148ab31..57fddb416317e 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -13,7 +13,7 @@ einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests peft -ray[adag]==2.35 +ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests timm # required for internvl test diff --git a/requirements-test.txt b/requirements-test.txt index 8ceb705cdffd7..c786a1249bddb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -410,7 +410,7 @@ pyyaml==6.0.2 # ray # timm # transformers -ray[adag]==2.35.0 +ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 # via tensorizer diff --git a/vllm/envs.py b/vllm/envs.py index be5d9985b63a4..bc8c1499e9534 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,6 +45,7 @@ VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = True VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -337,6 +338,13 @@ def get_default_config_root(): lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) ), + # If the env var is set, it enables GPU communication overlap in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "1")) + ), + # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4263fb27265f6..4bf5cbbd18ffe 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -414,12 +414,10 @@ def _check_ray_adag_installation(self): import pkg_resources from packaging import version - required_version = version.parse("2.35") + required_version = version.parse("2.40") current_version = version.parse( pkg_resources.get_distribution("ray").version) - # TODO: update the constraint once we adapt to the backward - # incompatible API change from ray 2.36 - if current_version != required_version: + if current_version < required_version: raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") @@ -445,6 +443,8 @@ def _compiled_ray_dag(self, enable_asyncio: bool): logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) + logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) with InputNode() as input_data: # Example DAG: PP=2, TP=4 # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 @@ -480,7 +480,10 @@ def _compiled_ray_dag(self, enable_asyncio: bool): forward_dag = MultiOutputNode(outputs) - return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + return forward_dag.experimental_compile( + enable_asyncio=enable_asyncio, + _overlap_gpu_communication=envs. + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) def __del__(self): self.shutdown() @@ -507,8 +510,8 @@ async def execute_model_async( serialized_data = self.input_encoder.encode(execute_model_req) dag_future = await self.forward_dag.execute_async(serialized_data) - outputs = await dag_future - return self.output_decoder.decode(outputs[0]) + output = await dag_future[0] + return self.output_decoder.decode(output) async def _driver_execute_model_async( self,