From 0437492ea97f0650a8b2ca39121be8864625fd70 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 31 Jul 2024 20:15:42 -0700 Subject: [PATCH] PP comm optimization: replace send with partial send + allgather (#6695) Co-authored-by: Aurick Qiao --- vllm/distributed/parallel_state.py | 38 ++++++++++++++++++++++++++++-- vllm/worker/worker_base.py | 8 ++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bf7a7de0724af..d7ca8fd82e1a2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -569,7 +569,8 @@ def broadcast_tensor_dict( def send_tensor_dict( self, tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. @@ -578,6 +579,11 @@ def send_tensor_dict( if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + group = self.device_group metadata_group = self.cpu_group @@ -598,6 +604,12 @@ def send_tensor_dict( if tensor.numel() == 0: # Skip sending empty tensors. continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.send(tensor, @@ -612,7 +624,8 @@ def send_tensor_dict( def recv_tensor_dict( self, - src: Optional[int] = None + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. @@ -621,6 +634,11 @@ def recv_tensor_dict( if not torch.distributed.is_initialized() or self.world_size == 1: return None + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + group = self.device_group metadata_group = self.cpu_group @@ -639,6 +657,16 @@ def recv_tensor_dict( # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, @@ -649,6 +677,12 @@ def recv_tensor_dict( torch.distributed.recv(tensor, src=self.ranks[src], group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + tensor_dict[key] = tensor else: tensor_dict[key] = value diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 03e3857e23c4b..8a4d1958c65a0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,7 +6,7 @@ import torch -from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform @@ -267,7 +267,8 @@ def execute_model( intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict()) + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] @@ -276,7 +277,8 @@ def execute_model( if not get_pp_group().is_last_rank: # output is IntermediateTensors - get_pp_group().send_tensor_dict(output.tensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) return [None] # output is List[SamplerOutput]