From 76f1658fd1c04a81c653e4a40c0b8393d42d384e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 8 Nov 2024 22:07:41 +0000 Subject: [PATCH] find max m for flux kernels Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 45 ++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 20e0499ff4bbc..1917b3c7ccf46 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -87,9 +87,14 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, - gemm_1_weights: torch.Size, gemm_2_type, - gemm_2_weights: torch.Size, tp_group_name: str): +def get_gemm_rs_ag_gemm(use_flux: bool, + gemm_1_type, + gemm_1_weights: torch.Size, + gemm_1_max_m: int, + gemm_2_type, + gemm_2_weights: torch.Size, + gemm_2_max_m: int, + tp_group_name: str): group = get_group_from_group_name(tp_group_name) device_group = group.device_group @@ -99,7 +104,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_rs_op = flux.GemmRS( device_group, 1, # One node - 8192, # Max M. TODO: Pass in correctly. + gemm_1_max_m, gemm_1_weights[0], # N # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. @@ -114,7 +119,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, ag_gemm_op = flux.AGKernel( device_group, 1, # One node - 8192, # Max M. TODO: Pass in correctly. + gemm_2_max_m, gemm_2_weights[0], # N gemm_2_weights[1], # K # TODO: It would be nicer to modify flux to dispatch based on dtype @@ -138,8 +143,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_2_str = str(gemm_2_type).removeprefix("torch.") group_str = tp_group_name.replace(":", "_") name = ( - f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" - f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{group_str}" + f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_" + f"{group_str}" ) else: world_group_name = get_world_name() @@ -275,7 +281,7 @@ def match_final( # Register this as a custom op since all reduce cannot be torch.compiled yet. -def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, +def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -296,7 +302,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return reduced -def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, +def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]], @@ -305,9 +311,9 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, direct_register_custom_op("gemm_ag_final", - replace_final, + gemm_ag_final, mutates_args=[], - fake_impl=replace_final_fake) + fake_impl=gemm_ag_final_fake) class CollectiveFusionPass(InductorPass): @@ -360,6 +366,18 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False + def find_max_m(self, matches) -> Tuple[int, int]: + gemm_1_max_m = 0 + gemm_2_max_m = 0 + for m in matches: + gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] + gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] + gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[0]) + gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[0]) + assert gemm_1_max_m > 0 + assert gemm_2_max_m > 0 + return gemm_1_max_m, gemm_2_max_m + def process_matches(self, graph: fx.Graph): nodes = list(graph.nodes) @@ -372,6 +390,9 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] + gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches) + logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m) + for match in matches: last_node = last_node_in_match(match) @@ -402,8 +423,10 @@ def find_min_index(match: Match) -> int: fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype, gemm_1.shape, + gemm_1_max_m, gemm_2.dtype, gemm_2.shape, + gemm_2_max_m, tp_group_name) fused_node = graph.call_function(fused_gemm_func,