-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement conversion from FMA dot operand to linear layout #5469
base: main
Are you sure you want to change the base?
Conversation
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
// Returns ["dim0", "dim1", ..., "dim<rank-1>"] in given order. | ||
SmallVector<StringAttr> orderedOutDimNames(MLIRContext *ctx, | ||
ArrayRef<unsigned> order) { | ||
auto rank = order.size(); | ||
SmallVector<StringAttr> ret; | ||
for (int i = 0; i < rank; i++) { | ||
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(order[i]))); | ||
} | ||
return ret; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is already somewhere else, could you move it from there?
cc @Mogball candidate to have in that file of LL utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not find exactly this function, but found permuteDimNames
, will use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's standardOutDimNames
.
Also see #5470 for a nice place to put that utility function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you wanted the dims permuted, yeah, use permuteDimNames
and the standard one perhaps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, combineCtaCgaWithShape
implicitly calculates order of repetitions from ctaLayout
argument, so I transpose all cta components in this order with transposeOuts
method.
This PR introduces FMA dot operand converter and related tests.
- Fix compiler crashes in FMA.cpp - Fix lit test
140727d
to
547f7f8
Compare
@@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ | |||
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> | |||
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> | |||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { | |||
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { | |||
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made this tensor larger, because with introduction of linear layout input and output tensors turned out to be compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! A few comments but overall looks good!
Do we have any test that exercises the FMA to LLVM lowering on the AMD side?
if (!verifyCTALayout(dLayout.getCTALayout())) | ||
return Value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that this path is still preferred over LLs? Could you also make LLs the preferred path, or is there anything blocking us from doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe LL is already preferred, but I don't want to remove legacy converter yet, just in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but in which case did you hit the hard error? Can we just revert these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not hit any errors so far. But want to compare code that is generated by legacy and new converter to see if there are differences in perf/register usage, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps now this can be reverted before merging?
@@ -292,6 +293,11 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, | |||
auto numBTiles = std::max(1u, B / shapePerCTABTile); | |||
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); | |||
|
|||
// Found discrepancy in this case, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You meant this is a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will reword this.
// TODO: use operandLayout.getThreadOrder() | ||
auto threadOrder = blocked.getThreadOrder(); | ||
auto warpOrder = blocked.getWarpOrder(); | ||
auto repOrder = blocked.getRepOrder(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be operandLayout.getRepOrder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not implemented at the moment for blocked parent.
We had a conversation about dot operand order functions with @lezcano and did not come to a certain decision.
So for now I am just using parent order functions everywhere.
auto regOrder = blocked.getOrder(); | ||
// TODO: use operandLayout.getThreadOrder() | ||
auto threadOrder = blocked.getThreadOrder(); | ||
auto warpOrder = blocked.getWarpOrder(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like there's something wrong here. Have you tested a warp shape of [2, 2]
with more than 1 warp on the k dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that warps have to be broadcasted along the k dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think you have to use warpsDotOperand
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?
Yes, this converter works with any warp shape. The trick is I create "thread" part of layout using whole K dimension, so any number of warps or threads across k dimension will be squeezed in combineCtaCgaWithShape
call.
Let's take an example with dot A operand with shape [m=32, k=32]:
parent layout is perThreadShape=[1, 1], threads=[8, 4], warps=[2,2]
- "per-thread" layout
identityStandardND(kReg, threadSize, regOrder)
will cover shape [m=1, k=32] - "per-warp" layout
... * identityStandardND(kLane, threadShape, threadOrder)
will cover shape [m=8, k=32*4=128] - "full" layout
... * warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx)
will cover shape [m=16, k=3282=256]
Then I apply combineCtaCgaWithShape
it repeats m dimension two times, but "broadcasts" k dimension down to 32, so all threads and warps across K dimension holds same values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think you have to use warpsDotOperand here.
I thought about this, the only reason I choose to go without it for now is aestetic:
At this moment cta tile constructions looks like this:
LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
.transposeOuts(repDimNames) *
identityStandardND(kLane, threadShape, threadOrder)
.transposeOuts(repDimNames) *
identityStandardND(kWarp, warpShape, warpOrder)
.transposeOuts(repDimNames);
with warpsDotOperand
:
LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
.transposeOuts(repDimNames) *
identityStandardND(kLane, threadShape, threadOrder)
.transposeOuts(repDimNames) *
warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx)
.transposeOuts(repDimNames);
in second variant layout is still extends beyond K dimension due to lane component. to make it uniform I can introduce something like laneDotOperand
, but this function will be used only in one place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the per thread layout would be
[r0, r1]
[r2, r3]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, [2*2] is a blocked parent layout property.
Dot operand is slightly different. It inherits all attributes of a parent, except k dimension. Dot operand layout implicitly extends per thread size to [2, K] for A operand and [K, 2] for B operand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Picture you mentioned is related to intermediate layout, before it is expanded with combineCtaCgaWithShape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I think it's different from the cases where the parent is mma; we don't do implicit broadcasting on only the register dimension. That being said, I think some code clean ups have to happen later. Right now, several methods crash on this dotoperand, like getElemsPerThread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The update looks good to me now. Thanks!
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) { | ||
return fmaDotToLinearLayout(*this, shape); | ||
} | ||
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) { | |
else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) { |
yes, sadly we actually this only for AMD at this moment: https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py#L3240 I verified that it worked for Nvidia manually, but I don't think this is tested in CI at the moment |
- cleanup hash function in FMA.cpp - add more details in TODO in SharedToDotOperandFMA.cpp - cleanup DotOperandEncodingAttr::toLinearLayout
auto regOrder = blocked.getOrder(); | ||
// TODO: use operandLayout.getThreadOrder() | ||
auto threadOrder = blocked.getThreadOrder(); | ||
auto warpOrder = blocked.getWarpOrder(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The update looks good to me now. Thanks!
This PR
Fixes #5423