Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion from FMA dot operand to linear layout #5469

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Dec 19, 2024

This PR

  • Introduces FMA dot operand converter to linear layout, related tests
  • Fixes FMA generation. previous version had incompatible repetitions with blocked layout

Fixes #5423

Comment on lines 624 to 623
// Returns ["dim0", "dim1", ..., "dim<rank-1>"] in given order.
SmallVector<StringAttr> orderedOutDimNames(MLIRContext *ctx,
ArrayRef<unsigned> order) {
auto rank = order.size();
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(order[i])));
}
return ret;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is already somewhere else, could you move it from there?
cc @Mogball candidate to have in that file of LL utils

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not find exactly this function, but found permuteDimNames, will use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's standardOutDimNames.
Also see #5470 for a nice place to put that utility function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you wanted the dims permuted, yeah, use permuteDimNames and the standard one perhaps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, combineCtaCgaWithShape implicitly calculates order of repetitions from ctaLayout argument, so I transpose all cta components in this order with transposeOutsmethod.

This PR introduces FMA dot operand converter and related tests.
- Fix compiler crashes in FMA.cpp
- Fix lit test
@binarman binarman force-pushed the fma_operand_linearlayout branch from 140727d to 547f7f8 Compare December 20, 2024 14:43
@binarman binarman changed the title [WIP] Implement conversion from FMA dot operand to linear layout Implement conversion from FMA dot operand to linear layout Dec 20, 2024
@@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this tensor larger, because with introduction of linear layout input and output tensors turned out to be compatible.

@binarman binarman marked this pull request as ready for review December 20, 2024 15:08
@binarman binarman requested a review from ptillet as a code owner December 20, 2024 15:08
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! A few comments but overall looks good!

Do we have any test that exercises the FMA to LLVM lowering on the AMD side?

Comment on lines +219 to +220
if (!verifyCTALayout(dLayout.getCTALayout()))
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that this path is still preferred over LLs? Could you also make LLs the preferred path, or is there anything blocking us from doing so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe LL is already preferred, but I don't want to remove legacy converter yet, just in case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in which case did you hit the hard error? Can we just revert these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not hit any errors so far. But want to compare code that is generated by legacy and new converter to see if there are differences in perf/register usage, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps now this can be reverted before merging?

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp Outdated Show resolved Hide resolved
lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp Outdated Show resolved Hide resolved
@@ -292,6 +293,11 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
auto numBTiles = std::max(1u, B / shapePerCTABTile);
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);

// Found discrepancy in this case,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You meant this is a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will reword this.

// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
auto repOrder = blocked.getRepOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be operandLayout.getRepOrder?

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not implemented at the moment for blocked parent.

We had a conversation about dot operand order functions with @lezcano and did not come to a certain decision.

So for now I am just using parent order functions everywhere.

auto regOrder = blocked.getOrder();
// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there's something wrong here. Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that warps have to be broadcasted along the k dimension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Yes, this converter works with any warp shape. The trick is I create "thread" part of layout using whole K dimension, so any number of warps or threads across k dimension will be squeezed in combineCtaCgaWithShape call.

Let's take an example with dot A operand with shape [m=32, k=32]:
parent layout is perThreadShape=[1, 1], threads=[8, 4], warps=[2,2]

  1. "per-thread" layout identityStandardND(kReg, threadSize, regOrder) will cover shape [m=1, k=32]
  2. "per-warp" layout ... * identityStandardND(kLane, threadShape, threadOrder) will cover shape [m=8, k=32*4=128]
  3. "full" layout ... * warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx) will cover shape [m=16, k=3282=256]

Then I apply combineCtaCgaWithShape it repeats m dimension two times, but "broadcasts" k dimension down to 32, so all threads and warps across K dimension holds same values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

I thought about this, the only reason I choose to go without it for now is aestetic:

At this moment cta tile constructions looks like this:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kWarp, warpShape, warpOrder)
                               .transposeOuts(repDimNames);

with warpsDotOperand:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx)
                               .transposeOuts(repDimNames);

in second variant layout is still extends beyond K dimension due to lane component. to make it uniform I can introduce something like laneDotOperand, but this function will be used only in one place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the per thread layout would be

[r0, r1]
[r2, r3]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, [2*2] is a blocked parent layout property.

Dot operand is slightly different. It inherits all attributes of a parent, except k dimension. Dot operand layout implicitly extends per thread size to [2, K] for A operand and [K, 2] for B operand

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture you mentioned is related to intermediate layout, before it is expanded with combineCtaCgaWithShape

Copy link
Contributor

@Jokeren Jokeren Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think it's different from the cases where the parent is mma; we don't do implicit broadcasting on only the register dimension. That being said, I think some code clean ups have to happen later. Right now, several methods crash on this dotoperand, like getElemsPerThread

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update looks good to me now. Thanks!

if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
return fmaDotToLinearLayout(*this, shape);
}
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {

@binarman
Copy link
Contributor Author

binarman commented Dec 20, 2024

@lezcano

Do we have any test that exercises the FMA to LLVM lowering on the AMD side?

yes, sadly we actually this only for AMD at this moment: https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py#L3240

I verified that it worked for Nvidia manually, but I don't think this is tested in CI at the moment

- cleanup hash function in FMA.cpp
- add more details in TODO in SharedToDotOperandFMA.cpp
- cleanup DotOperandEncodingAttr::toLinearLayout
@binarman binarman requested review from lezcano and Jokeren December 20, 2024 21:34
@binarman
Copy link
Contributor Author

@lezcano @Jokeren Could you take a look again, please?

auto regOrder = blocked.getOrder();
// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update looks good to me now. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support blocked dot operand layout conversion to linear layout
3 participants