Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jul 31, 2024
1 parent 7151610 commit 028eb85
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections import namedtuple
from contextlib import contextmanager, nullcontext, suppress
from contextlib import contextmanager, nullcontext, suppress
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch.distributed import ProcessGroup

# for shm
from multiprocessing import shared_memory
from unittest.mock import patch
import atexit
Expand All @@ -21,10 +20,11 @@
get_tp_pynccl_communicator)

from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBufferIO)
ShmRingBufferIO)

shm_broadcaster: Optional[ShmRingBufferIO] = None


def init_shm_broadcaster(group):
global shm_broadcaster
world_size = get_tensor_model_parallel_world_size()
Expand All @@ -33,6 +33,7 @@ def init_shm_broadcaster(group):
shm_broadcaster = ShmRingBufferIO.create_from_process_group(
group, 1 << 20, 6)


@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
Expand Down Expand Up @@ -207,7 +208,7 @@ def broadcast(input_: torch.Tensor,
torch.distributed.broadcast(input_, src=src, group=group)
return input_

def broadcast_object(obj: Optional[Any] = None, src: int = 0):#, group: Optional[ProcessGroup] = None):
def broadcast_object(obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
Expand All @@ -225,17 +226,14 @@ def broadcast_object(obj: Optional[Any] = None, src: int = 0):#, group: Optional
assert src == 0, "Shared memory broadcaster only supports src=0"
return shm_broadcaster.broadcast_object(obj)
if torch.distributed.get_rank() == src:
torch.distributed.broadcast_object_list([obj],
src=src,
group=group)
torch.distributed.broadcast_object_list([obj], src=src, group=group)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=src,
group=group)
torch.distributed.broadcast_object_list(recv, src=src, group=group)
return recv[0]


def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
Expand Down Expand Up @@ -370,6 +368,7 @@ def broadcast_tensor_dict(
async_handle.wait()
return tensor_dict


def is_in_the_same_node(pg: ProcessGroup):
"""
This is a collective operation that checks if all processes in the group
Expand Down Expand Up @@ -434,8 +433,10 @@ def is_in_the_same_node(pg: ProcessGroup):

return is_in_the_same_node.sum().item() == world_size


def destroy_shm_broadcaster():
global shm_broadcaster
shm_broadcaster = None


atexit.register(destroy_shm_broadcaster)

0 comments on commit 028eb85

Please sign in to comment.