From de0b8a120c247d27263b7036b7ec339df137db02 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 7 Nov 2024 06:34:11 +0000 Subject: [PATCH] fix Signed-off-by: Hanzhi Zhou --- .../device_communicators/custom_all_reduce.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 42fc364b3e0dc..62929dc0feaaf 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -222,8 +222,18 @@ def capture(self): def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) logger.info("Registering %d cuda graph addresses", len(offset)) - all_data = [None] * dist.get_world_size(group=self.group) - dist.all_gather_object(all_data, (handle, offset), group=self.group) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore