Skip to content

Commit

Permalink
perf improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 9, 2024
1 parent b327c4a commit 6d20979
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def should_slice(shape) -> bool:
and shape[0] >= FLUX_TILE_SIZE * n_slices)


# This is really inefficient. Should only pick the slice required.
def residual_slice_shape(residual, rank) -> List[torch.Size]:
def residual_slice_shape(residual, rank) -> int:
n_slices = get_tensor_model_parallel_world_size()
chunk, rem = divmod(residual.shape[0], n_slices)
return chunk if rank < n_slices - 1 or rem == 0 else rem


def residual_slice_shape_fake(residual, rank) -> int:
n_slices = get_tensor_model_parallel_world_size()
slices = torch.chunk(residual, n_slices, dim=0)
return slices[rank].shape
return slices[rank].shape[0]


def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool):
Expand Down Expand Up @@ -170,7 +175,7 @@ def gemm_rs_ag_gemm(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if first_layer and should_slice(residual.shape):
slice_shape = residual_slice_shape(residual, rank)[0]
slice_shape = residual_slice_shape(residual, rank)
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape)
my_residual = residual_chunk[0]
else:
Expand Down Expand Up @@ -218,7 +223,7 @@ def gemm_rs_ag_gemm_fake(
first_layer: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if first_layer and should_slice(gemm_1_activations.shape):
slice_shape = residual_slice_shape(residual, rank)[0]
slice_shape = residual_slice_shape_fake(residual, rank)
split_1 = torch.ops.aten.split.Tensor(residual, slice_shape)
my_residual = split_1[0]
else:
Expand Down

0 comments on commit 6d20979

Please sign in to comment.