Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom allreduce + torch.compile #10121

Merged
merged 15 commits into from
Nov 26, 2024
1 change: 0 additions & 1 deletion docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator

pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
pynccl.disabled = False

s = torch.cuda.Stream()
with torch.cuda.stream(s):
Expand Down
15 changes: 6 additions & 9 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def worker_fn():
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size

Expand All @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.all_reduce(tensor)
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2

Expand Down Expand Up @@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
pynccl_comm.all_reduce(a)
a_out = pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**0
graph.replay()
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**1
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1


@worker_fn_wrapper
Expand Down
2 changes: 0 additions & 2 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
Expand Down
26 changes: 13 additions & 13 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,30 @@ def __init__(
self.stream.synchronize()
del data

# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self.disabled = True

def all_reduce(self,
tensor: torch.Tensor,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
stream=None) -> torch.Tensor:
if self.disabled:
return
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {in_tensor.device}")

out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = self.stream
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor

def all_gather(self,
output_tensor: torch.Tensor,
Expand Down
110 changes: 36 additions & 74 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)


if supports_custom_op():

def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)

def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

if supports_custom_op():
direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
fake_impl=all_reduce_fake,
)


Expand Down Expand Up @@ -317,30 +299,13 @@ def graph_capture(
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context

Expand All @@ -356,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
coordinator.

In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
a new tensor in the same op. So we always make the all-reduce operation
out-of-place.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
Expand All @@ -368,10 +333,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_

if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
Expand All @@ -385,30 +346,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)

def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out

def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
Comment on lines +362 to +365
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a little confused about what this TODO meant, so I had to dig a bit.

Looks like PyNcclCommunicator creates a new stream in its __init__ method and uses it by default so we always have to pass in the current stream. Do you know it behaves this way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly historical. we can remove it. but i don't want to do it in this pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree.

if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
Expand Down Expand Up @@ -570,8 +571,9 @@ def capture_model(self) -> None:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)
with graph_capture():
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down