Skip to content

Commit

Permalink
[Misc] Support TPU in initialize_ray_cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed Jul 26, 2024
1 parent 2eb9f4f commit e6c1765
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -93,6 +93,7 @@ def initialize_ray_cluster(
# Placement group is already set.
return

device_str = "GPU" if not is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand All @@ -101,24 +102,27 @@ def initialize_ray_cluster(
# Verify that we can use the placement group.
gpu_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
bundle_gpus = bundle.get(device_str, 0)
if bundle_gpus > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
"Placement group bundle cannot have more than 1 "
f"{device_str}.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
else:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
num_gpus_in_cluster = ray.cluster_resources().get(device_str, 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
Expand Down

0 comments on commit e6c1765

Please sign in to comment.