Skip to content

Commit

Permalink
[core] Bump ray to use _overlap_gpu_communication in compiled graph t…
Browse files Browse the repository at this point in the history
…ests (vllm-project#10410)

Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 and Rui Qiao authored Dec 11, 2024
1 parent 66aaa77 commit 72ff3a9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio tests
peft
ray[adag]==2.35
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests
timm # required for internvl test
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ pyyaml==6.0.2
# ray
# timm
# transformers
ray[adag]==2.35.0
ray[adag]==2.40.0
# via -r requirements-test.in
redis==5.2.0
# via tensorizer
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = True
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
Expand Down Expand Up @@ -337,6 +338,13 @@ def get_default_config_root():
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
),

# If the env var is set, it enables GPU communication overlap in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "1"))
),

# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD":
Expand Down
17 changes: 10 additions & 7 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,10 @@ def _check_ray_adag_installation(self):
import pkg_resources
from packaging import version

required_version = version.parse("2.35")
required_version = version.parse("2.40")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
# TODO: update the constraint once we adapt to the backward
# incompatible API change from ray 2.36
if current_version != required_version:
if current_version < required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")

Expand All @@ -445,6 +443,8 @@ def _compiled_ray_dag(self, enable_asyncio: bool):

logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
Expand Down Expand Up @@ -480,7 +480,10 @@ def _compiled_ray_dag(self, enable_asyncio: bool):

forward_dag = MultiOutputNode(outputs)

return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
return forward_dag.experimental_compile(
enable_asyncio=enable_asyncio,
_overlap_gpu_communication=envs.
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)

def __del__(self):
self.shutdown()
Expand All @@ -507,8 +510,8 @@ async def execute_model_async(

serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future
return self.output_decoder.decode(outputs[0])
output = await dag_future[0]
return self.output_decoder.decode(output)

async def _driver_execute_model_async(
self,
Expand Down

0 comments on commit 72ff3a9

Please sign in to comment.