Skip to content

Commit

Permalink
find max m for flux kernels
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 8, 2024
1 parent e4ddf6d commit 76f1658
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ def match_gemm_rs_ag_gemm(
return match_gemm_rs_ag_gemm


def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_1_weights: torch.Size, gemm_2_type,
gemm_2_weights: torch.Size, tp_group_name: str):
def get_gemm_rs_ag_gemm(use_flux: bool,
gemm_1_type,
gemm_1_weights: torch.Size,
gemm_1_max_m: int,
gemm_2_type,
gemm_2_weights: torch.Size,
gemm_2_max_m: int,
tp_group_name: str):

group = get_group_from_group_name(tp_group_name)
device_group = group.device_group
Expand All @@ -99,7 +104,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_rs_op = flux.GemmRS(
device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
gemm_1_max_m,
gemm_1_weights[0], # N
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
Expand All @@ -114,7 +119,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
ag_gemm_op = flux.AGKernel(
device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
gemm_2_max_m,
gemm_2_weights[0], # N
gemm_2_weights[1], # K
# TODO: It would be nicer to modify flux to dispatch based on dtype
Expand All @@ -138,8 +143,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_2_str = str(gemm_2_type).removeprefix("torch.")
group_str = tp_group_name.replace(":", "_")
name = (
f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_"
f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{group_str}"
f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_"
f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_"
f"{group_str}"
)
else:
world_group_name = get_world_name()
Expand Down Expand Up @@ -275,7 +281,7 @@ def match_final(


# Register this as a custom op since all reduce cannot be torch.compiled yet.
def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
Expand All @@ -296,7 +302,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
return reduced


def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]],
Expand All @@ -305,9 +311,9 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,


direct_register_custom_op("gemm_ag_final",
replace_final,
gemm_ag_final,
mutates_args=[],
fake_impl=replace_final_fake)
fake_impl=gemm_ag_final_fake)


class CollectiveFusionPass(InductorPass):
Expand Down Expand Up @@ -360,6 +366,18 @@ def record_match(self, match: Match) -> bool:
# Return False to prevent automatic replacement.
return False

def find_max_m(self, matches) -> Tuple[int, int]:
gemm_1_max_m = 0
gemm_2_max_m = 0
for m in matches:
gemm_1 = m.kwargs["gemm_1_weights"].meta["val"]
gemm_2 = m.kwargs["gemm_2_weights"].meta["val"]
gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[0])
gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[0])
assert gemm_1_max_m > 0
assert gemm_2_max_m > 0
return gemm_1_max_m, gemm_2_max_m

def process_matches(self, graph: fx.Graph):
nodes = list(graph.nodes)

Expand All @@ -372,6 +390,9 @@ def find_min_index(match: Match) -> int:
res_replacements: List[fx.Node] = []
my_res_replacements: List[fx.Node] = []

gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches)
logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m)

for match in matches:
last_node = last_node_in_match(match)

Expand Down Expand Up @@ -402,8 +423,10 @@ def find_min_index(match: Match) -> int:

fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype,
gemm_1.shape,
gemm_1_max_m,
gemm_2.dtype,
gemm_2.shape,
gemm_2_max_m,
tp_group_name)

fused_node = graph.call_function(fused_gemm_func,
Expand Down

0 comments on commit 76f1658

Please sign in to comment.