Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion from FMA dot operand to linear layout #5469

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
}
}

void verifyCTALayout(CTALayoutAttr ctaLayout) {
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
auto ctaSplit = ctaLayout.getCTASplitNum();
for (auto split : ctaSplit) {
if (split != 1)
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
"are not supported in FMA dot yet.");
return false;
}
return true;
}

/// Get a linear offset of first element loaded by thread.
Expand Down Expand Up @@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
Value thread, Location loc,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const int dotOpNo) {
verifyCTALayout(dLayout.getCTALayout());
if (!verifyCTALayout(dLayout.getCTALayout()))
return Value();
Comment on lines +219 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that this path is still preferred over LLs? Could you also make LLs the preferred path, or is there anything blocking us from doing so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe LL is already preferred, but I don't want to remove legacy converter yet, just in case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in which case did you hit the hard error? Can we just revert these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not hit any errors so far. But want to compare code that is generated by legacy and new converter to see if there are differences in perf/register usage, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps now this can be reverted before merging?


DimIdx dim;
dim.batch = 0;
Expand Down Expand Up @@ -292,6 +293,15 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
auto numBTiles = std::max(1u, B / shapePerCTABTile);
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);

// Found discrepancy in this case,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You meant this is a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will reword this.

// use linear layout based converter for this case
// TODO: break batch and non-k dimension iterations in
// "repeat" and "inside-repeate" parts, pack them in llvm structure
// according repeat and register order.
// See FMA.cpp:getValueTableFromStructFMA for reference
if (numBTiles != 1 || numNonKTiles != 1)
return Value();

auto perThreadShape =
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);

Expand Down
102 changes: 72 additions & 30 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,51 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;

using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
k.kIdx);
}
};

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

static ValueTableFMA
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
unsigned kDim, unsigned nonKDim,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<unsigned> order) {
static ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perTileShape.size() == 3);
assert(elems.size() == product(perTileShape));
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}
Expand All @@ -54,46 +81,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned perThreadShape[3];
unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
numRep = std::max(static_cast<unsigned>(1), numRep);
perThreadShape[i] = numRep * sizePerThread[i];
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {perThreadShape[0], perThreadShape[1], K},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {perThreadShape[0], K, perThreadShape[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned b = 0; b < perThreadShape[0]; ++b)
for (unsigned m = 0; m < perThreadShape[1]; ++m)
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearAccumIdx =
linearize(multiDimAccumIdx, perThreadShape, order);
for (unsigned k = 0; k < K; ++k) {
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
}
}
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx =
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;
for (unsigned k = 0; k < K; ++k) {
auto aOp = has[{bRep, mRep, b, m, k}];
auto bOp = hbs[{bRep, nRep, b, n, k}];
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, aOp, bOp, acc[linearAccumIdx]);
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);
Expand Down
43 changes: 1 addition & 42 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,54 +119,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedLayout(Attribute dstLayout) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout))
return true;
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
if (isa<MmaEncodingTrait>(dot.getParent()))
return true;
}
return false;
};

LogicalResult
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType dstTy = op.getType();
Attribute dstLayout = dstTy.getEncoding();
if (isSupportedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
if (isa<DotOperandEncodingAttr>(dstLayout) &&
isa<BlockedEncodingAttr>(
cast<DotOperandEncodingAttr>(dstLayout).getParent())) {
return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter);
}
return failure();
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
}

private:
LogicalResult
lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
RankedTensorType dstTy = op.getType();
Attribute dstLayout = dstTy.getEncoding();
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
auto blockedLayout = cast<BlockedEncodingAttr>(
cast<DotOperandEncodingAttr>(dstLayout).getParent());
auto thread = getThreadId(rewriter, loc);
Value res = SharedToDotOperandFMA::convertLayout(
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
thread, loc, getTypeConverter(), rewriter);
rewriter.replaceOp(op, res);
return success();
}
LogicalResult
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
Expand Down
79 changes: 63 additions & 16 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,11 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
}

LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
ArrayRef<unsigned> warpOrder, unsigned inner) {
/// Function to generate lane and warp layout for dot operands.
LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
unsigned kDim, StringAttr inDimName) {
// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {1, 0}
Expand All @@ -255,24 +258,23 @@ LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In other words, we need to broadcast along K
auto rank = warpShape.size();
auto rank = shape.size();
auto dimNames = standardOutDimNames(ctx, rank);
LinearLayout warpLayout = LinearLayout::empty();
LinearLayout layout = LinearLayout::empty();

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
// For B, when moving along N we go from 0 to 1.
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
// Same happens if the warpOrder is {0, 1}, like in Hopper
for (auto d : warpOrder) {
if (d == inner) {
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]);
for (auto d : order) {
if (d == kDim) {
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
} else {
warpLayout *=
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]);
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
}
}
return warpLayout;
return layout;
}

} // anonymous namespace
Expand Down Expand Up @@ -620,7 +622,8 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
// Generate warp layout
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout);
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim);
LinearLayout warpLayout =
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp"));

// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
// extension of layout
Expand Down Expand Up @@ -650,6 +653,48 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

std::optional<LinearLayout>
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
ArrayRef<int64_t> shape) {
int rank = shape.size();
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
MLIRContext *ctx = operandLayout.getContext();

// TODO: introduce registerOrder or use getOrder(operandLayout)
// Currently this order is used in legacy converter, because we do not
// have access to full dot operand layout, only parent part.
auto regOrder = blocked.getOrder();
// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there's something wrong here. Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that warps have to be broadcasted along the k dimension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested a warp shape of [2, 2] with more than 1 warp on the k dimension?

Yes, this converter works with any warp shape. The trick is I create "thread" part of layout using whole K dimension, so any number of warps or threads across k dimension will be squeezed in combineCtaCgaWithShape call.

Let's take an example with dot A operand with shape [m=32, k=32]:
parent layout is perThreadShape=[1, 1], threads=[8, 4], warps=[2,2]

  1. "per-thread" layout identityStandardND(kReg, threadSize, regOrder) will cover shape [m=1, k=32]
  2. "per-warp" layout ... * identityStandardND(kLane, threadShape, threadOrder) will cover shape [m=8, k=32*4=128]
  3. "full" layout ... * warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx) will cover shape [m=16, k=3282=256]

Then I apply combineCtaCgaWithShape it repeats m dimension two times, but "broadcasts" k dimension down to 32, so all threads and warps across K dimension holds same values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think you have to use warpsDotOperand here.

I thought about this, the only reason I choose to go without it for now is aestetic:

At this moment cta tile constructions looks like this:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kWarp, warpShape, warpOrder)
                               .transposeOuts(repDimNames);

with warpsDotOperand:

  LinearLayout ctaLayout = identityStandardND(kReg, threadSize, regOrder)
                               .transposeOuts(repDimNames) *
                           identityStandardND(kLane, threadShape, threadOrder)
                               .transposeOuts(repDimNames) *
                           warpsDotOperand(ctx, warpShape, warpOrder, kDimIdx)
                               .transposeOuts(repDimNames);

in second variant layout is still extends beyond K dimension due to lane component. to make it uniform I can introduce something like laneDotOperand, but this function will be used only in one place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the per thread layout would be

[r0, r1]
[r2, r3]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, [2*2] is a blocked parent layout property.

Dot operand is slightly different. It inherits all attributes of a parent, except k dimension. Dot operand layout implicitly extends per thread size to [2, K] for A operand and [K, 2] for B operand

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture you mentioned is related to intermediate layout, before it is expanded with combineCtaCgaWithShape

Copy link
Contributor

@Jokeren Jokeren Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think it's different from the cases where the parent is mma; we don't do implicit broadcasting on only the register dimension. That being said, I think some code clean ups have to happen later. Right now, several methods crash on this dotoperand, like getElemsPerThread

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update looks good to me now. Thanks!

auto repOrder = blocked.getRepOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be operandLayout.getRepOrder?

Copy link
Contributor Author

@binarman binarman Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not implemented at the moment for blocked parent.

We had a conversation about dot operand order functions with @lezcano and did not come to a certain decision.

So for now I am just using parent order functions everywhere.


StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

SmallVector<unsigned> threadSize = blocked.getSizePerThread();
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
threadSize[kDimIdx] = shape[kDimIdx];
auto threadShape = blocked.getThreadsPerWarp();
auto warpShape = blocked.getWarpsPerCTA();

SmallVector<StringAttr> repDimNames =
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);

auto registersLayout = identityStandardND(kReg, threadSize, regOrder);
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder,
kDimIdx, kLane);
auto warpsLayout =
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp);

LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) *
lanesLayout.transposeOuts(repDimNames) *
warpsLayout.transposeOuts(repDimNames);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
}

LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
Expand Down Expand Up @@ -740,19 +785,21 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
auto kDim = isA ? rank - 1 : rank - 2;
ctaLayout *=
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(),
mma.getWarpOrder(), kDim, S("warp"))
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
return fmaDotToLinearLayout(*this, shape);
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
return wmmaDotOperandToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
return nvidiaDotToLinearLayout(shape, *this);
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this tensor larger, because with introduction of linear layout input and output tensors turned out to be compatible.

// CHECK-NOT: ttg.convert_layout
// CHECK: ttg.local_alloc
// CHECK: ttg.local_load
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}
Loading
Loading