-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement conversion from FMA dot operand to linear layout #5469
base: main
Are you sure you want to change the base?
Changes from all commits
0ac924a
69c3354
547f7f8
2157f20
a7c978b
eb33d00
7f2d2a6
c372e5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, | |
} | ||
} | ||
|
||
void verifyCTALayout(CTALayoutAttr ctaLayout) { | ||
bool verifyCTALayout(CTALayoutAttr ctaLayout) { | ||
auto ctaSplit = ctaLayout.getCTASplitNum(); | ||
for (auto split : ctaSplit) { | ||
if (split != 1) | ||
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) " | ||
"are not supported in FMA dot yet."); | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
/// Get a linear offset of first element loaded by thread. | ||
|
@@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, | |
Value thread, Location loc, | ||
const LLVMTypeConverter *typeConverter, | ||
ConversionPatternRewriter &rewriter, const int dotOpNo) { | ||
verifyCTALayout(dLayout.getCTALayout()); | ||
if (!verifyCTALayout(dLayout.getCTALayout())) | ||
return Value(); | ||
|
||
DimIdx dim; | ||
dim.batch = 0; | ||
|
@@ -292,6 +293,15 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, | |
auto numBTiles = std::max(1u, B / shapePerCTABTile); | ||
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); | ||
|
||
// Found discrepancy in this case, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You meant this is a TODO? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, will reword this. |
||
// use linear layout based converter for this case | ||
// TODO: break batch and non-k dimension iterations in | ||
// "repeat" and "inside-repeate" parts, pack them in llvm structure | ||
// according repeat and register order. | ||
// See FMA.cpp:getValueTableFromStructFMA for reference | ||
if (numBTiles != 1 || numNonKTiles != 1) | ||
return Value(); | ||
|
||
auto perThreadShape = | ||
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -239,8 +239,11 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape, | |
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); | ||
} | ||
|
||
LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape, | ||
ArrayRef<unsigned> warpOrder, unsigned inner) { | ||
/// Function to generate lane and warp layout for dot operands. | ||
LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, | ||
ArrayRef<unsigned> shape, | ||
ArrayRef<unsigned> order, | ||
unsigned kDim, StringAttr inDimName) { | ||
// Let warpsPerCTAMma = {2, 2}, then | ||
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB | ||
// assume warpOrder = {1, 0} | ||
|
@@ -255,24 +258,23 @@ LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape, | |
// - - | - - - - | - - | ||
// 2 3 | 2 3 0 2 | 1 3 | ||
// In other words, we need to broadcast along K | ||
auto rank = warpShape.size(); | ||
auto rank = shape.size(); | ||
auto dimNames = standardOutDimNames(ctx, rank); | ||
LinearLayout warpLayout = LinearLayout::empty(); | ||
LinearLayout layout = LinearLayout::empty(); | ||
|
||
// We have to broadcast along the inner dimension | ||
// For A, when moving along M we go from 0 to 2. | ||
// For B, when moving along N we go from 0 to 1. | ||
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting | ||
// Same happens if the warpOrder is {0, 1}, like in Hopper | ||
for (auto d : warpOrder) { | ||
if (d == inner) { | ||
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]); | ||
for (auto d : order) { | ||
if (d == kDim) { | ||
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); | ||
} else { | ||
warpLayout *= | ||
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]); | ||
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); | ||
} | ||
} | ||
return warpLayout; | ||
return layout; | ||
} | ||
|
||
} // anonymous namespace | ||
|
@@ -620,7 +622,8 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, | |
// Generate warp layout | ||
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); | ||
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout); | ||
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim); | ||
LinearLayout warpLayout = | ||
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp")); | ||
|
||
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper | ||
// extension of layout | ||
|
@@ -650,6 +653,48 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const { | |
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); | ||
} | ||
|
||
std::optional<LinearLayout> | ||
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, | ||
ArrayRef<int64_t> shape) { | ||
int rank = shape.size(); | ||
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent()); | ||
MLIRContext *ctx = operandLayout.getContext(); | ||
|
||
// TODO: introduce registerOrder or use getOrder(operandLayout) | ||
// Currently this order is used in legacy converter, because we do not | ||
// have access to full dot operand layout, only parent part. | ||
auto regOrder = blocked.getOrder(); | ||
// TODO: use operandLayout.getThreadOrder() | ||
auto threadOrder = blocked.getThreadOrder(); | ||
auto warpOrder = blocked.getWarpOrder(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like there's something wrong here. Have you tested a warp shape of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My point is that warps have to be broadcasted along the k dimension There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I think you have to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, this converter works with any warp shape. The trick is I create "thread" part of layout using whole K dimension, so any number of warps or threads across k dimension will be squeezed in Let's take an example with dot A operand with shape [m=32, k=32]:
Then I apply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I thought about this, the only reason I choose to go without it for now is aestetic: At this moment cta tile constructions looks like this:
with
in second variant layout is still extends beyond K dimension due to lane component. to make it uniform I can introduce something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the per thread layout would be [r0, r1] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, [2*2] is a blocked parent layout property. Dot operand is slightly different. It inherits all attributes of a parent, except k dimension. Dot operand layout implicitly extends per thread size to [2, K] for A operand and [K, 2] for B operand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Picture you mentioned is related to intermediate layout, before it is expanded with combineCtaCgaWithShape There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. I think it's different from the cases where the parent is mma; we don't do implicit broadcasting on only the register dimension. That being said, I think some code clean ups have to happen later. Right now, several methods crash on this dotoperand, like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The update looks good to me now. Thanks! |
||
auto repOrder = blocked.getRepOrder(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not implemented at the moment for blocked parent. We had a conversation about dot operand order functions with @lezcano and did not come to a certain decision. So for now I am just using parent order functions everywhere. |
||
|
||
StringAttr kReg = S("register"); | ||
StringAttr kLane = S("lane"); | ||
StringAttr kWarp = S("warp"); | ||
|
||
SmallVector<unsigned> threadSize = blocked.getSizePerThread(); | ||
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; | ||
threadSize[kDimIdx] = shape[kDimIdx]; | ||
auto threadShape = blocked.getThreadsPerWarp(); | ||
auto warpShape = blocked.getWarpsPerCTA(); | ||
|
||
SmallVector<StringAttr> repDimNames = | ||
permuteDimNames(standardOutDimNames(ctx, rank), repOrder); | ||
|
||
auto registersLayout = identityStandardND(kReg, threadSize, regOrder); | ||
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder, | ||
kDimIdx, kLane); | ||
auto warpsLayout = | ||
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp); | ||
|
||
LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) * | ||
lanesLayout.transposeOuts(repDimNames) * | ||
warpsLayout.transposeOuts(repDimNames); | ||
|
||
return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape); | ||
} | ||
|
||
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape, | ||
unsigned kWidth, ArrayRef<unsigned> order, | ||
ArrayRef<unsigned> repOrder) { | ||
|
@@ -740,19 +785,21 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape, | |
auto ctaLayout = | ||
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder()); | ||
auto kDim = isA ? rank - 1 : rank - 2; | ||
ctaLayout *= | ||
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim) | ||
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); | ||
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), | ||
mma.getWarpOrder(), kDim, S("warp")) | ||
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); | ||
|
||
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); | ||
} | ||
|
||
std::optional<LinearLayout> | ||
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const { | ||
auto parent = getParent(); | ||
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) { | ||
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) { | ||
return fmaDotToLinearLayout(*this, shape); | ||
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) { | ||
return mfmaDotToLinearLayout(*this, shape); | ||
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) { | ||
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) { | ||
return wmmaDotOperandToLinearLayout(*this, shape); | ||
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) { | ||
return nvidiaDotToLinearLayout(shape, *this); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ | |
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> | ||
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { | ||
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { | ||
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. made this tensor larger, because with introduction of linear layout input and output tensors turned out to be compatible. |
||
// CHECK-NOT: ttg.convert_layout | ||
// CHECK: ttg.local_alloc | ||
// CHECK: ttg.local_load | ||
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> | ||
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> | ||
tt.return | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that this path is still preferred over LLs? Could you also make LLs the preferred path, or is there anything blocking us from doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe LL is already preferred, but I don't want to remove legacy converter yet, just in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but in which case did you hit the hard error? Can we just revert these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not hit any errors so far. But want to compare code that is generated by legacy and new converter to see if there are differences in perf/register usage, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps now this can be reverted before merging?