From e4ddf6df5dfe2180086cd12354834b8947bacdd0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 22:58:24 +0000 Subject: [PATCH] fix some todos Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 103 +++++++++++++------------- 1 file changed, 51 insertions(+), 52 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index ab64476499b20..20e0499ff4bbc 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -14,8 +14,7 @@ from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( - get_group_from_group_name, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_group_from_group_name, get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.utils import direct_register_custom_op @@ -31,12 +30,12 @@ use_flux = False -# how to do this properly? +# TODO: is this right? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name -# This check is a hack +# TODO: This check is a hack def should_slice(shape) -> bool: n_slices = get_tensor_model_parallel_world_size() return (shape[0] % n_slices == 0 and shape[0] >= 128) @@ -59,6 +58,7 @@ def match_gemm_rs_ag_gemm( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + # It would be nice to do this instead of having two separate patterns #all_reduce = tensor_model_parallel_all_reduce(mm_1) if custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( @@ -87,47 +87,20 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def gemm_rs_ag_gemm_fake( - residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if first_layer and should_slice(gemm_1_activations.shape): - res_slices = slice_residual(residual) - # is this rank ok? - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = split_1[0] - else: - my_residual = residual - - # verify the type is always correct - mm_res = torch.empty( - (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), - device=gemm_1_activations.device, - dtype=gemm_1_activations.dtype) - - return (mm_res, my_residual, residual) - - -# TODO: factor out groupnames, etc. def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_1_weights: torch.Size, gemm_2_type, gemm_2_weights: torch.Size, tp_group_name: str): + group = get_group_from_group_name(tp_group_name) + device_group = group.device_group + rank = group.rank_in_group + if use_flux: - device_group = get_group_from_group_name(tp_group_name).device_group gemm_rs_op = flux.GemmRS( device_group, 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_1_weights[0], # N - # TODO: Pass in input dtype correctly. # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. @@ -144,7 +117,6 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, 8192, # Max M. TODO: Pass in correctly. gemm_2_weights[0], # N gemm_2_weights[1], # K - # TODO: Pass in input dtype correctly. # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. @@ -192,12 +164,11 @@ def gemm_rs_ag_gemm( if first_layer and should_slice(residual.shape): res_slices = slice_residual(residual) - # is this rank ok? - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + slice_size = res_slices[rank].shape[0] residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = residual_chunk[0] else: - my_residual = residual #.clone() + my_residual = residual slice_size = residual.shape[0] if not should_slice(residual.shape): @@ -225,14 +196,37 @@ def gemm_rs_ag_gemm( slice_scatter = torch.ops.aten.slice_scatter.default( residual_1, my_residual, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) - - # TODO: can we avoid clone here? - new_residual = split_2[0] #.clone() + new_residual = split_2[0] mm_2 = ag_gemm(output, gemm_2_weights) return mm_2[0], new_residual, slice_scatter + def gemm_rs_ag_gemm_fake( + residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if first_layer and should_slice(gemm_1_activations.shape): + res_slices = slice_residual(residual) + slice_size = res_slices[rank].shape[0] + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = split_1[0] + else: + my_residual = residual + + # TODO: verify the type is always correct + mm_res = torch.empty( + (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), + device=gemm_1_activations.device, + dtype=gemm_1_activations.dtype) + + return (mm_res, my_residual, residual) + if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) direct_register_custom_op(name, @@ -255,6 +249,7 @@ def match_final( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + # TODO: it would be nice to be able to use the official api directly. #all_reduce = tensor_model_parallel_all_reduce(mm_1) if use_custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( @@ -322,6 +317,8 @@ def __init__(self): self.final_pattern = PatternMatcherPass() self.matches: List[Match] = [] + # Run in fake mode so that we don't call real functions + # when tracing the patterns. with torch._dynamo.utils.detect_fake_mode(): x = torch.empty([4, 4], device='cuda') w = torch.empty([4, 4], device='cuda') @@ -351,13 +348,9 @@ def __init__(self): get_match_final(group_name, False), get_match_final(group_name, True) ]: - register_replacement( - m, - torch.ops.vllm.gemm_ag_final, - #replace_final, - final_inputs, - fwd_only, - [self.final_pattern]) + register_replacement(m, torch.ops.vllm.gemm_ag_final, + final_inputs, fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -394,6 +387,8 @@ def find_min_index(match: Match) -> int: gemm_1 = kwargs["gemm_1_weights"].meta["val"] gemm_2 = kwargs["gemm_2_weights"].meta["val"] + # Extract group_name from matched code. Use to + # generate proper replacement code. ar_node = find_auto_fn( match.nodes, torch.ops.vllm.inplace_all_reduce.default) if ar_node is not None: @@ -405,9 +400,13 @@ def find_min_index(match: Match) -> int: assert ar_node is not None tp_group_name = ar_node.args[1] - fused_node = graph.call_function(get_gemm_rs_ag_gemm( - use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, - gemm_2.shape, tp_group_name), + fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype, + gemm_1.shape, + gemm_2.dtype, + gemm_2.shape, + tp_group_name) + + fused_node = graph.call_function(fused_gemm_func, kwargs=kwargs) graph.inserting_after(fused_node)