Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion from FMA dot operand to linear layout #5469

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ for
// Block encoding is dense stride layout. The elements per thread are contiguous.
return getSizePerThread();
};
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
}];

let hasCustomAssemblyFormat = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
}
}

void verifyCTALayout(CTALayoutAttr ctaLayout) {
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
auto ctaSplit = ctaLayout.getCTASplitNum();
for (auto split : ctaSplit) {
if (split != 1)
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
"are not supported in FMA dot yet.");
return false;
}
return true;
}

/// Get a linear offset of first element loaded by thread.
Expand Down Expand Up @@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
Value thread, Location loc,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const int dotOpNo) {
verifyCTALayout(dLayout.getCTALayout());
if (!verifyCTALayout(dLayout.getCTALayout()))
return Value();
Comment on lines +219 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that this path is still preferred over LLs? Could you also make LLs the preferred path, or is there anything blocking us from doing so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe LL is already preferred, but I don't want to remove legacy converter yet, just in case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in which case did you hit the hard error? Can we just revert these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not hit any errors so far. But want to compare code that is generated by legacy and new converter to see if there are differences in perf/register usage, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps now this can be reverted before merging?


DimIdx dim;
dim.batch = 0;
Expand Down Expand Up @@ -292,6 +293,11 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
auto numBTiles = std::max(1u, B / shapePerCTABTile);
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);

// Found discrepancy in this case,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You meant this is a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will reword this.

// use linear layout based converter for this case
if (numBTiles != 1 || numNonKTiles != 1)
return Value();

auto perThreadShape =
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);

Expand Down
105 changes: 75 additions & 30 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,54 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;

using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return std::hash<unsigned>()(k.bRepIdx) ^
(std::hash<unsigned>()(k.nonKRepIdx) << 1) ^
(std::hash<unsigned>()(k.bIdx) << 2) ^
(std::hash<unsigned>()(k.nonKIdx) << 3) ^
(std::hash<unsigned>()(k.kIdx) << 4);
}
lezcano marked this conversation as resolved.
Show resolved Hide resolved
};

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

static ValueTableFMA
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
unsigned kDim, unsigned nonKDim,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<unsigned> order) {
static ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perTileShape.size() == 3);
assert(elems.size() == product(perTileShape));
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}
Expand All @@ -54,46 +84,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned perThreadShape[3];
unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
numRep = std::max(static_cast<unsigned>(1), numRep);
perThreadShape[i] = numRep * sizePerThread[i];
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {perThreadShape[0], perThreadShape[1], K},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {perThreadShape[0], K, perThreadShape[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned b = 0; b < perThreadShape[0]; ++b)
for (unsigned m = 0; m < perThreadShape[1]; ++m)
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearAccumIdx =
linearize(multiDimAccumIdx, perThreadShape, order);
for (unsigned k = 0; k < K; ++k) {
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
}
}
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx =
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;
for (unsigned k = 0; k < K; ++k) {
auto aOp = has[{bRep, mRep, b, m, k}];
auto bOp = hbs[{bRep, nRep, b, n, k}];
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, aOp, bOp, acc[linearAccumIdx]);
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
LinearEncodingAttr>(dstLayout))
return true;
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
if (isa<MmaEncodingTrait>(dot.getParent()))
if (isa<MmaEncodingTrait, BlockedEncodingAttr>(dot.getParent()))
return true;
}
return false;
lezcano marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -164,6 +164,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
Value res = SharedToDotOperandFMA::convertLayout(
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
thread, loc, getTypeConverter(), rewriter);
if (!res)
return failure();
rewriter.replaceOp(op, res);
return success();
}
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,14 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return product<unsigned>(getElemsPerThread(shape, eltTy));
}

// Blocked encoding

SmallVector<unsigned>
BlockedEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

//

SmallVector<unsigned>
Expand Down
46 changes: 44 additions & 2 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,45 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

std::optional<LinearLayout>
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
ArrayRef<int64_t> shape) {
int rank = shape.size();
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
MLIRContext *ctx = operandLayout.getContext();

// TODO: introduce registerOrder or use getOrder(operandLayout)
// Currently this order is used in legacy converter, because we do not
// have access to full dot operand layout, only parent part.
auto regOrder = blocked.getOrder();
// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there's something wrong here. Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that warps have to be broadcasted along the k dimension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Yes, this converter works with any warp shape. The trick is I create "thread" part of layout using whole K dimension, so any number of warps or threads across k dimension will be squeezed in combineCtaCgaWithShape call.

Let's take an example with dot A operand with shape [m=32, k=32]:
parent layout is perThreadShape=[1, 1], threads=[8, 4], warps=[2,2]

  1. "per-thread" layout identityStandardND(kReg, threadSize, regOrder) will cover shape [m=1, k=32]
  2. "per-warp" layout ... * identityStandardND(kLane, threadShape, threadOrder) will cover shape [m=8, k=32*4=128]
  3. "full" layout ... * warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx) will cover shape [m=16, k=3282=256]

Then I apply combineCtaCgaWithShape it repeats m dimension two times, but "broadcasts" k dimension down to 32, so all threads and warps across K dimension holds same values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

I thought about this, the only reason I choose to go without it for now is aestetic:

At this moment cta tile constructions looks like this:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kWarp, warpShape, warpOrder)
                               .transposeOuts(repDimNames);

with warpsDotOperand:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx)
                               .transposeOuts(repDimNames);

in second variant layout is still extends beyond K dimension due to lane component. to make it uniform I can introduce something like laneDotOperand, but this function will be used only in one place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the per thread layout would be

[r0, r1]
[r2, r3]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, [2*2] is a blocked parent layout property.

Dot operand is slightly different. It inherits all attributes of a parent, except k dimension. Dot operand layout implicitly extends per thread size to [2, K] for A operand and [K, 2] for B operand

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture you mentioned is related to intermediate layout, before it is expanded with combineCtaCgaWithShape

Copy link
Contributor

@Jokeren Jokeren Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think it's different from the cases where the parent is mma; we don't do implicit broadcasting on only the register dimension. That being said, I think some code clean ups have to happen later. Right now, several methods crash on this dotoperand, like getElemsPerThread

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update looks good to me now. Thanks!

auto repOrder = blocked.getRepOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be operandLayout.getRepOrder?

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not implemented at the moment for blocked parent.

We had a conversation about dot operand order functions with @lezcano and did not come to a certain decision.

So for now I am just using parent order functions everywhere.


StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

SmallVector<unsigned> threadSize = blocked.getSizePerThread();
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
threadSize[kDimIdx] = shape[kDimIdx];
auto threadShape = blocked.getThreadsPerWarp();
auto warpShape = blocked.getWarpsPerCTA();

SmallVector<StringAttr> repDimNames =
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);

LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
.transposeOuts(repDimNames) *
identityStandardND(kLane, threadShape, threadOrder)
.transposeOuts(repDimNames) *
identityStandardND(kWarp, warpShape, warpOrder)
.transposeOuts(repDimNames);
lezcano marked this conversation as resolved.
Show resolved Hide resolved

return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
}

LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
Expand Down Expand Up @@ -750,9 +789,12 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
return fmaDotToLinearLayout(*this, shape);
}
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {

return mfmaDotToLinearLayout(*this, shape);
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
return wmmaDotOperandToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
return nvidiaDotToLinearLayout(shape, *this);
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this tensor larger, because with introduction of linear layout input and output tensors turned out to be compatible.

// CHECK-NOT: ttg.convert_layout
// CHECK: ttg.local_alloc
// CHECK: ttg.local_load
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}
52 changes: 52 additions & 0 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,58 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) {
{S("dim0"), S("dim1"), S("dim2"), S("dim3")}));
}

TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs) {
binarman marked this conversation as resolved.
Show resolved Hide resolved
auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4},
/*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0},
/*cta order*/ {1, 0});
auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0);
EXPECT_EQ(
toLinearLayout({32, 16}, dotOperand),
LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("lane"), {{0, 0}, {0, 0}, {2, 0}, {4, 0}, {8, 0}}},
{S("warp"), {{0, 0}, {0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs1) {
auto parent = blocked(/*size*/ {1, 1}, /*threads*/ {1, 32}, /*warps*/ {1, 1},
/*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0},
/*cta order*/ {1, 0});
auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0);
EXPECT_EQ(toLinearLayout({1, 64}, dotOperand),
LinearLayout({{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}},
{S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}},
{S("warp"), {}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));

auto dotOperand1 = dot(parent, /*idx*/ 1, /*kWidth*/ 0);
EXPECT_EQ(toLinearLayout({64, 64}, dotOperand1),
LinearLayout(
{{S("register"),
{{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {32, 0}, {0, 32}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}},
{S("warp"), {}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, BlockedDotOperandRhs) {
auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4},
/*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0},
/*cta order*/ {1, 0});
auto dotOperand = dot(parent, /*idx*/ 1, /*kWidth*/ 0);
EXPECT_EQ(toLinearLayout({16, 64}, dotOperand),
LinearLayout({{S("register"),
{{0, 1}, {0, 2}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 4}, {0, 8}, {0, 0}, {0, 0}, {0, 0}}},
{S("warp"), {{0, 16}, {0, 32}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) {
EXPECT_EQ(toLinearLayout({16, 16},
mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})),
Expand Down