Skip to content

Commit

Permalink
[BACKEND] LL for ldmatrix part1 - fp16 and no slicing shared memory f…
Browse files Browse the repository at this point in the history
…or both operands (#5548)

All limitations of ldmatrix have been noted in the comments; those with
a TODO label should be addressed in following PRs.

Discussed with @lezcano, these limitations can be removed in a formal
and generic way instead of using heuristics.

1. Divide check: Check if we have enough elements to use `ldmatrix.xn`,
where `n` ranges from 1 to 4. This could be implemented through
`divideLeft`.
2. Tile check: Check if the `4 / sizeof(elem)` registers are contiguous,
the first four lanes are contiguous, and the remaining lanes are on
subsequent rows. For example, given `sizeof(elem)=4`, we check if
`layout[kLane]=={(1, 0), (2, 0), (0, 1), (0, 2), (0, 4)}`.
3. Address check: Check if elements on accessed addresses are
contiguous.
4. Compose layout: Spreading lanes on each row of the tile and repeat it
along the original shape.
  • Loading branch information
Jokeren authored Jan 14, 2025
1 parent 8f6e9d2 commit e1697f6
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 76 deletions.
11 changes: 11 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,17 @@ inline Value packLLVector(Location loc, ValueRange vals,
return vec;
}

inline bool
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
ArrayRef<int64_t> allocShape,
triton::gpu::SharedEncodingAttr sharedEnc) {
auto rank = shape.size();
return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
/*swizzling but same shape*/ shape == allocShape ||
/*swizzling and rank-reduced and rank >= 2*/
(shape == allocShape.take_back(rank) && rank >= 2);
}

} // namespace mlir

#endif
8 changes: 5 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// bit width of the tensor in the future to support more flexible tensor
// encodings
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize);

// The primary goal of this function is to efficiently store 2D tiles of a
// tensor into shared memory using the `ldmatrix` instruction.
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
2 changes: 1 addition & 1 deletion include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,4 +725,4 @@ inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) {

} // namespace mlir::triton

#endif
#endif // TRITON_TOOLS_LINEARLAYOUT_H
6 changes: 2 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
LinearLayout shmemStoreLayout =
isStMatrix ? chooseStMatrixLayout(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0)
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
/*swizzleByteSize=*/0)
: srcLayout.invertAndCompose(sharedLayout);

const int shmemAllocatedNumElems =
Expand Down
5 changes: 1 addition & 4 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ Value getSmemVecAddr(RankedTensorType registerTy,
// We propose case 2 (see comments below), which provides a more general
// solution for all swizzled shared memory scenarios, including the edge case
// mentioned above.
if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
/*swizzling but same shape*/ shape == allocShape ||
/*swizzling and rank-reduced and rank >= 2*/
(shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
// offsetXN, block), where the offsets appear in minor-to-major order, and
// we drop_end to drop block, which we know from above will be 0.
Expand Down
115 changes: 97 additions & 18 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,9 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
}

namespace {
LinearLayout chooseStMatrixLayoutLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
int swizzleByteSize) {
LinearLayout chooseStMatrixLayoutLeadingOffset(MLIRContext *ctx,
RankedTensorType tensorTy,
int swizzleByteSize) {
int perPhase;
int maxPhase;
if (swizzleByteSize == 32) {
Expand Down Expand Up @@ -1064,9 +1063,9 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
{{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}});
}

LinearLayout chooseStMatrixLayoutNoLeadingOffset(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
Attribute encoding,
ArrayRef<int64_t> shape) {
StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
Expand All @@ -1081,17 +1080,16 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});

// Expand the `register` dimension so the size of columns matches `n`.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
auto mma = cast<NvidiaMmaEncodingAttr>(encoding);
int n = mma.getInstrShape()[1];
layout *=
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);

// Expand the `warp` dimension according to warpsPerCTA.
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), shape);
auto tensorShapePerCTA = getShapePerCTA(mma, shape);
llvm::SmallDenseMap<StringAttr, int64_t> namedTensorShape;
namedTensorShape[kRow] = tensorShapePerCTA[0];
namedTensorShape[kCol] = tensorShapePerCTA[1];
Expand All @@ -1102,19 +1100,100 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

LinearLayout chooseLdMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
SharedEncodingAttr shared,
DotOperandEncodingAttr dot,
ArrayRef<int64_t> shape) {
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
auto rank = shape.size();
auto opIdx = dot.getOpIdx();
int kDim = opIdx == 0 ? rank - 1 : rank - 2;

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
StringAttr kBlock = S("block");
StringAttr kInner = opIdx == 0 ? S("dim1") : S("dim0");
StringAttr kOuter = opIdx == 0 ? S("dim0") : S("dim1");

std::vector<std::vector<int>> basesReg = {{0, 1}, {0, 2}, {0, 4}};
std::vector<std::vector<int>> basesLane;
auto numRowsPerTile = 16;
auto numColsPerTile = 16;
int vecSize = shared.getVec();
int perPhase = shared.getPerPhase();
int maxPhase = shared.getMaxPhase();
auto warpsPerCTA = mma.getWarpsPerCTA();
// Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
// efficiently. opIdx=0 and opIdx=1 are handled differently.
if (opIdx == 0) {
// The matrix elements of thread 0 are distributed in the following pattern:
//
// col0 col8
// row0 reg[0-1] reg[4-5]
// row8 reg[2-3] reg[6-7]
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
int row = 1 << logRow;
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
}
basesLane.push_back({0, numColsPerTile / 2});
// Expand the `register` dimension so the size of columns matches `K`.
for (int logCol = 0; logCol < llvm::Log2_32(shape[kDim] / numColsPerTile);
logCol++) {
int col = 1 << logCol;
basesReg.push_back({0, numColsPerTile * col});
}
} else {
// The matrix elements of thread 0 are distributed in the following pattern:
//
// col0 col8 col16 col24
// row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
// 8x8
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile / 2); logRow++) {
int row = 1 << logRow;
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
}
// 8x16
basesLane.push_back({0, numColsPerTile / 2});
// 8x32
basesLane.push_back({0, numColsPerTile});
// Expand the `register` dimension so the size of columns matches `K`.
for (int logCol = 0;
logCol < llvm::Log2_32(shape[kDim] / (numColsPerTile * 2)); logCol++) {
int col = 1 << logCol;
basesReg.push_back({0, (numColsPerTile * 2) * col});
}
}
auto layout = LinearLayout(
{{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, {kOuter, kInner});
// Expand the `warp` dimension according to warpsPerCTA.
layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, mma.getWarpOrder(),
kDim, kWarp)
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret = combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
return ret.transposeOuts({kInner, kOuter})
.reshapeOuts(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

} // anonymous namespace

LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order,
int swizzleByteSize) {
if (swizzleByteSize == 0)
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
paddedRepShape, order);
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy.getEncoding(),
tensorTy.getShape());
else
return chooseStMatrixLayoutLeadingOffset(
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize);
}

LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
Attribute dotEnc, ArrayRef<int64_t> shape) {
auto shared = cast<SharedEncodingAttr>(sharedEnc);
auto dot = cast<DotOperandEncodingAttr>(dotEnc);
assert(!shared.getHasLeadingOffset() &&
"Ldmatrix does not support leading offset yet");
return chooseLdMatrixLayoutNoLeadingOffset(ctx, shared, dot, shape);
}

} // namespace mlir::triton::gpu
Loading

0 comments on commit e1697f6

Please sign in to comment.