From 6d20979691b418d790493343c5d3b7e3bf8d4604 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 19:37:36 +0000 Subject: [PATCH] perf improvements Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 20608cdaadae4..8b51e0f9bda47 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -45,11 +45,16 @@ def should_slice(shape) -> bool: and shape[0] >= FLUX_TILE_SIZE * n_slices) -# This is really inefficient. Should only pick the slice required. -def residual_slice_shape(residual, rank) -> List[torch.Size]: +def residual_slice_shape(residual, rank) -> int: + n_slices = get_tensor_model_parallel_world_size() + chunk, rem = divmod(residual.shape[0], n_slices) + return chunk if rank < n_slices - 1 or rem == 0 else rem + + +def residual_slice_shape_fake(residual, rank) -> int: n_slices = get_tensor_model_parallel_world_size() slices = torch.chunk(residual, n_slices, dim=0) - return slices[rank].shape + return slices[rank].shape[0] def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): @@ -170,7 +175,7 @@ def gemm_rs_ag_gemm( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(residual.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = residual_chunk[0] else: @@ -218,7 +223,7 @@ def gemm_rs_ag_gemm_fake( first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(gemm_1_activations.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape_fake(residual, rank) split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = split_1[0] else: