diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 343679fe4442e..cf9a33af5d2db 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -214,11 +214,9 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -collective_fusion_pass: Optional[CollectiveFusionPass] = None - def wrap_inductor(graph, example_inputs, - additional_inductor_config, + additional_inductor_config = None, do_logging=False, runtime_shape: Optional[int] = None, use_inductor: bool = True): @@ -338,7 +336,8 @@ def run(self, *args): self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args ] - return super().run(*fake_args) + with self.fake_mode: + return super().run(*fake_args) def call_module(self, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument, @@ -418,11 +417,8 @@ def add_passes_to_config(self): passes = passes + [RedundantReshapesPass(config)] - if config.enable_fusion: - global collective_fusion_pass - if not collective_fusion_pass: - collective_fusion_pass = CollectiveFusionPass() - passes = passes + [collective_fusion_pass] + if True or config.enable_fusion: + passes = passes + [CollectiveFusionPass.instance(config)] passes = passes + [FusionPass.instance(config)] inductor_config = config.inductor_compile_config diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 1917b3c7ccf46..0eacf35d793e3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,6 +8,7 @@ import vllm._custom_ops as ops import vllm.envs as envs +from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) @@ -318,7 +319,26 @@ def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, class CollectiveFusionPass(InductorPass): - def __init__(self): + _instance: 'Optional[CollectiveFusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig): + """ + Get the singleton instance of the CollectiveFusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = CollectiveFusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + self.gemm_rs_ag_gemm_pattern = PatternMatcherPass() self.final_pattern = PatternMatcherPass() self.matches: List[Match] = []