From e6c1765f50fd84dc1ae96abf7d803b642a5525bd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 25 Jul 2024 18:31:33 -0700 Subject: [PATCH] [Misc] Support TPU in initialize_ray_cluster --- vllm/executor/ray_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index fcbfa30d7a38a..b23039ff997c3 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -93,6 +93,7 @@ def initialize_ray_cluster( # Placement group is already set. return + device_str = "GPU" if not is_tpu() else "TPU" # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: @@ -101,24 +102,27 @@ def initialize_ray_cluster( # Verify that we can use the placement group. gpu_bundles = 0 for bundle in bundles: - bundle_gpus = bundle.get("GPU", 0) + bundle_gpus = bundle.get(device_str, 0) if bundle_gpus > 1: raise ValueError( - "Placement group bundle cannot have more than 1 GPU.") + "Placement group bundle cannot have more than 1 " + f"{device_str}.") if bundle_gpus: gpu_bundles += 1 if parallel_config.world_size > gpu_bundles: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the placement group.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") else: - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) + num_gpus_in_cluster = ray.cluster_resources().get(device_str, 0) if parallel_config.world_size > num_gpus_in_cluster: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the cluster.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + placement_group_specs = ([{ + device_str: 1 + }] * parallel_config.world_size) current_placement_group = ray.util.placement_group( placement_group_specs) # Wait until PG is ready - this will block until all