Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernels] Add an inductor pass to rewrite and fuse collective communication ops with gemms #9886

Open
wants to merge 72 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
61c79b3
Prototype integration of bytedance Flux kernels
tlrmchlsmth Jun 12, 2024
62b5ab6
wip
bnellnm Oct 7, 2024
93fe660
fix
bnellnm Oct 8, 2024
c317610
working naive
bnellnm Oct 14, 2024
57b3e74
working real
bnellnm Oct 14, 2024
296d65d
working real
bnellnm Oct 14, 2024
7c43068
work w/torch.compile
bnellnm Oct 14, 2024
aa61f87
work w/torch.compile
bnellnm Oct 14, 2024
17020db
add fuse_gemms flag to turn it on/off
bnellnm Oct 15, 2024
1f5fe34
pattern wip
bnellnm Oct 15, 2024
25400bf
wip
bnellnm Oct 15, 2024
1d3b3aa
final pattern
bnellnm Oct 16, 2024
91462a8
progress
bnellnm Oct 18, 2024
ab68b65
wip
bnellnm Oct 24, 2024
f516431
wip
bnellnm Oct 24, 2024
a5c9f8d
wip
bnellnm Oct 24, 2024
269e7f9
wip
bnellnm Oct 24, 2024
786bcc0
wip
bnellnm Oct 24, 2024
570de57
wip
bnellnm Oct 28, 2024
4aa4ab6
working
bnellnm Oct 29, 2024
f6435dc
fix matcher. naive working
bnellnm Oct 30, 2024
b654a8e
move collective fusion to separate file
bnellnm Oct 30, 2024
54dde90
move collective fusion to separate file
bnellnm Oct 30, 2024
e4b3871
fix fake function
bnellnm Oct 30, 2024
3468420
use InductorPass from @ProExpertProg's PR
bnellnm Oct 30, 2024
e2c9ef0
rebase
bnellnm Oct 31, 2024
81465d2
cleanups
bnellnm Oct 31, 2024
0dd9ca6
cleanups
bnellnm Oct 31, 2024
bb2f2d0
cleanups
bnellnm Oct 31, 2024
c78ce79
cleanups
bnellnm Oct 31, 2024
82fc807
cleanups renames
bnellnm Oct 31, 2024
2c0e799
fix formatting
bnellnm Oct 31, 2024
0dc3c04
cleanups
bnellnm Oct 31, 2024
59d2100
revert some hacks
bnellnm Oct 31, 2024
e0d7203
cleanups
bnellnm Oct 31, 2024
70c6250
back out llama model changes
bnellnm Oct 31, 2024
96d9756
back out models/utils changes
bnellnm Oct 31, 2024
707df6a
remove cruft
bnellnm Oct 31, 2024
04ec8ca
move utilities
bnellnm Nov 1, 2024
c3bb875
add flux support
bnellnm Nov 1, 2024
e422562
add flux support
bnellnm Nov 1, 2024
4edbcab
add types to flux kernels
bnellnm Nov 2, 2024
689d819
wip
bnellnm Nov 4, 2024
30d4fb9
improve perf.
bnellnm Nov 4, 2024
9bc764d
support custom ar
bnellnm Nov 5, 2024
72bb3b6
tweaks
bnellnm Nov 5, 2024
da034eb
factor out group names
bnellnm Nov 5, 2024
aa40131
factor out group names
bnellnm Nov 5, 2024
a54883b
factor out group names, cleanups, etc.
bnellnm Nov 5, 2024
0e2a024
fix some todos
bnellnm Nov 5, 2024
b01205d
find max m for flux kernels
bnellnm Nov 8, 2024
9f90853
rebase
bnellnm Nov 8, 2024
a21fb98
add error check
bnellnm Nov 8, 2024
7aa7546
review comments
bnellnm Nov 9, 2024
65fcaf5
format
bnellnm Nov 9, 2024
6ce19bd
fix cudagraph support
bnellnm Nov 9, 2024
ddc0b20
perf improvements
bnellnm Nov 9, 2024
039d285
cleanups
bnellnm Nov 9, 2024
515f56c
wip
bnellnm Nov 22, 2024
5a6be3c
rebase
bnellnm Nov 22, 2024
d4b0aa2
fixing
bnellnm Nov 22, 2024
2c15cd3
fix merge problems. make dump graph nicer
bnellnm Nov 25, 2024
da18a92
disable collective fusion when chunk size is too small
bnellnm Nov 25, 2024
bead129
fix mypy
bnellnm Nov 25, 2024
72953cc
fix yapf
bnellnm Nov 25, 2024
8724fab
disable collective fusion if TP is not on
bnellnm Nov 25, 2024
ec07de1
remove cruft
bnellnm Nov 25, 2024
6e26b9a
disable collective fusion pass if TP is not enabled
bnellnm Nov 25, 2024
f69ae53
wip
bnellnm Nov 26, 2024
41ab065
rebase + simplify
bnellnm Nov 26, 2024
b75cbba
rebase + simplify
bnellnm Nov 26, 2024
7e2c490
cleanup
bnellnm Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
from .utils import dump_graph

logger = init_logger(__name__)


def wrap_inductor(graph,
example_inputs,
additional_inductor_config,
do_logging=False,
def wrap_inductor(graph: fx.GraphModule,
example_inputs: Sequence[Any],
additional_inductor_config: Optional[Dict] = None,
do_logging: bool = False,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
if not use_inductor:
Expand All @@ -37,6 +38,10 @@ def wrap_inductor(graph,
logger.info("Compiling a graph for shape %s", runtime_shape)

from torch._inductor import config

# Enable support for symmetric memory ops in the inductor.
torch._inductor.config._micro_pipeline_tp = True

current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx

Expand Down Expand Up @@ -248,9 +253,19 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.compilation_configs.init_during_runtime()
self.configure_post_pass()

if ("before_split_graph"
in self.compilation_configs.pass_config.dump_graph_stages):
dump_graph(self.compilation_configs.pass_config, graph.graph,
"before_split_graph")

self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.splitting_ops)

if ("after_split_graph"
in self.compilation_configs.pass_config.dump_graph_stages):
dump_graph(self.compilation_configs.pass_config,
self.split_gm.graph, "after_split_graph")

from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
logger.debug("%s", lazy_format_graph_code("after split",
Expand Down
Loading