Skip to content

Commit

Permalink
custom allreduce + torch.compile (vllm-project#10121)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
2 people authored and weilong.yu committed Dec 13, 2024
1 parent 2eecda0 commit a43682f
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 101 deletions.
1 change: 0 additions & 1 deletion docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
pynccl.disabled = False
s = torch.cuda.Stream()
with torch.cuda.stream(s):
Expand Down
15 changes: 6 additions & 9 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def worker_fn():
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size

Expand All @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.all_reduce(tensor)
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2

Expand Down Expand Up @@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
pynccl_comm.all_reduce(a)
a_out = pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**0
graph.replay()
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**1
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1


@worker_fn_wrapper
Expand Down
2 changes: 0 additions & 2 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
Expand Down
26 changes: 13 additions & 13 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,30 @@ def __init__(
self.stream.synchronize()
del data

# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self.disabled = True

def all_reduce(self,
tensor: torch.Tensor,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
stream=None) -> torch.Tensor:
if self.disabled:
return
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {in_tensor.device}")

out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = self.stream
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor

def all_gather(self,
output_tensor: torch.Tensor,
Expand Down
110 changes: 36 additions & 74 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)


if supports_custom_op():

def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)

def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

if supports_custom_op():
direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
fake_impl=all_reduce_fake,
)


Expand Down Expand Up @@ -317,30 +299,13 @@ def graph_capture(
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context

Expand All @@ -356,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
a new tensor in the same op. So we always make the all-reduce operation
out-of-place.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
Expand All @@ -368,10 +333,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_

if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
Expand All @@ -385,30 +346,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)

def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out

def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
Expand Down Expand Up @@ -570,8 +571,9 @@ def capture_model(self) -> None:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)
with graph_capture():
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down

0 comments on commit a43682f

Please sign in to comment.