Skip to content

Commit

Permalink
Revert nccl changes (#351)
Browse files Browse the repository at this point in the history
* Revert "[distributed] remove pynccl's redundant change_state (vllm-project#11749)"

This reverts commit 9e764e7.

* Revert "[distributed] remove pynccl's redundant stream (vllm-project#11744)"

This reverts commit 635b897.
  • Loading branch information
gshtras authored Jan 8, 2025
1 parent 88e020d commit c040f0e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 36 deletions.
65 changes: 38 additions & 27 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def worker_fn():
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
with pynccl_comm.change_state(enable=True):
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()

Expand All @@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 2).cpu().item()
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 2).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand Down Expand Up @@ -135,7 +137,9 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph):
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
graph.replay()
Expand Down Expand Up @@ -164,7 +168,8 @@ def all_gather_worker_fn():
for r in range(world_size)
]).to(device)

pynccl_comm.all_gather(result, tensor)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)

Expand Down Expand Up @@ -201,7 +206,8 @@ def reduce_scatter_worker_fn():
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device)

pynccl_comm.reduce_scatter(result, tensor)
with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)

Expand All @@ -228,13 +234,15 @@ def send_recv_worker_fn():
else:
tensor = torch.empty(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)

if pynccl_comm.rank == 0:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
assert torch.all(tensor == 1).cpu().item()

Expand Down Expand Up @@ -265,12 +273,15 @@ def multiple_send_recv_worker_fn():
1024,
dtype=torch.float32,
device=device)
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item()
Expand Down
43 changes: 35 additions & 8 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import Optional, Union

# ===================== import region =====================
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
Expand All @@ -58,6 +60,7 @@ def __init__(
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return

self.available = True
Expand Down Expand Up @@ -95,12 +98,12 @@ def __init__(
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.stream = torch.cuda.Stream()

stream = torch.cuda.current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
self.stream.synchronize()
del data

def all_reduce(self,
Expand All @@ -119,7 +122,7 @@ def all_reduce(self,
out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
Expand All @@ -141,7 +144,7 @@ def all_gather(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
Expand All @@ -162,7 +165,7 @@ def reduce_scatter(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
Expand All @@ -177,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -189,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -201,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
Expand All @@ -212,3 +215,27 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

@contextmanager
def change_state(self,
enable: Optional[bool] = None,
stream: Optional[torch.cuda.Stream] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available

if stream is None:
stream = self.stream

old_disable = self.disabled
old_stream = self.stream

self.stream = stream
self.disabled = not enable
yield

self.disabled = old_disable
self.stream = old_stream
10 changes: 9 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,15 @@ def graph_capture(
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context:
yield graph_capture_context
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit c040f0e

Please sign in to comment.