Skip to content

Commit

Permalink
PP comm optimization: replace send with partial send + allgather (#6695)
Browse files Browse the repository at this point in the history
Co-authored-by: Aurick Qiao <[email protected]>
  • Loading branch information
aurickq and sfc-gh-aqiao authored Aug 1, 2024
1 parent 630dd9e commit 0437492
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
38 changes: 36 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,8 @@ def broadcast_tensor_dict(
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
Expand All @@ -578,6 +579,11 @@ def send_tensor_dict(
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict

all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)

group = self.device_group
metadata_group = self.cpu_group

Expand All @@ -598,6 +604,12 @@ def send_tensor_dict(
if tensor.numel() == 0:
# Skip sending empty tensors.
continue

# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
Expand All @@ -612,7 +624,8 @@ def send_tensor_dict(

def recv_tensor_dict(
self,
src: Optional[int] = None
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
Expand All @@ -621,6 +634,11 @@ def recv_tensor_dict(
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)

group = self.device_group
metadata_group = self.cpu_group

Expand All @@ -639,6 +657,16 @@ def recv_tensor_dict(
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue

# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)

if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]

if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
Expand All @@ -649,6 +677,12 @@ def recv_tensor_dict(
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)

tensor_dict[key] = tensor
else:
tensor_dict[key] = value
Expand Down
8 changes: 5 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
Expand Down Expand Up @@ -267,7 +267,8 @@ def execute_model(
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))

output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
Expand All @@ -276,7 +277,8 @@ def execute_model(

if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]

# output is List[SamplerOutput]
Expand Down

0 comments on commit 0437492

Please sign in to comment.