You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand for example. This lets us use the default exact domain mapping between operands and MmaOp output. However, it means that if we are translating a Fusion that has MatmulOp or LinearOp to use MmaOp, we need to introduce BroadcastOp nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.
Proposed Approach
I propose to do the following:
Add attributes to the MmaOp
Add a special case in PairwiseLogicalDomainMap that will map the output domains to domains in the inputs. This is similar to what we do for SdpaFwdOp and SdpaBwdOp currently.
Update mma_utils::MatmulPattern::translateToMmaOp to skip inserting broadcasts.
Update the Ampere and Hopper matmul schedulers to not assume there is a broadcast M or N dimension in the ab and bb tensors.
The text was updated successfully, but these errors were encountered:
This is a proposal to enable MmaOp to receive inputs shaped like [M, K] and [N, K] instead of [M, 1, K] and [1, N, K].
This is an alternative to #3366.
Motivation
Currently, MmaOp requires at least 3D inputs in which all of the dimensions "line up". That means that M dimensions should be Iteration in the A operand and Broadcast in the B operand for example. This lets us use the default exact domain mapping between operands and MmaOp output. However, it means that if we are translating a Fusion that has
MatmulOp
orLinearOp
to useMmaOp
, we need to introduceBroadcastOp
nodes, which interferes with the optimal gmem->smem->mma pipeline on Hopper.Proposed Approach
I propose to do the following:
MmaOp
PairwiseLogicalDomainMap
that will map the output domains to domains in the inputs. This is similar to what we do forSdpaFwdOp
andSdpaBwdOp
currently.mma_utils::MatmulPattern::translateToMmaOp
to skip inserting broadcasts.ab
andbb
tensors.The text was updated successfully, but these errors were encountered: