diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index ede070f5082aa..bfab9c15f4fb2 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,12 +1,11 @@ from collections import namedtuple -from contextlib import contextmanager, nullcontext, suppress +from contextlib import contextmanager, nullcontext, suppress from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup -# for shm from multiprocessing import shared_memory from unittest.mock import patch import atexit @@ -21,10 +20,11 @@ get_tp_pynccl_communicator) from vllm.distributed.device_communicators.shm_broadcast import ( - ShmRingBufferIO) + ShmRingBufferIO) shm_broadcaster: Optional[ShmRingBufferIO] = None + def init_shm_broadcaster(group): global shm_broadcaster world_size = get_tensor_model_parallel_world_size() @@ -33,6 +33,7 @@ def init_shm_broadcaster(group): shm_broadcaster = ShmRingBufferIO.create_from_process_group( group, 1 << 20, 6) + @dataclass class GraphCaptureContext: stream: torch.cuda.Stream @@ -207,7 +208,7 @@ def broadcast(input_: torch.Tensor, torch.distributed.broadcast(input_, src=src, group=group) return input_ -def broadcast_object(obj: Optional[Any] = None, src: int = 0):#, group: Optional[ProcessGroup] = None): +def broadcast_object(obj: Optional[Any] = None, src: int = 0): """Broadcast the input object. NOTE: `src` is the local rank of the source rank. """ @@ -225,17 +226,14 @@ def broadcast_object(obj: Optional[Any] = None, src: int = 0):#, group: Optional assert src == 0, "Shared memory broadcaster only supports src=0" return shm_broadcaster.broadcast_object(obj) if torch.distributed.get_rank() == src: - torch.distributed.broadcast_object_list([obj], - src=src, - group=group) + torch.distributed.broadcast_object_list([obj], src=src, group=group) return obj else: recv = [None] - torch.distributed.broadcast_object_list(recv, - src=src, - group=group) + torch.distributed.broadcast_object_list(recv, src=src, group=group) return recv[0] + def broadcast_object_list(obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None): @@ -370,6 +368,7 @@ def broadcast_tensor_dict( async_handle.wait() return tensor_dict + def is_in_the_same_node(pg: ProcessGroup): """ This is a collective operation that checks if all processes in the group @@ -434,8 +433,10 @@ def is_in_the_same_node(pg: ProcessGroup): return is_in_the_same_node.sum().item() == world_size + def destroy_shm_broadcaster(): global shm_broadcaster shm_broadcaster = None + atexit.register(destroy_shm_broadcaster)