diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 20608cdaadae4..8b51e0f9bda47 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -45,11 +45,16 @@ def should_slice(shape) -> bool: and shape[0] >= FLUX_TILE_SIZE * n_slices) -# This is really inefficient. Should only pick the slice required. -def residual_slice_shape(residual, rank) -> List[torch.Size]: +def residual_slice_shape(residual, rank) -> int: + n_slices = get_tensor_model_parallel_world_size() + chunk, rem = divmod(residual.shape[0], n_slices) + return chunk if rank < n_slices - 1 or rem == 0 else rem + + +def residual_slice_shape_fake(residual, rank) -> int: n_slices = get_tensor_model_parallel_world_size() slices = torch.chunk(residual, n_slices, dim=0) - return slices[rank].shape + return slices[rank].shape[0] def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): @@ -170,7 +175,7 @@ def gemm_rs_ag_gemm( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(residual.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = residual_chunk[0] else: @@ -218,7 +223,7 @@ def gemm_rs_ag_gemm_fake( first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(gemm_1_activations.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape_fake(residual, rank) split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = split_1[0] else: