From 91f50a6fe240b2c92a99e171bb11d083f82e4a84 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 23 Apr 2024 18:32:19 -0700 Subject: [PATCH] [Core][Distributed] use cpu/gloo to initialize pynccl (#4248) --- tests/distributed/test_pynccl.py | 15 ++- .../device_communicators/pynccl.py | 122 ++++++++++-------- .../device_communicators/pynccl_utils.py | 12 +- vllm/distributed/parallel_state.py | 6 + vllm/worker/worker.py | 9 +- 5 files changed, 93 insertions(+), 71 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index d58f621d36b86..6d7d4a5806bd0 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,6 +5,7 @@ from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) +from vllm.distributed.parallel_state import init_distributed_environment from vllm.utils import update_environment_variables @@ -26,19 +27,23 @@ def distributed_run(fn, world_size): for p in processes: p.join() + for p in processes: + assert p.exitcode == 0 + -def update_env(fn): +def worker_fn_wrapper(fn): # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function - def wrapper(env): + def wrapped_fn(env): update_environment_variables(env) + init_distributed_environment() fn() - return wrapper + return wrapped_fn -@update_env +@worker_fn_wrapper def worker_fn(): comm = NCCLCommunicator() tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) @@ -53,7 +58,7 @@ def test_pynccl(): distributed_run(worker_fn, 2) -@update_env +@worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 0707afe922f40..fcedf0fed34cb 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -20,14 +20,15 @@ # variable in the code. import ctypes -import datetime import platform +from typing import Optional, Union # ===================== import region ===================== import torch import torch.distributed as dist -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger from vllm.utils import find_nccl_library, nccl_integrity_check @@ -59,6 +60,18 @@ ncclResult_t = ctypes.c_int +_c_ncclGetErrorString = nccl.ncclGetErrorString +_c_ncclGetErrorString.restype = ctypes.c_char_p +_c_ncclGetErrorString.argtypes = [ncclResult_t] + + +def NCCL_CHECK(result: ncclResult_t) -> None: + if result != 0: + error_str = _c_ncclGetErrorString(result) + error_str = error_str.decode("utf-8") + raise RuntimeError(f"NCCL error: {error_str}") + + # equivalent to c declaration: # ncclResult_t ncclGetVersion(int *version); _c_ncclGetVersion = nccl.ncclGetVersion @@ -68,8 +81,7 @@ def ncclGetVersion() -> str: version = ctypes.c_int() - result = _c_ncclGetVersion(ctypes.byref(version)) - assert result == 0 + NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) # something like 21903 --> "2.19.3" version_str = str(version.value) major = version_str[0].lstrip("0") @@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure): def ncclGetUniqueId() -> NcclUniqueId: unique_id = NcclUniqueId() - result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) - assert result == 0 + NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) return unique_id @@ -199,66 +210,75 @@ class NCCLCommunicator: def __init__( self, - backend=None, - init_method=None, - timeout=datetime.timedelta(seconds=10), - world_size: int = -1, - rank: int = -1, - store=None, - group_name: str = "", - pg_options=None, - local_rank: int = -1, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, ): - if not dist.is_initialized(): - backend = backend or "nccl" - assert backend == 'nccl', ( - "only use nccl backend for starting the NCCL communicator") - dist.init_process_group(backend=backend, - init_method=init_method, - timeout=timeout, - world_size=world_size, - rank=rank, - store=store, - group_name=group_name, - pg_options=pg_options) - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - if local_rank == -1: - local_rank = self.rank - self.local_rank = local_rank - # don't use these args, as they can be -1 - # use `self.rank`, `self.local_rank` and `self.world_size` instead - del world_size, rank, local_rank - torch.cuda.set_device(self.local_rank) + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the NCCLCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + assert dist.is_initialized() + group = get_cpu_world_group() if group is None else group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "NCCLCommunicator should be attached to a non-NCCL group.") + self.group = group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) if self.rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( - self.local_rank) - dist.broadcast(tensor, src=0) - byte_list = tensor.cpu().tolist() + tensor = torch.ByteTensor(list(self.unique_id.internal)) + dist.broadcast(tensor, src=0, group=group) + byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank) - assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + current_device = torch.cuda.current_device() + try: + torch.cuda.set_device(device) + NCCL_CHECK( + _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank)) + self.stream = torch.cuda.Stream() + finally: + torch.cuda.set_device(current_device) def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream)) - assert result == 0 + NCCL_CHECK( + _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), + ctypes.c_void_p(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + ctypes.c_void_p(stream.cuda_stream))) def __del__(self): # `dist` module might have been already destroyed diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 916dc814af7eb..a717fddb695ba 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp from vllm.logger import init_logger @@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, - rank: int, - init_method: str, - local_rank: int = -1) -> None: +def init_process_group(group: Optional[ProcessGroup] = None) -> None: assert not is_initialized() global comm logger.info(f"vLLM is using nccl=={ncclGetVersion()}") - comm = NCCLCommunicator(init_method=init_method, - world_size=world_size, - local_rank=local_rank, - rank=rank) + comm = NCCLCommunicator(group=group) def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e2473736375e0..515f2212511b7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -4,6 +4,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib +import os from typing import Optional import torch @@ -73,6 +74,11 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1 and distributed_init_method == "env://": + local_rank = int(os.environ['LOCAL_RANK']) global _LOCAL_RANK _LOCAL_RANK = local_rank diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2203570b37ad6..39ad428f16fe3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -298,12 +298,9 @@ def init_worker_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - pynccl_utils.init_process_group( - world_size=parallel_config.world_size, - local_rank=local_rank, - rank=rank, - init_method=distributed_init_method, - ) + # NOTE(kaichao): By default, pynccl will use information inside + # `parallel_state` for initialization. + pynccl_utils.init_process_group() ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)