-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernels] Add an inductor pass to rewrite and fuse collective communication ops with gemms #9886
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
b3200f8
to
5183999
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking forward to this one!
0a1f637
to
1c9d79c
Compare
e164973
to
1683f80
Compare
This pull request has merge conflicts that must be resolved before it can be |
1683f80
to
34de3a4
Compare
This pull request has merge conflicts that must be resolved before it can be |
ef2be0d
to
7ebd94c
Compare
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
d713a7d
to
7e2c490
Compare
|
||
|
||
# Note: this heuristic is unique to flux | ||
def use_cc_kernels(m_shape: int, n_slices: Optional[int] = None) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add _flux
at the end of the function name to make the note clear?
def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: | ||
for node in nodes: | ||
if node.op == "call_function" and node.target == op: | ||
return node | ||
return None | ||
|
||
|
||
def find_auto_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: | ||
for node in nodes: | ||
if (node.op == "call_function" and node.target == auto_functionalized | ||
and node.args[0] == op): | ||
return node | ||
return None | ||
|
||
|
||
def find_getitem(node: fx.Node, idx: int) -> Optional[fx.Node]: | ||
for user in node.users: | ||
if (user.op == "call_function" and user.target == operator.getitem | ||
and user.args[1] == idx): | ||
return user | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be available from fx_utils
after #10906.
FLUX_TILE_SIZE: int = 128 | ||
|
||
|
||
def use_cc_kernels(m_shape: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a separate function with the same name? The other one is flux-only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what does use_cc_kernels
even mean?
device_group = group.device_group | ||
rank = group.rank_in_group | ||
|
||
if use_flux: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe use a better abstraction than if statements based on use_flux
?
rms_norm_weights: torch.Tensor, | ||
gemm_2_weights: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the permutation need to be in the match? As in the replacement won't be permuted?
fused_node = graph.call_function(fused_gemm_func, | ||
kwargs=kwargs) | ||
|
||
graph.inserting_after(fused_node) | ||
result_node_new = graph.call_function(operator.getitem, | ||
(fused_node, 0)) | ||
residual_node_new = graph.call_function( | ||
operator.getitem, (fused_node, 1)) | ||
my_residual_node_new = graph.call_function( | ||
operator.getitem, (fused_node, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think multi-output match has a utility that emits a function and tuple accessors.
res_replacements.append(residual_node_new) | ||
my_res_replacements.append(my_residual_node_new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we save all of the residuals instead of just the previous one?
raise ValueError("No nodes in graph") | ||
|
||
|
||
def dump_graph(pass_config, graph: fx.Graph, name: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is going to get phased out in favor of @youkaichao's depyf
if self.compilation_config.pass_config.enable_collective_fusion: | ||
n_slices = self.parallel_config.world_size | ||
max_tokens = self.scheduler_config.max_num_batched_tokens | ||
if not use_cc_kernels(int(max_tokens / n_slices), n_slices): | ||
logger.info( | ||
("Disabling collective fusion pass since chunked prefill " | ||
"size %d is too small."), max_tokens) | ||
self.compilation_config.pass_config.enable_collective_fusion = \ | ||
False | ||
if n_slices == 1: | ||
logger.info("Disabling collective fusion pass since tensor " | ||
"parallelism is not enabled.") | ||
self.compilation_config.pass_config.enable_collective_fusion = \ | ||
False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this only live under V1? Shouldn't it also happen for V0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(so maybe put this under PassConfig.__post_init__
)
if gemm_1 is None or gemm_2 is None: | ||
raise ValueError("Missing 'val' in gemm weights meta data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler if you just do meta["val"]
This pull request has merge conflicts that must be resolved before it can be |
Add an inductor pass to rewrite and fuse collective communication ops with gemms
See #9883 for version that includes llama hacks.
TODO:
torch._inductor.ir.ExternKernel.__str__
pytorch/pytorch#139501cc @tlrmchlsmth , @ProExpertProg , @SageMoore , @youkaichao
Requires a special config to run:
Some benchmark results: