Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Support root->logical transforms in Fusion inputs #3366

Closed
jacobhinkle opened this issue Nov 7, 2024 · 5 comments
Closed

[Proposal] Support root->logical transforms in Fusion inputs #3366

jacobhinkle opened this issue Nov 7, 2024 · 5 comments
Assignees

Comments

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 7, 2024

NOTICE: See #3372

This is a proposal to fully support Fusion input TensorViews to contain non-trivial root domains. The ATen tensor passed should then match the root domain of the fusion input, not the logical domain.

Motivation

The primary motivation for this proposal is basically #1628. Usually for Hopper matmul we will want to load both operands to smem using TMA, then directly call the mma instruction using those smem operands. If the Fusion inputs are [M, K] and [N, K], they must be broadcasted to [M, 1, K] and [1, N, K] before they can pass through the MmaOp, which we do using a BroadcastOp in mma_utils::MatmulPattern::translateToMmaOp(). This introduces a tensor that we can't get rid of in our current system.

Approach

I propose that we do the following:

  • Create a utility function that will push broadcasts in a c2p direction as far as possible, similar to zipping up CatOp.
  • Update ExpressionEvaluator and bind tv->getMaybeRootDomain() instead of tv->getLogicalDomain() to the received shapes of input tensors.
  • Update SchedulerRuntimeInfo which also handles the at::Tensor and needs to know about the root and/or logical domain.

I believe this is all that is needed, since we don't actually use the root domain for input tensors and broadcasts should not affect the actual memory layout so the allocation domain matching the logical instead of what is in the ATen tensor is not a problem.

Details

Suppose we have

Inputs:
  tv0 [ i0 ]
  tv1 [ i0, i1 ]

tv2 [ i0 ] = neg(tv0)
tv3 [ i0, b1 ] = broadcast(tv2)
tv4 [ i0, i1 ] = mul(tv3, tv1)

We can translate this to the following:

Inputs:
  tv7 [ i0, b1 ] (root = [ i0 ])
  tv1 [ i0, i1 ]

tv6 [ i0, b1 ] = neg(tv7)
tv5 [ i0, i1 ] = mul(tv6, tv1)

Specifically, what was done:

  • Detect that tv3 is the result of a BroadcastOp
  • Replace tv3 with a new TV called tv5 that is already broadcast (i.e. there is no root domain) and has same definition as tv2
  • Propagate backwards: tv6 is just like tv2 but has the new broadcast domain in its logical domain. Again tv2 has no root domain.
  • When we reach the input tv0, we set the logical domain to have the broadcast, and the original logical domain is set to root.

Possible challenges

Allreduce

One challenge is "allreduce", which is a pattern we detect at lowering/codegen where we reduce a dimension then broadcast a new dimension in its place immediately.

tv0 [ i0, i1 ]
tv1 [ i0, r1 ] = sum(tv0)
tv2 [ i0, b1] = broadcast(tv1)

If we ignore this pattern while zipping up BroadcastOp then we might translate this to

tv0 [ i0, i1, b2 ]
tv1 [ i0, r1, b2 ] = sum(tv0)

I think patterns like this are easy to detect and we can leave the BroadcastOp in place in these cases, but we should be careful.

I think this is the only way we could actually have a BroadcastOp in the fusion if we implement this proposal as a preseg pass. In that case, we could also go ahead and be done with BroadcastOp once and for all if we did something like introduce IterType::AllReduce to replace the reduced+broadcasted axis.

Aliasing

If an input tensor has a root domain and it is aliased with an output tensor, should this be allowed? I think so but I haven't thought very deeply about it, so I'd probably refuse to do such aliasing until needed.

Summary

Originally we can make light use of this and only apply it to the prologue of translated matmuls. However if it works well it might be a nice simplifying step that we could run as a preseg pass.

Related:

  • This is kind of the opposite idea to just removing Broadcast axes altogether [Proposal] Remove Broadcast IterType #1778 , but with a similar goal of avoiding having to handle the imaginary point of "broadcasting" while preserving the important and real point of "concretizing broadcast" in a binary op.
@jacobhinkle jacobhinkle self-assigned this Nov 7, 2024
@jacobhinkle jacobhinkle changed the title Support root->logical transforms in Fusion inputs [Proposal] Support root->logical transforms in Fusion inputs Nov 7, 2024
@naoyam
Copy link
Collaborator

naoyam commented Nov 7, 2024

My general concern would be that these approaches would not retain the same information as what BroadcastOp has, specifically, BroadcastOp::getBroadcastDimFlags would be lost. I think that, more generally speaking, anything represented with Expr could be moved around without losing information. Reordering and adding broadcast IDs are really TensorDomain ops, and they are not recorded like split and merge, so they may not be replayable as precisely as split and merge.

I'd feel more comfortable if these scheduling were only done by a scheduler rather than more globally as a preseg pass. I'm doing something similar for slice and concat.

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Nov 7, 2024

Update ExpressionEvaluator and bind tv->getMaybeRootDomain() instead of tv->getLogicalDomain() to the received shapes of input tensors.

FYI: Instead of root, I think @wujingyue was thinking about binding allocation domain instead for distributed support: #3282

@wujingyue
Copy link
Collaborator

wujingyue commented Nov 7, 2024

Thanks for tagging me! I think we are trying to overload this poor at::Tensor with too many meanings :) I was thinking of letting at::Tensor to match allocation because it has limitations representing more "abstract" tensor domains like logical. I suspect allocation would also work for this case as long as transforms don't have to go one direction (today, it typically flows logical to allocation). Wdyt?

@jacobhinkle
Copy link
Collaborator Author

We can revisit this later if needed. For now, because of simplicity and smaller scope, I'm going to pursue #3372 instead.

@jacobhinkle jacobhinkle closed this as not planned Won't fix, can't repro, duplicate, stale Nov 7, 2024
@jacobhinkle
Copy link
Collaborator Author

I was thinking of letting at::Tensor to match allocation because it has limitations representing more "abstract" tensor domains like logical. I suspect allocation would also work for this case as long as transforms don't have to go one direction (today, it typically flows logical to allocation). Wdyt?

Yeah I like that. The allocation domain is really telling us how the input should look in memory which is all we need. Really once the fusion is defined I think the only reason we care at all about the logical size of input at::Tensors is because that lets us bind some values to ExpressionEvaluator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants