From 0ac924ad9f1ea9b8ac507a417cb1219115876bca Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 19 Dec 2024 15:21:11 +0000 Subject: [PATCH 1/8] Implement conversion from FMA dot operand to linear layout This PR introduces FMA dot operand converter and related tests. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 3 ++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 1 + .../SharedToDotOperandFMA.cpp | 10 ++-- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 4 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 19 +++++++ .../TritonGPU/IR/LinearLayoutConversions.cpp | 45 +++++++++++++++- .../TritonGPU/LinearLayoutConversionsTest.cpp | 52 +++++++++++++++++++ 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index b81ecf103a05..27e2462067f8 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -217,6 +217,9 @@ LinearLayout ensureLayoutNotSmallerThan( // Return a vector of the standard out dimension names for tensor layouts. These // are "dim0", "dim1", etc. SmallVector standardOutDimNames(MLIRContext *ctx, int rank); +// Returns ["dim0", "dim1", ..., "dim"] in given order. +SmallVector orderedOutDimNames(MLIRContext *ctx, + ArrayRef order); // Return an identity mapping from `inDimName` to the standard out dimensions, // with the dimensions sized according to the shape. The bases are sorted // according to `order`, with the most minor dimension first. diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 23a59f383341..92298e80763e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -766,6 +766,7 @@ for // Block encoding is dense stride layout. The elements per thread are contiguous. return getSizePerThread(); }; + SmallVector getRepOrderForOperand(int opIdx) const; }]; let hasCustomAssemblyFormat = 1; diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index 0ea294d53bb4..f0fb1eecc22e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, } } -void verifyCTALayout(CTALayoutAttr ctaLayout) { +bool verifyCTALayout(CTALayoutAttr ctaLayout) { auto ctaSplit = ctaLayout.getCTASplitNum(); for (auto split : ctaSplit) { if (split != 1) - llvm::report_fatal_error("tensors splited in CGA(thread group clusters) " - "are not supported in FMA dot yet."); + return false; } + return true; } /// Get a linear offset of first element loaded by thread. @@ -216,7 +216,9 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, const int dotOpNo) { - verifyCTALayout(dLayout.getCTALayout()); + if (!verifyCTALayout(dLayout.getCTALayout())) + return Value(); + return Value(); DimIdx dim; dim.batch = 0; diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 2b7026eaee6f..298ddc105fc6 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -126,7 +126,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { LinearEncodingAttr>(dstLayout)) return true; if (auto dot = dyn_cast(dstLayout)) { - if (isa(dot.getParent())) + if (isa(dot.getParent())) return true; } return false; @@ -164,6 +164,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { Value res = SharedToDotOperandFMA::convertLayout( dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, thread, loc, getTypeConverter(), rewriter); + if (!res) + return failure(); rewriter.replaceOp(op, res); return success(); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 337328f650c6..550b369d084e 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -621,6 +621,17 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } +// Returns ["dim0", "dim1", ..., "dim"] in given order. +SmallVector orderedOutDimNames(MLIRContext *ctx, + ArrayRef order) { + auto rank = order.size(); + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(order[i]))); + } + return ret; +} + // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). @@ -1000,6 +1011,14 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } +// Blocked encoding + +SmallVector +BlockedEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + // SmallVector diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index e7174f0f9b9a..9cf525bc3c57 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -650,6 +650,44 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } +std::optional +fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, + ArrayRef shape) { + int rank = shape.size(); + auto blocked = cast(operandLayout.getParent()); + MLIRContext *ctx = operandLayout.getContext(); + + // TODO: introduce registerOrder or use getOrder(operandLayout) + // Currently this order is used in legacy converter, because we do not + // have access to full dot operand layout, only parent part. + auto regOrder = blocked.getOrder(); + // TODO: use operandLayout.getThreadOrder() + auto threadOrder = blocked.getThreadOrder(); + auto warpOrder = blocked.getWarpOrder(); + auto repOrder = blocked.getRepOrder(); + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + SmallVector threadSize = blocked.getSizePerThread(); + auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + threadSize[kDimIdx] = shape[kDimIdx]; + auto threadShape = blocked.getThreadsPerWarp(); + auto warpShape = blocked.getWarpsPerCTA(); + + SmallVector repDimNames = orderedOutDimNames(ctx, repOrder); + + LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder) + .transposeOuts(repDimNames) * + identityStandardND(kLane, threadShape, threadOrder) + .transposeOuts(repDimNames) * + identityStandardND(kWarp, warpShape, warpOrder) + .transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape); +} + LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, unsigned kWidth, ArrayRef order, ArrayRef repOrder) { @@ -750,9 +788,12 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { auto parent = getParent(); - if (auto mfmaLayout = llvm::dyn_cast(parent)) { + if (auto blockedLayout = mlir::dyn_cast(parent)) { + return fmaDotToLinearLayout(*this, shape); + } + if (auto mfmaLayout = mlir::dyn_cast(parent)) { return mfmaDotToLinearLayout(*this, shape); - } else if (auto wmmaLayout = llvm::dyn_cast(parent)) { + } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { return wmmaDotOperandToLinearLayout(*this, shape); } else if (auto mma = mlir::dyn_cast(parent)) { return nvidiaDotToLinearLayout(shape, *this); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 7850b87ac595..30c9f3376f4b 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -311,6 +311,58 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs) { + auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4}, + /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, + /*cta order*/ {1, 0}); + auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0); + EXPECT_EQ( + toLinearLayout({32, 16}, dotOperand), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("lane"), {{0, 0}, {0, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs1) { + auto parent = blocked(/*size*/ {1, 1}, /*threads*/ {1, 32}, /*warps*/ {1, 1}, + /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, + /*cta order*/ {1, 0}); + auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0); + EXPECT_EQ(toLinearLayout({1, 64}, dotOperand), + LinearLayout({{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto dotOperand1 = dot(parent, /*idx*/ 1, /*kWidth*/ 0); + EXPECT_EQ(toLinearLayout({64, 64}, dotOperand1), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {32, 0}, {0, 32}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, BlockedDotOperandRhs) { + auto parent = blocked(/*size*/ {2, 4}, /*threads*/ {8, 4}, /*warps*/ {2, 4}, + /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, + /*cta order*/ {1, 0}); + auto dotOperand = dot(parent, /*idx*/ 1, /*kWidth*/ 0); + EXPECT_EQ(toLinearLayout({16, 64}, dotOperand), + LinearLayout({{S("register"), + {{0, 1}, {0, 2}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 4}, {0, 8}, {0, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { EXPECT_EQ(toLinearLayout({16, 16}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), From 69c335482ad5a82485ab979636bb65472c330a5c Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 19 Dec 2024 21:21:01 +0000 Subject: [PATCH 2/8] fix repetitions in FMA dot inputs and outputs --- .../SharedToDotOperandFMA.cpp | 6 +- .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 105 +++++++++++++----- 2 files changed, 80 insertions(+), 31 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index f0fb1eecc22e..e5191d091913 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -218,7 +218,6 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, ConversionPatternRewriter &rewriter, const int dotOpNo) { if (!verifyCTALayout(dLayout.getCTALayout())) return Value(); - return Value(); DimIdx dim; dim.batch = 0; @@ -294,6 +293,11 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, auto numBTiles = std::max(1u, B / shapePerCTABTile); auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); + // Found discrepancy in this case, + // use linear layout based converter for this case + if (numBTiles != 1 || numNonKTiles != 1) + return Value(); + auto perThreadShape = getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index e32b3e0d6ed1..9e9c2009f4d2 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -13,24 +13,54 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; -using ValueTableFMA = std::map, Value>; +/// \brief spatial position of repetition and register of a given value +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } +}; + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return std::hash()(k.bRepIdx) ^ + (std::hash()(k.nonKRepIdx) << 1) ^ + (std::hash()(k.bIdx) << 2) ^ + (std::hash()(k.nonKIdx) << 3) ^ + (std::hash()(k.kIdx) << 4); + } +}; + +using ValueTableFMA = std::unordered_map; -static ValueTableFMA -getValueTableFromStructFMA(Value val, ArrayRef perTileShape, - unsigned kDim, unsigned nonKDim, - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef order) { +static ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); - assert(perTileShape.size() == 3); - assert(elems.size() == product(perTileShape)); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); assert(kDim == 1 || kDim == 2); assert(nonKDim == 1 || nonKDim == 2); const unsigned bDim = 0; for (unsigned idx = 0; idx < elems.size(); ++idx) { - auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order); - res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx]; + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, perRepShape, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + res[key] = elems[idx]; } return res; } @@ -54,7 +84,9 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); @@ -62,38 +94,51 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto sizePerThread = expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto numElemsPerThread = product(sizePerThread); auto shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); unsigned K = aShapePerCTA[2]; - unsigned perThreadShape[3]; + unsigned threadTileShape[3]; + unsigned repetitions[3]; for (int i = 0; i < 3; ++i) { - unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; - numRep = std::max(static_cast(1), numRep); - perThreadShape[i] = numRep * sizePerThread[i]; + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); } auto has = getValueTableFromStructFMA( - llA, {perThreadShape[0], perThreadShape[1], K}, - /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order); + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); auto hbs = getValueTableFromStructFMA( - llB, {perThreadShape[0], K, perThreadShape[2]}, - /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order); + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); SmallVector acc = cc; - for (unsigned b = 0; b < perThreadShape[0]; ++b) - for (unsigned m = 0; m < perThreadShape[1]; ++m) - for (unsigned n = 0; n < perThreadShape[2]; ++n) { - SmallVector multiDimAccumIdx = {b, m, n}; - unsigned linearAccumIdx = - linearize(multiDimAccumIdx, perThreadShape, order); - for (unsigned k = 0; k < K; ++k) { - acc[linearAccumIdx] = rewriter.create( - loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]); - } - } + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + for (unsigned k = 0; k < K; ++k) { + auto aOp = has[{bRep, mRep, b, m, k}]; + auto bOp = hbs[{bRep, nRep, b, n, k}]; + acc[linearAccumIdx] = rewriter.create( + loc, aOp, bOp, acc[linearAccumIdx]); + } + } auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); rewriter.replaceOp(op, res); From 547f7f8658e540d56cf0259892fad38566a50069 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 12:56:26 +0000 Subject: [PATCH 3/8] - Remove orderedOutDimNames function - Fix compiler crashes in FMA.cpp - Fix lit test --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 3 --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 11 ----------- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 3 ++- .../amd/decompose-unsupported-conversions.mlir | 4 ++-- 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 27e2462067f8..b81ecf103a05 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -217,9 +217,6 @@ LinearLayout ensureLayoutNotSmallerThan( // Return a vector of the standard out dimension names for tensor layouts. These // are "dim0", "dim1", etc. SmallVector standardOutDimNames(MLIRContext *ctx, int rank); -// Returns ["dim0", "dim1", ..., "dim"] in given order. -SmallVector orderedOutDimNames(MLIRContext *ctx, - ArrayRef order); // Return an identity mapping from `inDimName` to the standard out dimensions, // with the dimensions sized according to the shape. The bases are sorted // according to `order`, with the most minor dimension first. diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index 9e9c2009f4d2..a01e0023d7d1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -56,7 +56,7 @@ static ValueTableFMA getValueTableFromStructFMA( auto inRepSpatialIdx = mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); auto repSpatialIdx = - mlir::LLVM::delinearize(repLinearIdx, perRepShape, repOrder); + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], inRepSpatialIdx[kDim]}; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 550b369d084e..aa3c5b994140 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -621,17 +621,6 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } -// Returns ["dim0", "dim1", ..., "dim"] in given order. -SmallVector orderedOutDimNames(MLIRContext *ctx, - ArrayRef order) { - auto rank = order.size(); - SmallVector ret; - for (int i = 0; i < rank; i++) { - ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(order[i]))); - } - return ret; -} - // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 9cf525bc3c57..a4ebd5b31207 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -676,7 +676,8 @@ fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, auto threadShape = blocked.getThreadsPerWarp(); auto warpShape = blocked.getWarpsPerCTA(); - SmallVector repDimNames = orderedOutDimNames(ctx, repOrder); + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), repOrder); LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder) .transposeOuts(repDimNames) * diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index 983d16e8d6bf..d578bd3c47a4 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} { - tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) { // CHECK-NOT: ttg.convert_layout // CHECK: ttg.local_alloc // CHECK: ttg.local_load - %0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + %0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> tt.return } } From 2157f208189f9665bc031dba1f235051458f446d Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 15:07:58 +0000 Subject: [PATCH 4/8] remove redundant changes --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 3 +-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 92298e80763e..4ce67b4a70b9 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -763,10 +763,9 @@ for let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { - // Block encoding is dense stride layout. The elements per thread are contiguous. + // Block encoding is dense str ide layout. The elements per thread are contiguous. return getSizePerThread(); }; - SmallVector getRepOrderForOperand(int opIdx) const; }]; let hasCustomAssemblyFormat = 1; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index aa3c5b994140..337328f650c6 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1000,14 +1000,6 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } -// Blocked encoding - -SmallVector -BlockedEncodingAttr::getRepOrderForOperand(int opIdx) const { - auto rank = getWarpsPerCTA().size(); - return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); -} - // SmallVector From a7c978bcc7f57b46d957e7a785c656296552703f Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 16:53:42 +0100 Subject: [PATCH 5/8] fix typo --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 4ce67b4a70b9..23a59f383341 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -763,7 +763,7 @@ for let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { - // Block encoding is dense str ide layout. The elements per thread are contiguous. + // Block encoding is dense stride layout. The elements per thread are contiguous. return getSizePerThread(); }; }]; From eb33d00cf40c4d9a69ad61db521c56149dada142 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 19:42:39 +0000 Subject: [PATCH 6/8] generate warp and lane layout in broadcast form --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index a4ebd5b31207..f2c593da1f74 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -239,8 +239,11 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); } -LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef warpShape, - ArrayRef warpOrder, unsigned inner) { +/// Function to generate lane and warp layout for dot operands. +LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, + ArrayRef shape, + ArrayRef order, + unsigned kDim, StringAttr inDimName) { // Let warpsPerCTAMma = {2, 2}, then // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB // assume warpOrder = {1, 0} @@ -255,24 +258,23 @@ LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef warpShape, // - - | - - - - | - - // 2 3 | 2 3 0 2 | 1 3 // In other words, we need to broadcast along K - auto rank = warpShape.size(); + auto rank = shape.size(); auto dimNames = standardOutDimNames(ctx, rank); - LinearLayout warpLayout = LinearLayout::empty(); + LinearLayout layout = LinearLayout::empty(); // We have to broadcast along the inner dimension // For A, when moving along M we go from 0 to 2. // For B, when moving along N we go from 0 to 1. // As such, choosing the order of A {1, 0}, gives us the correct broadcasting // Same happens if the warpOrder is {0, 1}, like in Hopper - for (auto d : warpOrder) { - if (d == inner) { - warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]); + for (auto d : order) { + if (d == kDim) { + layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); } else { - warpLayout *= - LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]); + layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); } } - return warpLayout; + return layout; } } // anonymous namespace @@ -620,7 +622,8 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, // Generate warp layout auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout); - LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim); + LinearLayout warpLayout = + broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp")); // reorder dim names in rep order, so combineCtaCgaWithShape generate proper // extension of layout @@ -679,12 +682,15 @@ fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, SmallVector repDimNames = permuteDimNames(standardOutDimNames(ctx, rank), repOrder); - LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder) - .transposeOuts(repDimNames) * - identityStandardND(kLane, threadShape, threadOrder) - .transposeOuts(repDimNames) * - identityStandardND(kWarp, warpShape, warpOrder) - .transposeOuts(repDimNames); + auto registersLayout = identityStandardND(kReg, threadSize, regOrder); + auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder, + kDimIdx, kLane); + auto warpsLayout = + broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp); + + LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) * + lanesLayout.transposeOuts(repDimNames) * + warpsLayout.transposeOuts(repDimNames); return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape); } @@ -779,9 +785,9 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, auto ctaLayout = nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder()); auto kDim = isA ? rank - 1 : rank - 2; - ctaLayout *= - warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim) - .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), + mma.getWarpOrder(), kDim, S("warp")) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } From 7f2d2a61b6e27ea71553bd584a23464d4446c8f7 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 20:03:52 +0000 Subject: [PATCH 7/8] - remove legacy converter from pattern - cleanup hash function in FMA.cpp - add more details in TODO in SharedToDotOperandFMA.cpp - cleanup DotOperandEncodingAttr::toLinearLayout --- .../SharedToDotOperandFMA.cpp | 4 ++ .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 7 +-- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 45 +------------------ .../TritonGPU/IR/LinearLayoutConversions.cpp | 3 +- 4 files changed, 8 insertions(+), 51 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index e5191d091913..214a12b2e8af 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -295,6 +295,10 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout, // Found discrepancy in this case, // use linear layout based converter for this case + // TODO: break batch and non-k dimension iterations in + // "repeat" and "inside-repeate" parts, pack them in llvm structure + // according repeat and register order. + // See FMA.cpp:getValueTableFromStructFMA for reference if (numBTiles != 1 || numNonKTiles != 1) return Value(); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index a01e0023d7d1..ecf1d12914fa 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -27,11 +27,8 @@ struct OperandValueKey { template <> struct std::hash { std::size_t operator()(const OperandValueKey &k) const { - return std::hash()(k.bRepIdx) ^ - (std::hash()(k.nonKRepIdx) << 1) ^ - (std::hash()(k.bIdx) << 2) ^ - (std::hash()(k.nonKIdx) << 3) ^ - (std::hash()(k.kIdx) << 4); + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); } }; diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 298ddc105fc6..b4397bc78d55 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -119,56 +119,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { } - // FIXME [Dot LL] - // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedLayout(Attribute dstLayout) { - if (isa(dstLayout)) - return true; - if (auto dot = dyn_cast(dstLayout)) { - if (isa(dot.getParent())) - return true; - } - return false; - }; - LogicalResult matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType dstTy = op.getType(); - Attribute dstLayout = dstTy.getEncoding(); - if (isSupportedLayout(dstLayout)) { - return lowerSharedToDistributed(op, adaptor, getTypeConverter(), - rewriter); - } - if (isa(dstLayout) && - isa( - cast(dstLayout).getParent())) { - return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); - } - return failure(); + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } private: - LogicalResult - lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType dstTy = op.getType(); - Attribute dstLayout = dstTy.getEncoding(); - auto dotLayout = cast(dstLayout); - auto blockedLayout = cast( - cast(dstLayout).getParent()); - auto thread = getThreadId(rewriter, loc); - Value res = SharedToDotOperandFMA::convertLayout( - dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, - thread, loc, getTypeConverter(), rewriter); - if (!res) - return failure(); - rewriter.replaceOp(op, res); - return success(); - } LogicalResult lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor, const LLVMTypeConverter *typeConverter, diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index f2c593da1f74..e411d6470170 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -797,8 +797,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { auto parent = getParent(); if (auto blockedLayout = mlir::dyn_cast(parent)) { return fmaDotToLinearLayout(*this, shape); - } - if (auto mfmaLayout = mlir::dyn_cast(parent)) { + } else if (auto mfmaLayout = mlir::dyn_cast(parent)) { return mfmaDotToLinearLayout(*this, shape); } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { return wmmaDotOperandToLinearLayout(*this, shape); From c372e5a700aa6e72db3154efa4c8e2931f4da95b Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 20 Dec 2024 20:47:09 +0000 Subject: [PATCH 8/8] add dot 3d test --- .../TritonGPU/LinearLayoutConversionsTest.cpp | 63 ++++++++++++------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 30c9f3376f4b..503b3bb0422a 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -325,28 +325,26 @@ TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs) { {S("dim0"), S("dim1")})); } -TEST_F(LinearLayoutConversionsTest, BlockedDotOperandLhs1) { - auto parent = blocked(/*size*/ {1, 1}, /*threads*/ {1, 32}, /*warps*/ {1, 1}, - /*ctas*/ {1, 1}, /*splits*/ {1, 1}, /*order*/ {1, 0}, - /*cta order*/ {1, 0}); +TEST_F(LinearLayoutConversionsTest, BlockedDot3dOperandLhs) { + auto parent = + blocked(/*size*/ {2, 2, 4}, /*threads*/ {2, 4, 4}, /*warps*/ {2, 2, 2}, + /*ctas*/ {1, 1, 1}, /*splits*/ {1, 1, 1}, /*order*/ {2, 1, 0}, + /*cta order*/ {2, 1, 0}); auto dotOperand = dot(parent, /*idx*/ 0, /*kWidth*/ 0); - EXPECT_EQ(toLinearLayout({1, 64}, dotOperand), - LinearLayout({{S("register"), - {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}}, - {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}}, - {S("warp"), {}}, - {S("block"), {}}}, - {S("dim0"), S("dim1")})); - - auto dotOperand1 = dot(parent, /*idx*/ 1, /*kWidth*/ 0); - EXPECT_EQ(toLinearLayout({64, 64}, dotOperand1), - LinearLayout( - {{S("register"), - {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {32, 0}, {0, 32}}}, - {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}}}, - {S("warp"), {}}, - {S("block"), {}}}, - {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({16, 32, 4}, dotOperand), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {1, 0, 0}, + {0, 16, 0}, + {8, 0, 0}}}, + {S("lane"), {{0, 0, 0}, {0, 0, 0}, {0, 2, 0}, {0, 4, 0}, {2, 0, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 8, 0}, {4, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); } TEST_F(LinearLayoutConversionsTest, BlockedDotOperandRhs) { @@ -363,6 +361,29 @@ TEST_F(LinearLayoutConversionsTest, BlockedDotOperandRhs) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, BlockedDot3dOperandRhs) { + auto parent = + blocked(/*size*/ {2, 2, 4}, /*threads*/ {2, 4, 4}, /*warps*/ {2, 2, 2}, + /*ctas*/ {1, 1, 1}, /*splits*/ {1, 1, 1}, /*order*/ {2, 1, 0}, + /*cta order*/ {2, 1, 0}); + auto dotOperand = dot(parent, /*idx*/ 1, /*kWidth*/ 0); + EXPECT_EQ( + toLinearLayout({16, 4, 64}, dotOperand), + LinearLayout( + {{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 1, 0}, + {0, 2, 0}, + {1, 0, 0}, + {0, 0, 32}, + {8, 0, 0}}}, + {S("lane"), {{0, 0, 4}, {0, 0, 8}, {0, 0, 0}, {0, 0, 0}, {2, 0, 0}}}, + {S("warp"), {{0, 0, 16}, {0, 0, 0}, {4, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { EXPECT_EQ(toLinearLayout({16, 16}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})),