diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9be16bdb42ac6..4dfad8893d969 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -215,7 +215,7 @@ def fix_functionalization(graph: fx.Graph): def wrap_inductor(graph, example_inputs, - additional_inductor_config=None, + additional_inductor_config: Optional[Dict] = None, do_logging=False, runtime_shape: Optional[int] = None, use_inductor: bool = True): @@ -233,7 +233,7 @@ def wrap_inductor(graph, from torch._inductor import config torch._inductor.config._micro_pipeline_tp = True - # Set to False to avoid infinite recursion logging + # Set to False to avoid infinite recursion logging? torch._inductor.config.implicit_fallbacks = True current_config = config.shallow_copy_dict() diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 8b51e0f9bda47..0f76e7a4c1260 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,12 +1,11 @@ import operator -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.fx as fx from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -import vllm._custom_ops as ops import vllm.envs as envs from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass @@ -33,31 +32,31 @@ FLUX_TILE_SIZE: int = 128 -# TODO: is this right? +# TODO: is this ok? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name # Note: this heuristic is unique to flux -def should_slice(shape) -> bool: +def should_slice(shape: torch.Size) -> bool: n_slices = get_tensor_model_parallel_world_size() return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0 and shape[0] >= FLUX_TILE_SIZE * n_slices) -def residual_slice_shape(residual, rank) -> int: +def residual_slice_shape(residual: torch.Tensor, rank: int) -> int: n_slices = get_tensor_model_parallel_world_size() chunk, rem = divmod(residual.shape[0], n_slices) return chunk if rank < n_slices - 1 or rem == 0 else rem -def residual_slice_shape_fake(residual, rank) -> int: +def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int: n_slices = get_tensor_model_parallel_world_size() slices = torch.chunk(residual, n_slices, dim=0) return slices[rank].shape[0] -def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): +def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool) -> Callable: def match_gemm_rs_ag_gemm( residual: torch.Tensor, @@ -69,8 +68,8 @@ def match_gemm_rs_ag_gemm( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - # It would be nice to do this instead of having two separate patterns - #all_reduce = tensor_model_parallel_all_reduce(mm_1) + # It would be nice to do this instead of having two separate patterns. + # all_reduce = tensor_model_parallel_all_reduce(mm_1) if custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( mm_1, tp_group_name) @@ -98,10 +97,10 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, - gemm_1_weights: torch.Size, gemm_1_max_m: int, - gemm_2_type, gemm_2_weights: torch.Size, - gemm_2_max_m: int, tp_group_name: str): +def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, + gemm_1_weights: torch.Size, gemm_2_type: torch.dtype, + gemm_2_weights: torch.Size, + tp_group_name: str) -> Callable: group = get_group_from_group_name(tp_group_name) device_group = group.device_group @@ -111,7 +110,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_rs_op = flux.GemmRS( device_group, 1, # One node - gemm_1_max_m, # M + max_m, # max M gemm_1_weights[0], # N # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. @@ -126,7 +125,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, ag_gemm_op = flux.AGKernel( device_group, 1, # One node - gemm_2_max_m, # M + max_m, # max M gemm_2_weights[0], # N gemm_2_weights[1], # K # TODO: It would be nicer to modify flux to dispatch based on dtype @@ -149,10 +148,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_1_str = str(gemm_1_type).removeprefix("torch.") gemm_2_str = str(gemm_2_type).removeprefix("torch.") group_str = tp_group_name.replace(":", "_") - name = ( - f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_" - f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_" - f"{group_str}") + name = (f"gemm_rs_ag_gemm_{max_m}_{gemm_1_str}_{gemm_1_weights[0]}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_" + f"{group_str}") else: world_group_name = get_world_name() @@ -187,10 +185,10 @@ def gemm_rs_ag_gemm( gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) - ops.fused_add_rms_norm(input=reduced_output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=reduced_output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) mm_2 = torch.ops.aten.mm.default(reduced_output, gemm_2_weights.transpose(1, 0)) @@ -198,16 +196,20 @@ def gemm_rs_ag_gemm( else: output = gemm_rs(gemm_1_activations, gemm_1_weights) - ops.fused_add_rms_norm(input=output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) residual_1 = residual if first_layer else old_my_residual - slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, my_residual, 0, 0, slice_shape) - split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) - new_residual = split_2[0] + #if False: + #slice_scatter = torch.ops.aten.slice_scatter.default( + # residual_1, my_residual, 0, 0, slice_shape) + #split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) + #new_residual = split_2[0] + #else: + slice_scatter = my_residual + new_residual = residual_1 mm_2 = ag_gemm(output, gemm_2_weights) @@ -248,7 +250,7 @@ def gemm_rs_ag_gemm_fake( return getattr(torch.ops.vllm, name).default -def get_match_final(tp_group_name: str, use_custom_ar: bool): +def get_match_final(tp_group_name: str, use_custom_ar: bool) -> Callable: def match_final( my_residual: torch.Tensor, @@ -260,7 +262,7 @@ def match_final( mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) # TODO: it would be nice to be able to use the official api directly. - #all_reduce = tensor_model_parallel_all_reduce(mm_1) + # all_reduce = tensor_model_parallel_all_reduce(mm_1) if use_custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( mm_1, tp_group_name) @@ -288,8 +290,8 @@ def match_final( def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: - permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) reduced = tensor_model_parallel_all_reduce(mm_1) @@ -298,10 +300,10 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - ops.fused_add_rms_norm(input=reduced, - residual=wait_tensor, - weight=rms_norm_weights, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=reduced, + residual=wait_tensor, + weight=rms_norm_weights, + epsilon=1e-05) return reduced @@ -325,7 +327,7 @@ class CollectiveFusionPass(InductorPass): _instance: 'Optional[CollectiveFusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig): + def instance(cls, config: CompilationConfig) -> "CollectiveFusionPass": """ Get the singleton instance of the CollectiveFusionPass. If the instance exists, the config is updated but @@ -358,8 +360,8 @@ def __init__(self, config): final_inputs = [x, w, resid, resid_w] # register multiple patterns for all group names. - max_gpus = 8 # TODO: get this officially - group_names = [f"tp:{rank}" for rank in range(max_gpus)] + world_size = get_tensor_model_parallel_world_size() + group_names = [f"tp:{rank}" for rank in range(world_size)] for group_name in group_names: for m in [ @@ -389,24 +391,15 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False - def find_max_m(self, matches) -> Tuple[int, int]: - gemm_1_max_m = 0 - gemm_2_max_m = 0 + def find_max_m(self, matches: List[Match]) -> int: + max_m = 0 for m in matches: - #gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] - #gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] - #gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) - #gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) - gemm_1 = m.kwargs["residual"].meta["val"] - gemm_2 = m.kwargs["residual"].meta["val"] - gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) - gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) - - assert gemm_1_max_m > 0 - assert gemm_2_max_m > 0 - return gemm_1_max_m, gemm_2_max_m - - def process_matches(self, graph: fx.Graph): + residual = m.kwargs["residual"].meta["val"] + max_m = max(max_m, residual.shape[1]) + assert max_m > 0 + return max_m + + def process_matches(self, graph: fx.Graph) -> None: nodes = list(graph.nodes) def find_min_index(match: Match) -> int: @@ -418,8 +411,8 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] - gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches) - logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m) + max_m = self.find_max_m(matches) + logger.info("max m = %d", max_m) for match in matches: last_node = last_node_in_match(match) @@ -451,8 +444,8 @@ def find_min_index(match: Match) -> int: tp_group_name = ar_node.args[1] fused_gemm_func = get_gemm_rs_ag_gemm( - use_flux, gemm_1.dtype, gemm_1.shape, gemm_1_max_m, - gemm_2.dtype, gemm_2.shape, gemm_2_max_m, tp_group_name) + use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, + gemm_2.shape, tp_group_name) fused_node = graph.call_function(fused_gemm_func, kwargs=kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4b72dfe9bf105..e1446192ce3d6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1135,9 +1135,8 @@ def load_model(self) -> None: if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ and supports_dynamo(): - from vllm.compilation.backends import wrap_inductor from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or wrap_inductor #"eager" + backend = get_torch_compile_backend() or "eager" self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,