diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 5d77d8abb4718..50444d3abfaf2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -43,12 +43,15 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) if rank <= 2: - pg2 = StatelessProcessGroup.create( - init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) + pg2 = StatelessProcessGroup.create(host="127.0.0.1", + port=port2, + rank=rank, + world_size=3) data = torch.tensor([rank]) data = pg1.broadcast_obj(data, src=2) assert data.item() == 2 @@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2): torch.cuda.set_device(rank) - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) pynccl1 = PyNcclCommunicator(pg1, device=rank) pynccl1.disabled = False if rank <= 2: - pg2 = StatelessProcessGroup.create( - init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) + pg2 = StatelessProcessGroup.create(host="127.0.0.1", + port=port2, + rank=rank, + world_size=3) pynccl2 = PyNcclCommunicator(pg2, device=rank) pynccl2.disabled = False data = torch.tensor([rank]).cuda() @@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) if rank == 2: @@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", + pg1 = StatelessProcessGroup.create(host="127.0.0.1", + port=port1, rank=rank, world_size=WORLD_SIZE) data = pg1.all_gather_obj(rank) @@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): pg1.barrier() -# TODO: investigate why this test is flaky. It hangs during initialization. -@pytest.mark.skip("Skip the test because it is flaky.") @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize( "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a77b41322f376..dcfcb848cbe06 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -9,7 +9,7 @@ from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -from torch.distributed.rendezvous import rendezvous +from torch.distributed import TCPStore import vllm.envs as envs from vllm.logger import init_logger @@ -97,7 +97,6 @@ class StatelessProcessGroup: group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ - prefix: str rank: int world_size: int store: torch._C._distributed_c10d.Store @@ -127,7 +126,7 @@ def __post_init__(self): def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" self.expire_data() - key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}" + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" self.store.set(key, pickle.dumps(obj)) self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) @@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( self.store.get( - f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}" - )) + f"send_to/{self.rank}/{self.recv_src_counter[src]}")) self.recv_src_counter[src] += 1 return obj @@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """ if self.rank == src: self.expire_data() - key = (f"{self.prefix}/broadcast_from/{src}/" + key = (f"broadcast_from/{src}/" f"{self.broadcast_send_counter}") self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: - key = (f"{self.prefix}/broadcast_from/{src}/" + key = (f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}") recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 @@ -194,7 +192,8 @@ def barrier(self): @staticmethod def create( - init_method: str, + host: str, + port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600, @@ -214,15 +213,14 @@ def create( can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa - from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT - timeout = _DEFAULT_PG_TIMEOUT - - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) - store.set_timeout(timeout) + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) return StatelessProcessGroup( - prefix=init_method, rank=rank, world_size=world_size, store=store,