Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernels] Add an inductor pass to rewrite and fuse collective communication ops with gemms #9886

Open
wants to merge 72 commits into
base: main
Choose a base branch
from

Conversation

bnellnm
Copy link
Contributor

@bnellnm bnellnm commented Oct 31, 2024

Add an inductor pass to rewrite and fuse collective communication ops with gemms

See #9883 for version that includes llama hacks.

TODO:

cc @tlrmchlsmth , @ProExpertProg , @SageMoore , @youkaichao

Requires a special config to run:

config = CompilationConfig(
    level=3,
    custom_ops = ["+rms_norm"],
    splitting_ops = [],
)

llm = LLM(model=model,
          enforce_eager=eager,
          tensor_parallel_size=tp_size,
          disable_custom_all_reduce=not custom_ar,
          dtype=torch.float16,
          max_num_batched_tokens=2048,
          compilation_config=config)

Some benchmark results:

model = meta-llama/Llama-3.1-70B-Instruct
tp_size = 4
chunked prefill size = 2048
batch_size = 1
input_len=2048
output_len=1
Eager mode + torch.compile

Avg latency: 0.16625802051508798 seconds
10% percentile latency: 0.16468927392270416 seconds
25% percentile latency: 0.16511811560485512 seconds
50% percentile latency: 0.16571794101037085 seconds
75% percentile latency: 0.16671031567966565 seconds
90% percentile latency: 0.1675790420267731 seconds
99% percentile latency: 0.17226817809045325 seconds

Eager mode + torch.compile + flux

Avg latency: 0.1583265809295699 seconds
10% percentile latency: 0.15630255101714283 seconds
25% percentile latency: 0.15688058221712708 seconds
50% percentile latency: 0.15789097198285162 seconds
75% percentile latency: 0.15932484721997753 seconds
90% percentile latency: 0.16147575441282241 seconds
99% percentile latency: 0.16223905643215403 seconds

cudagraphs + torch.compile

Avg latency: 0.17894838895183057 seconds
10% percentile latency: 0.17591054290533065 seconds
25% percentile latency: 0.176349236513488 seconds
50% percentile latency: 0.17722250788938254 seconds
75% percentile latency: 0.17862555047031492 seconds
90% percentile latency: 0.18074012212455273 seconds
99% percentile latency: 0.2171030258946121 seconds

cudagraphs + torch.compile + flux

Avg latency: 0.17262270329520107 seconds
10% percentile latency: 0.17164990142919123 seconds
25% percentile latency: 0.17196793673792854 seconds
50% percentile latency: 0.1724927049363032 seconds
75% percentile latency: 0.1730666920193471 seconds
90% percentile latency: 0.17406681017018855 seconds
99% percentile latency: 0.1758251654729247 seconds

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Oct 31, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @bnellnm please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking forward to this one!

vllm/compilation/collective_fusion.py Outdated Show resolved Hide resolved
vllm/compilation/collective_fusion.py Outdated Show resolved Hide resolved
@bnellnm bnellnm force-pushed the collective-fusion branch 2 times, most recently from 0a1f637 to 1c9d79c Compare November 8, 2024 23:36
@mergify mergify bot removed the needs-rebase label Nov 8, 2024
@bnellnm bnellnm marked this pull request as ready for review November 9, 2024 23:10
Copy link

mergify bot commented Nov 11, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Nov 25, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
tlrmchlsmth and others added 6 commits November 26, 2024 19:49
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>


# Note: this heuristic is unique to flux
def use_cc_kernels(m_shape: int, n_slices: Optional[int] = None) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add _flux at the end of the function name to make the note clear?

Comment on lines +27 to +47
def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]:
for node in nodes:
if node.op == "call_function" and node.target == op:
return node
return None


def find_auto_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]:
for node in nodes:
if (node.op == "call_function" and node.target == auto_functionalized
and node.args[0] == op):
return node
return None


def find_getitem(node: fx.Node, idx: int) -> Optional[fx.Node]:
for user in node.users:
if (user.op == "call_function" and user.target == operator.getitem
and user.args[1] == idx):
return user
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be available from fx_utils after #10906.

FLUX_TILE_SIZE: int = 128


def use_cc_kernels(m_shape: int) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a separate function with the same name? The other one is flux-only?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what does use_cc_kernels even mean?

device_group = group.device_group
rank = group.rank_in_group

if use_flux:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe use a better abstraction than if statements based on use_flux?

rms_norm_weights: torch.Tensor,
gemm_2_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the permutation need to be in the match? As in the replacement won't be permuted?

Comment on lines +402 to +411
fused_node = graph.call_function(fused_gemm_func,
kwargs=kwargs)

graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
(fused_node, 0))
residual_node_new = graph.call_function(
operator.getitem, (fused_node, 1))
my_residual_node_new = graph.call_function(
operator.getitem, (fused_node, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think multi-output match has a utility that emits a function and tuple accessors.

Comment on lines +412 to +413
res_replacements.append(residual_node_new)
my_res_replacements.append(my_residual_node_new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we save all of the residuals instead of just the previous one?

raise ValueError("No nodes in graph")


def dump_graph(pass_config, graph: fx.Graph, name: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is going to get phased out in favor of @youkaichao's depyf

Comment on lines +2425 to +2438
if self.compilation_config.pass_config.enable_collective_fusion:
n_slices = self.parallel_config.world_size
max_tokens = self.scheduler_config.max_num_batched_tokens
if not use_cc_kernels(int(max_tokens / n_slices), n_slices):
logger.info(
("Disabling collective fusion pass since chunked prefill "
"size %d is too small."), max_tokens)
self.compilation_config.pass_config.enable_collective_fusion = \
False
if n_slices == 1:
logger.info("Disabling collective fusion pass since tensor "
"parallelism is not enabled.")
self.compilation_config.pass_config.enable_collective_fusion = \
False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this only live under V1? Shouldn't it also happen for V0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(so maybe put this under PassConfig.__post_init__)

Comment on lines +388 to +389
if gemm_1 is None or gemm_2 is None:
raise ValueError("Missing 'val' in gemm weights meta data")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler if you just do meta["val"]

Copy link

mergify bot commented Dec 19, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants