Skip to content

Commit

Permalink
use new api
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 7, 2024
1 parent d23eba9 commit 59b70cb
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 184 deletions.
103 changes: 53 additions & 50 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

#include "custom_all_reduce.cuh"

// fake pointer type, must match fptr_t type in ops.h
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));

fptr_t init_custom_ar(const std::vector<torch::Tensor>& ipc_tensors,
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
int world_size = ipc_tensors.size();
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
Expand All @@ -22,7 +23,7 @@ fptr_t init_custom_ar(const std::vector<torch::Tensor>& ipc_tensors,

vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(ipc_tensors[i].data_ptr());
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
Expand Down Expand Up @@ -51,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size());
}

void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) {
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();

TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
Expand All @@ -81,61 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
}

void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}

void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();

auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, cudaMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}

void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
}

int64_t meta_size() { return sizeof(vllm::Signal); }

void register_buffer(fptr_t _fa,
const std::vector<torch::Tensor>& ipc_tensors) {
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(ipc_tensors.size() == fa->world_size_);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < ipc_tensors.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(ipc_tensors[i].data_ptr());
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}

std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
// Use vector<int64_t> to represents byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}

void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
// Use vector<int64_t> to represents byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
24 changes: 19 additions & 5 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,27 @@ class CustomAllreduce {
int world_size_;
bool full_nvlink_;

// below are device pointers
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;

// stores the registered device pointers from all ranks
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
Expand Down Expand Up @@ -332,11 +347,10 @@ class CustomAllreduce {
return it->second;
}

std::pair<std::vector<uint8_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta() {
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
Expand Down
16 changes: 8 additions & 8 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,

#ifndef USE_ROCM
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<torch::Tensor>& ipc_tensors,
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<torch::Tensor>& ipc_tensors);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif
16 changes: 5 additions & 11 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,24 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(Tensor[] ipc_tensors, Tensor rank_data, "
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

custom_ar.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);

custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);

custom_ar.def("register_buffer(int fa, Tensor[] ipc_tensors) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);

custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
}
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
out = fa.all_reduce(out, registered=False)
torch.testing.assert_close(out, inp * (tp_size**num_communication))

inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
out = fa.all_reduce(out, registered=False)
torch.testing.assert_close(out, inp * (tp_size**num_communication))


Expand Down
32 changes: 16 additions & 16 deletions tools/profiler/visualize_layerwise_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def is_cross_device_reduce_1stage(op_name: str):
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name

def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name
def is_custom_ar_all_reduce(op_name: str):
return "_C_custom_ar::all_reduce" in op_name

def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
Expand Down Expand Up @@ -246,9 +246,9 @@ def is_reduce_kernel(op_name: str):
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))

custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
custom_ar_all_reduce_ops = list(
filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))

reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
Expand Down Expand Up @@ -289,21 +289,21 @@ def is_reduce_kernel(op_name: str):
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_ops):
trace_df['custom_ar_all_reduce_ops'] = trace_df[
custom_ar_all_reduce_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)

trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df


Expand Down
17 changes: 7 additions & 10 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,10 @@ def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)


def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)


def dispose(fa: int) -> None:
Expand All @@ -935,15 +932,15 @@ def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, ipc_tensors: List[torch.Tensor]) -> None:
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(fa: int, handles: List[str],
def register_graph_buffers(fa: int, handles: List[List[int]],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)

Expand Down
Loading

0 comments on commit 59b70cb

Please sign in to comment.