diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1eb749f64d36b..3e940549862ea 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -510,6 +510,7 @@ steps: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py + - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus - pytest -v -s -x lora/test_mixtral.py diff --git a/tests/distributed/test_ca_buffer_sharing.py b/tests/distributed/test_ca_buffer_sharing.py new file mode 100644 index 0000000000000..fc4043cd3014e --- /dev/null +++ b/tests/distributed/test_ca_buffer_sharing.py @@ -0,0 +1,59 @@ +# can only run on machines with p2p access across GPUs +# can only run with torchrun: +# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py + +import ctypes + +import torch +import torch.distributed as dist + +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa + CustomAllreduce) + +# create a cpu process group for communicating metadata (ipc handle) +dist.init_process_group(backend="gloo") +rank = local_rank = dist.get_rank() +world_size = dist.get_world_size() + +# every process sets its own device (differently) +lib = CudaRTLibrary() +lib.cudaSetDevice(rank) + +buffer_size_in_bytes = 1024 +byte_value = 2 # the value we write to the buffer for verification + +pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes) + +print(f"Rank {rank} has pointers {pointers}") + +dist.barrier() +torch.cuda.synchronize() + +if rank == 0: + # the first rank tries to write to all buffers + for p in pointers: + pointer = ctypes.c_void_p(p) + lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes) + +dist.barrier() +torch.cuda.synchronize() + +host_data = (ctypes.c_char * buffer_size_in_bytes)() + +# all ranks read from all buffers, and check if the data is correct +for p in pointers: + pointer = ctypes.c_void_p(p) + lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes) + for i in range(buffer_size_in_bytes): + assert ord(host_data[i]) == byte_value, ( + f"Rank {rank} failed" + f" to verify buffer {p}. Expected {byte_value}, " + f"got {ord(host_data[i])}") + +print(f"Rank {rank} verified all buffers") + +dist.barrier() +torch.cuda.synchronize() + +CustomAllreduce.free_shared_buffer(pointers) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c3632aee6d11a..3b5d92561cf25 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,3 +1,4 @@ +import ctypes from contextlib import contextmanager from typing import Any, List, Optional, Union @@ -7,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -174,6 +176,35 @@ def __init__(self, offsets, rank, self.full_nvlink) self.register_buffer(self.buffer) + @staticmethod + def create_shared_buffer( + size_in_bytes: int, + group: Optional[ProcessGroup] = None) -> List[int]: + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append( + lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer(pointers: List[int], + group: Optional[ProcessGroup] = None) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + @contextmanager def capture(self): """