Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Hanzhi Zhou <[email protected]>
  • Loading branch information
hanzhi713 committed Nov 7, 2024
1 parent 9307bfa commit de0b8a1
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,18 @@ def capture(self):
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
all_data = [None] * dist.get_world_size(group=self.group)
dist.all_gather_object(all_data, (handle, offset), group=self.group)
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
Expand Down

0 comments on commit de0b8a1

Please sign in to comment.