Skip to content

Commit

Permalink
[distributed] add function to create ipc buffers directly (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#10064)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 6, 2024
1 parent 4089985 commit 4be3a45
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
- torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus
- pytest -v -s -x lora/test_mixtral.py

Expand Down
59 changes: 59 additions & 0 deletions tests/distributed/test_ca_buffer_sharing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# can only run on machines with p2p access across GPUs
# can only run with torchrun:
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py

import ctypes

import torch
import torch.distributed as dist

from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
CustomAllreduce)

# create a cpu process group for communicating metadata (ipc handle)
dist.init_process_group(backend="gloo")
rank = local_rank = dist.get_rank()
world_size = dist.get_world_size()

# every process sets its own device (differently)
lib = CudaRTLibrary()
lib.cudaSetDevice(rank)

buffer_size_in_bytes = 1024
byte_value = 2 # the value we write to the buffer for verification

pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)

print(f"Rank {rank} has pointers {pointers}")

dist.barrier()
torch.cuda.synchronize()

if rank == 0:
# the first rank tries to write to all buffers
for p in pointers:
pointer = ctypes.c_void_p(p)
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)

dist.barrier()
torch.cuda.synchronize()

host_data = (ctypes.c_char * buffer_size_in_bytes)()

# all ranks read from all buffers, and check if the data is correct
for p in pointers:
pointer = ctypes.c_void_p(p)
lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes)
for i in range(buffer_size_in_bytes):
assert ord(host_data[i]) == byte_value, (
f"Rank {rank} failed"
f" to verify buffer {p}. Expected {byte_value}, "
f"got {ord(host_data[i])}")

print(f"Rank {rank} verified all buffers")

dist.barrier()
torch.cuda.synchronize()

CustomAllreduce.free_shared_buffer(pointers)
31 changes: 31 additions & 0 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ctypes
from contextlib import contextmanager
from typing import Any, List, Optional, Union

Expand All @@ -7,6 +8,7 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
Expand Down Expand Up @@ -174,6 +176,35 @@ def __init__(self,
offsets, rank, self.full_nvlink)
self.register_buffer(self.buffer)

@staticmethod
def create_shared_buffer(
size_in_bytes: int,
group: Optional[ProcessGroup] = None) -> List[int]:
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)

pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(
lib.cudaIpcOpenMemHandle(h).value) # type: ignore

return pointers

@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))

@contextmanager
def capture(self):
"""
Expand Down

0 comments on commit 4be3a45

Please sign in to comment.