From 59b70cb2fdff8173a56119ca5f334e936656b8af Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 7 Nov 2024 03:25:25 +0000 Subject: [PATCH] use new api --- csrc/custom_all_reduce.cu | 103 +++++++-------- csrc/custom_all_reduce.cuh | 24 +++- csrc/ops.h | 16 +-- csrc/torch_bindings.cpp | 16 +-- tests/distributed/test_custom_all_reduce.py | 4 +- tools/profiler/visualize_layerwise_profile.py | 32 ++--- vllm/_custom_ops.py | 17 ++- .../device_communicators/custom_all_reduce.py | 117 ++++++------------ 8 files changed, 145 insertions(+), 184 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 83e813d82014b..c50dac19aaeae 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -5,14 +5,15 @@ #include "custom_all_reduce.cuh" -// fake pointer type, must match fptr_t type in ops.h +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(const std::vector& ipc_tensors, +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) { - int world_size = ipc_tensors.size(); + int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) @@ -22,7 +23,7 @@ fptr_t init_custom_ar(const std::vector& ipc_tensors, vllm::Signal* ipc_ptrs[8]; for (int i = 0; i < world_size; i++) { - ipc_ptrs[i] = reinterpret_cast(ipc_tensors[i].data_ptr()); + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, @@ -51,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - cudaStream_t stream) { +/** + * Performs an out-of-place allreduce and stores result in out. + * + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); + TORCH_CHECK(_is_weak_contiguous(inp)); + auto input_size = inp.numel() * inp.element_size(); + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp.data_ptr(); + } switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), + stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } @@ -81,61 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - _all_reduce(_fa, inp, out, stream); -} - -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - - auto input_size = inp.numel() * inp.element_size(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), - "registered buffer is too small to contain the input"); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), - input_size, cudaMemcpyDeviceToDevice, stream)); - _all_reduce(_fa, reg_buffer, out, stream); -} - void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - delete fa; + delete reinterpret_cast(_fa); } int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, - const std::vector& ipc_tensors) { +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); - TORCH_CHECK(ipc_tensors.size() == fa->world_size_); + TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); void* ipc_ptrs[8]; - for (int i = 0; i < ipc_tensors.size(); i++) { - ipc_ptrs[i] = reinterpret_cast(ipc_tensors[i].data_ptr()); + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } fa->register_buffer(ipc_ptrs); } -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa) { +// Use vector to represents byte data for python binding compatibility. +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); - auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = - torch::empty({static_cast(handle_bytes.size())}, options); - std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); - return {handles, std::move(offsets)}; + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); } -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +// Use vector to represents byte data for python binding compatibility. +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); - fa->register_graph_buffers(handles, offsets); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index f48969eeddb0d..6be4d4f2b2eb8 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -285,12 +285,27 @@ class CustomAllreduce { int world_size_; bool full_nvlink_; - // below are device pointers RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. std::unordered_map buffers_; Signal* self_sg_; - // stores the registered device pointers from all ranks + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers @@ -332,11 +347,10 @@ class CustomAllreduce { return it->second; } - std::pair, std::vector> - get_graph_buffer_ipc_meta() { + std::pair> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::vector handles(handle_sz * num_buffers, 0); + std::string handles(handle_sz * num_buffers, static_cast(0)); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; diff --git a/csrc/ops.h b/csrc/ops.h index 1930c77bcb805..e0775ee1891df 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -199,16 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, #ifndef USE_ROCM using fptr_t = int64_t; -fptr_t init_custom_ar(const std::vector& ipc_tensors, +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, const std::vector& ipc_tensors); -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets); #endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 953c89aceb049..971a45d50ffa4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,24 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def( - "init_custom_ar(Tensor[] ipc_tensors, Tensor rank_data, " + "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " "int rank, bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); - custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); - custom_ar.def( - "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " - "()"); - custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " + "int reg_buffer_sz_bytes) -> ()"); + custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); custom_ar.def("dispose", &dispose); custom_ar.def("meta_size", &meta_size); - custom_ar.def("register_buffer(int fa, Tensor[] ipc_tensors) -> ()"); - custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); - + custom_ar.def("register_buffer", ®ister_buffer); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("register_graph_buffers", ®ister_graph_buffers); } diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 95435e753058a..86ca1948ef94a 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce_unreg(out) + out = fa.all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce_unreg(out) + out = fa.all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index efd6beee865c2..adc44474aa4c1 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -196,8 +196,8 @@ def is_cross_device_reduce_1stage(op_name: str): def is_cross_device_reduce_2stage(op_name: str): return "cross_device_reduce_2stage" in op_name - def is_custom_ar_all_reduce_unreg(op_name: str): - return "_C_custom_ar::all_reduce_unreg" in op_name + def is_custom_ar_all_reduce(op_name: str): + return "_C_custom_ar::all_reduce" in op_name def is_reduce_kernel(op_name: str): return "reduce_kernel" in op_name @@ -246,9 +246,9 @@ def is_reduce_kernel(op_name: str): filter(lambda x: is_cross_device_reduce_2stage(x), ops)) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) - custom_ar_all_reduce_unreg_ops = list( - filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops)) - ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops)) + custom_ar_all_reduce_ops = list( + filter(lambda x: is_custom_ar_all_reduce(x), ops)) + ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) @@ -289,21 +289,21 @@ def is_reduce_kernel(op_name: str): if len(cross_device_reduce_2stage_ops): trace_df['cross_device_reduce_2stage_ops'] = trace_df[ cross_device_reduce_2stage_ops].agg("sum", axis=1) - if len(custom_ar_all_reduce_unreg_ops): - trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[ - custom_ar_all_reduce_unreg_ops].agg("sum", axis=1) + if len(custom_ar_all_reduce_ops): + trace_df['custom_ar_all_reduce_ops'] = trace_df[ + custom_ar_all_reduce_ops].agg("sum", axis=1) if len(reduce_kernel_ops): trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", axis=1) - trace_df.drop( - attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + - mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + - nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + - cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops + - reduce_kernel_ops, - axis=1, - inplace=True) + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + vocab_embed_ops + mem_ops + elementwise_ops + + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + + nccl_other_ops + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops + + reduce_kernel_ops, + axis=1, + inplace=True) return trace_df diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e014971dfa627..767d45ede7e87 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -918,13 +918,10 @@ def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, full_nvlink) -def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) - - -def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, - out: torch.Tensor) -> None: - torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, + reg_buffer_sz_bytes: int) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, + reg_buffer_sz_bytes) def dispose(fa: int) -> None: @@ -935,15 +932,15 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, ipc_tensors: List[torch.Tensor]) -> None: +def register_buffer(fa: int, ipc_tensors: List[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: List[str], +def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a48556bd01e68..b8a27f40070c2 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,11 +1,10 @@ import ctypes from contextlib import contextmanager -from typing import Any, List, Optional, Union +from typing import List, Optional, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.multiprocessing.reductions import reduce_tensor import vllm.envs as envs from vllm import _custom_ops as ops @@ -152,14 +151,11 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - self.meta = torch.zeros(ops.meta_size() + max_size, - dtype=torch.uint8, - device=self.device) + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, + group=group) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, - dtype=torch.uint8, - device=self.device) + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -171,17 +167,19 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size - # Retain reference to IPC tensors to prevent garbage collection. - self._ipc_tensors: List[torch.Tensor] = [] self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self._get_ipc_tensors(self.meta), - self.rank_data, rank, self.full_nvlink) - self.register_buffer(self.buffer) + self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, + self.full_nvlink) + ops.register_buffer(self._ptr, self.buffer_ptrs) @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None) -> List[int]: + """ + Creates shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) @@ -222,59 +220,14 @@ def capture(self): if not self.disabled: self.register_graph_buffers() - def _get_ipc_tensors(self, inp: torch.Tensor) -> List[torch.Tensor]: - """Gather the ipc-enabled tensors of `inp` from all ranks.""" - all_meta = self._all_gather_object(reduce_tensor(inp)) - all_tensors = [] - for i, obj in enumerate(all_meta): - func = obj[0][0] - args = list(obj[0][1]) - # This might break in the future since what `args` encompasses - # may change. - args[6] = inp.device.index - if i != self.rank: - all_tensors.append(func(*args)) - else: - all_tensors.append(inp) - self._ipc_tensors.extend(all_tensors) - return all_tensors - - def _all_gather_object(self, data: Any) -> List[List[Any]]: - """All gather serializable objects.""" - all_data: List[List[Any]] = [[None] for i in range(self.world_size)] - all_data[self.rank][0] = data - # We cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - ranks = dist.get_process_group_ranks(group=self.group) - ranks.sort() - for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], - src=rank, - group=self.group, - device="cpu") - - return all_data - - def _gather_ipc_meta(self, shard_data): - # Note: don't use `[[None]] * self.world_size` here - # because it will create a list of the same reference. - all_data = self._all_gather_object(shard_data) - - handles = [] - offsets = [] - for i in range(len(all_data)): - handles.append(all_data[i][0][0]) # type: ignore - offsets.append(all_data[i][0][1]) # type: ignore - return handles, offsets - - def register_buffer(self, inp: torch.Tensor): - ops.register_buffer(self._ptr, self._get_ipc_tensors(inp)) - def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) + all_data = [None] * dist.get_world_size(group=self.group) + dist.all_gather_object(all_data, (handle, offset), group=self.group) + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): @@ -292,35 +245,33 @@ def should_custom_ar(self, inp: torch.Tensor): return inp_size < self.max_size return False - def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): - """Performs all reduce. + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. - This method assumes inp tensor is IPC registered with register_buffer, - or, in the context of cuda graphs, register_graph_buffers. + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. """ if out is None: out = torch.empty_like(inp) - ops.all_reduce_reg(self._ptr, inp, out) - return out - - def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): - """Performs all reduce, assuming inp tensor is not IPC registered.""" - if out is None: - out = torch.empty_like(inp) - ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], + self.max_size) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - """Conditionally performs custom allreduce. - - Returns the allreduced result, or None if custom allreduce is not - suitable for the given Tensor under the current context. - """ + # when custom allreduce is disabled, this will be None if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce_reg(input) + return self.all_reduce(input, registered=True) else: # if warm up, mimic the allocation pattern # since custom allreduce is out-of-place @@ -330,12 +281,14 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # custom allreduce incurs a cost of cudaMemcpy, which should # be small(<=1% of overall latency) compared to the performance # gains of using custom kernels - return self.all_reduce_unreg(input) + return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close()