Skip to content

Commit

Permalink
[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Apr 24, 2024
1 parent 79a268c commit 91f50a6
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 71 deletions.
15 changes: 10 additions & 5 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables


Expand All @@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for p in processes:
p.join()

for p in processes:
assert p.exitcode == 0


def update_env(fn):
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
def wrapped_fn(env):
update_environment_variables(env)
init_distributed_environment()
fn()

return wrapper
return wrapped_fn


@update_env
@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
Expand All @@ -53,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2)


@update_env
@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
Expand Down
122 changes: 71 additions & 51 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# variable in the code.

import ctypes
import datetime
import platform
from typing import Optional, Union

# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp

from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check

Expand Down Expand Up @@ -59,6 +60,18 @@

ncclResult_t = ctypes.c_int

_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]


def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")


# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
Expand All @@ -68,8 +81,7 @@

def ncclGetVersion() -> str:
version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version))
assert result == 0
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
Expand All @@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):

def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
assert result == 0
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id


Expand Down Expand Up @@ -199,66 +210,75 @@ class NCCLCommunicator:

def __init__(
self,
backend=None,
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
local_rank: int = -1,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
if not dist.is_initialized():
backend = backend or "nccl"
assert backend == 'nccl', (
"only use nccl backend for starting the NCCL communicator")
dist.init_process_group(backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
pg_options=pg_options)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if local_rank == -1:
local_rank = self.rank
self.local_rank = local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del world_size, rank, local_rank
torch.cuda.set_device(self.local_rank)
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the NCCLCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.local_rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=0, group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)

def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)))

def __del__(self):
# `dist` module might have been already destroyed
Expand Down
12 changes: 3 additions & 9 deletions vllm/distributed/device_communicators/pynccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp

from vllm.logger import init_logger

Expand Down Expand Up @@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass


def init_process_group(world_size: int,
rank: int,
init_method: str,
local_rank: int = -1) -> None:
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
local_rank=local_rank,
rank=rank)
comm = NCCLCommunicator(group=group)


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
import os
from typing import Optional

import torch
Expand Down Expand Up @@ -73,6 +74,11 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK
_LOCAL_RANK = local_rank

Expand Down
9 changes: 3 additions & 6 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils.init_process_group()

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand Down

0 comments on commit 91f50a6

Please sign in to comment.