Skip to content

Commit

Permalink
fix ray spmd
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 7, 2024
1 parent 9faf44f commit 0869c45
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def sort_by_driver_then_worker_ip(worker):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/ray_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def sort_by_driver_then_worker_ip(worker):

worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def sort_by_driver_then_worker_ip(worker):
# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote()) \
) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/ray_xpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def _get_env_vars_to_be_updated(self):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore

Expand Down

0 comments on commit 0869c45

Please sign in to comment.