Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom allreduce + torch.compile #10121

Merged
merged 15 commits into from
Nov 26, 2024
21 changes: 13 additions & 8 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,29 @@ def __init__(
self.disabled = True

def all_reduce(self,
tensor: torch.Tensor,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
stream=None) -> torch.Tensor:
if self.disabled:
return
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {in_tensor.device}")

out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = self.stream
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
Expand Down
107 changes: 37 additions & 70 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,43 +96,24 @@
_groups[group.unique_name] = weakref.ref(group)


if supports_custom_op():
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)

def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)

def outplace_all_reduce(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)

def outplace_all_reduce_fake(tensor: torch.Tensor,
group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)

direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
)


class GroupCoordinator:
Expand Down Expand Up @@ -369,8 +350,8 @@
return input_

if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
return torch.ops.vllm.all_reduce(input_,
group_name=self.unique_name)

if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
Expand All @@ -384,31 +365,23 @@
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.vllm.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)

def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out

def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
# if supports_custom_op():
# ca_comm = self.ca_comm
# if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar(

Check failure on line 373 in vllm/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/distributed/parallel_state.py:373:81: E501 Line too long (91 > 80)
# input_):
# out = ca_comm.custom_all_reduce(input_)
# assert out is not None
# return out
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can change pynccl to be always enabled.

stream=torch.cuda.current_stream()):
out = pynccl_comm.all_reduce(input_)
assert out is not None
return out

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
Expand Down Expand Up @@ -448,8 +421,7 @@
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor

Expand All @@ -473,8 +445,8 @@
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
return self.xpu_communicator.gather(input_, self.rank_in_group, dst,
dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
Expand Down Expand Up @@ -585,8 +557,7 @@
assert src < self.world_size, f"Invalid src rank ({src})"

assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank."
)
"Invalid source rank. Source rank is the same as the current rank.")

size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

Expand Down Expand Up @@ -748,9 +719,7 @@
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=group)
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None

def recv_tensor_dict(
Expand Down Expand Up @@ -934,8 +903,7 @@


def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
"pipeline model parallel group is not initialized")
assert _PP is not None, ("pipeline model parallel group is not initialized")
return _PP


Expand Down Expand Up @@ -1076,8 +1044,7 @@
num_pipeline_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size)
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
assert _PP is None, ("pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
Expand Down
Loading