Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 8, 2024
1 parent 76f1658 commit 1c9d79c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
14 changes: 5 additions & 9 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,9 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f)


collective_fusion_pass: Optional[CollectiveFusionPass] = None

def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
additional_inductor_config = None,
do_logging=False,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
Expand Down Expand Up @@ -338,7 +336,8 @@ def run(self, *args):
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
return super().run(*fake_args)
with self.fake_mode:
return super().run(*fake_args)

def call_module(self, target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument,
Expand Down Expand Up @@ -418,11 +417,8 @@ def add_passes_to_config(self):

passes = passes + [RedundantReshapesPass(config)]

if config.enable_fusion:
global collective_fusion_pass
if not collective_fusion_pass:
collective_fusion_pass = CollectiveFusionPass()
passes = passes + [collective_fusion_pass]
if True or config.enable_fusion:

Check failure on line 420 in vllm/compilation/backends.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (SIM222)

vllm/compilation/backends.py:420:12: SIM222 Use `True` instead of `True or ...`
passes = passes + [CollectiveFusionPass.instance(config)]
passes = passes + [FusionPass.instance(config)]

inductor_config = config.inductor_compile_config
Expand Down
22 changes: 21 additions & 1 deletion vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import vllm._custom_ops as ops
import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
last_node_in_match)
Expand Down Expand Up @@ -318,7 +319,26 @@ def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,

class CollectiveFusionPass(InductorPass):

def __init__(self):
_instance: 'Optional[CollectiveFusionPass]' = None

Check failure on line 322 in vllm/compilation/collective_fusion.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F821)

vllm/compilation/collective_fusion.py:322:17: F821 Undefined name `Optional`

Check failure on line 322 in vllm/compilation/collective_fusion.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "Optional" is not defined [name-defined]

Check failure on line 322 in vllm/compilation/collective_fusion.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "Optional" is not defined [name-defined]

Check failure on line 322 in vllm/compilation/collective_fusion.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "Optional" is not defined [name-defined]

Check failure on line 322 in vllm/compilation/collective_fusion.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "Optional" is not defined [name-defined]

@classmethod
def instance(cls, config: CompilationConfig):
"""
Get the singleton instance of the CollectiveFusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = CollectiveFusionPass(config)
else:
cls._instance.config = config
return cls._instance

def __init__(self, config):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)

self.gemm_rs_ag_gemm_pattern = PatternMatcherPass()
self.final_pattern = PatternMatcherPass()
self.matches: List[Match] = []
Expand Down

0 comments on commit 1c9d79c

Please sign in to comment.