Skip to content

Commit

Permalink
Pipeline TMA stores (#3881)
Browse files Browse the repository at this point in the history
Pipeline TMA stores so they are actually async. This is done by changing
the order of the ops, so TMA store wait happens before the local_store,
and re-using alloca by moving it outside of the loop.

Co-authored-by: ThomasRaoux <[email protected]>

---------

Co-authored-by: Thomas Raoux <[email protected]>
  • Loading branch information
pawelszczerbuk and ThomasRaoux authored May 11, 2024
1 parent c549281 commit 161f7a4
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 17 deletions.
3 changes: 1 addition & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1439,8 +1439,7 @@ inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
auto srcTy = cast<RankedTensorType>(src.getType());
auto srcShape = srcTy.getShape();
auto rank = srcShape.size();
assert(rank == 2 ||
rank == 3 && "Unexpected rank of storeDistributedToShared");
assert(rank <= 3 && "Unexpected rank of storeDistributedToShared");
auto dstTy = cast<MemDescType>(dst.getType());
auto srcDistributedLayout = srcTy.getEncoding();
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcDistributedLayout)) {
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,18 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [MemoryEffects<[MemRead<SharedMemory>
let results = (outs TT_Tensor:$result);
}

def TTG_LocalStoreOp : TTG_Op<"local_store", [MemoryEffects<[MemWrite<SharedMemory>]>]> {
let summary = "Store a distributed tensor into a buffer in local memory";

let description = [{
Store a distributed tensor into a buffer in local memory.
}];
let arguments = (ins TT_Tensor:$src, TT_MemDescType:$result);

// Use qualified() otherwise "!tt.memdesc<X>" is printed as "<X>".
let assemblyFormat = [{
$src `,` $result attr-dict `:` type($src) `->` qualified(type($result))
}];
}

#endif
12 changes: 12 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,16 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global",
}];
}

def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
let summary = "wait until all the inputs are read.";
let arguments = (ins I32Attr:$pendings);
let description = [{
Wait until all the read operations are done from the associated store operations.
This is needed before the shared memory can be written to.
}];

let assemblyFormat = "attr-dict";
}


#endif
47 changes: 38 additions & 9 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,30 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(LocalAllocOp op, LocalAllocOpAdaptor adaptor,
void lowerDistributedToShared(Operation *op, Value src, Value dst,
Value adaptorSrc,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto loc = op->getLoc();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy);
auto srcLayout = srcTy.getEncoding();
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
assert(srcTy.getShape().size() == 2 ||
(srcTy.getShape().size() <= 3 && outOrd[2] == 0) &&
"Unexpected rank of ConvertLayout(blocked->shared)");
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op);
auto elemTy = typeConverter->convertType(srcTy.getElementType());

int32_t elemSize = elemTy.getIntOrFloatBitWidth();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
auto dstStrides =
LLVM::getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter);
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
storeDistributedToShared(op.getSrc(), inVals, dstStrides, op.getResult(),
smemBase, elemTy, loc, rewriter, targetInfo);
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(src, inVals, dstStrides, dst, smemBase, elemTy, loc,
rewriter, targetInfo);
}

struct LocalAllocOpConversion
Expand Down Expand Up @@ -73,7 +74,8 @@ struct LocalAllocOpConversion

// If there is an initial tensor, store it into the shared memory.
if (op.getSrc()) {
lowerDistributedToShared(op, adaptor, typeConverter, rewriter,
lowerDistributedToShared(op, op.getSrc(), op.getResult(),
adaptor.getSrc(), typeConverter, rewriter,
targetInfo);
}

Expand Down Expand Up @@ -103,11 +105,38 @@ struct LocalDeallocOpConversion
}
};

struct LocalStoreOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;

LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
lowerDistributedToShared(op, op.getSrc(), op.getResult(), adaptor.getSrc(),
getTypeConverter(), rewriter, targetInfo);
rewriter.eraseOp(op);

return success();
}

private:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
}
18 changes: 18 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,23 @@ struct CanonicalizeConvertFromAlloc
}
};

// local_store(cvt) -> local_store
struct CanonicalizeConvertFromLocalStore
: public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op,
PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(op, convert.getSrc(),
op.getResult());
return mlir::success();
}
};

struct CanonicalizeConvertFromConvert
: public OpRewritePattern<ConvertLayoutOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -2760,6 +2777,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<CanonicalizeConvertFromReshape>(context);
patterns.add<CanonicalizeConvertFromHistogram>(context);
patterns.add<CanonicalizeConvertFromAlloc>(context);
patterns.add<CanonicalizeConvertFromLocalStore>(context);
}

// LocalAllocOp
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_triton_library(TritonGPUTransforms
Pipeliner/OuterLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/SoftwarePipeliner.cpp
Pipeliner/TMAStoresPipeline.cpp
Pipeliner/PipeliningUtility.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);

/// Pipeline the TMA stores in the loop.
bool pipelineTMAStores(scf::ForOp forOp);

/// This does post-processing on the pipelined loop to try to pipeline wgmma
/// ops.
// TODO: this should be included as part of the pipeline but currently the wgmma
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
// the inner loop.
for (scf::ForOp outerLoop : outerLoops)
tryAndPipelineOuterLoop(outerLoop);

// Re-collect loop ops
loops.clear();
getOperation()->walk([&](scf::ForOp forOp) {
// Bail out for loops with num_stage <= 1.
if (getNumStagesOrDefault(forOp) > 1)
loops.push_back(forOp);
});

for (scf::ForOp forOp : loops) {
mlir::triton::pipelineTMAStores(forOp);
}
}
};
} // anonymous namespace
Expand Down
93 changes: 93 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "Schedule.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;

static SmallVector<tt::ExperimentalDescriptorStoreOp>
getTMAStores(scf::ForOp forOp) {
SmallVector<tt::ExperimentalDescriptorStoreOp> tmaStores;

// Do not use walk, as we don't want to walk into nested loops.
std::function<void(Operation *)> collectTMAStores = [&](Operation *op) {
if (auto storeOp = dyn_cast<tt::ExperimentalDescriptorStoreOp>(op)) {
tmaStores.push_back(storeOp);
}
for (Region &region : op->getRegions()) {
for (Operation &op : region.getOps()) {
if (!isa<scf::ForOp>(op))
collectTMAStores(&op);
}
}
};
collectTMAStores(forOp);
return tmaStores;
}

static Value createAlloc(scf::ForOp &forOp,
tt::ExperimentalDescriptorStoreOp storeOp) {
OpBuilder builder(forOp);
auto ty = cast<RankedTensorType>(storeOp.getSrc().getType());
auto order = ttg::getOrder(ty.getEncoding());
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
Attribute encoding =
ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, ctaLayout);
if (ty.getRank() > 1) {
encoding = ttg::SharedEncodingAttr::get(
ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType());
}

Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(),
encoding, /*mutableMemory*/ true);
Value alloc = builder.create<ttg::LocalAllocOp>(storeOp->getLoc(),
memdescType, Value());
return alloc;
}

static void createTMAAsyncCopy(scf::ForOp &forOp,
tt::ExperimentalDescriptorStoreOp storeOp,
Value alloc) {
OpBuilder builder(storeOp);
auto loc = storeOp.getLoc();
auto ty = cast<RankedTensorType>(storeOp.getSrc().getType());
auto order = ttg::getOrder(ty.getEncoding());
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());

// Put wait before the local_store make the store truly async. We know
// that we are the only user of the CopyLocalToGlobal.
builder.create<ttng::TMAStoreWait>(loc, 0);
builder.create<ttg::LocalStoreOp>(loc, storeOp.getSrc(), alloc);
builder.create<ttng::FenceAsyncSharedOp>(loc, false);
builder.create<ttng::AsyncTMACopyLocalToGlobalOp>(
loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc);

storeOp->erase();
}

bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) {
SmallVector<tt::ExperimentalDescriptorStoreOp> tmaStores =
getTMAStores(forOp);
if (tmaStores.empty())
return false;

DenseMap<tt::ExperimentalDescriptorStoreOp, Value> storeToAlloc;
for (tt::ExperimentalDescriptorStoreOp op : tmaStores) {
storeToAlloc[op] = createAlloc(forOp, op);
}

for (tt::ExperimentalDescriptorStoreOp op : tmaStores) {
createTMAAsyncCopy(forOp, op, storeToAlloc[op]);
}

// Deallocate shared memory buffers.
OpBuilder builder(forOp);
builder.setInsertionPointAfter(forOp);
builder.create<ttng::TMAStoreWait>(forOp->getLoc(), 0);
for (auto it : storeToAlloc) {
builder.create<ttg::LocalDeallocOp>(forOp->getLoc(), it.second);
}
return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
bool isMMAV3 =
isa<NvidiaMmaEncodingAttr>(encoding) &&
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
if (isMMAV3 && isa<LocalAllocOp>(op))
if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
return true;
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
}
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::HistogramOp,
triton::gpu::LocalAllocOp>(op);
triton::gpu::LocalAllocOp, triton::gpu::LocalStoreOp>(op);
}

scf::ForOp replaceForOpWithNewSignature(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class TMAStoreLowering
rewriter.create<triton::nvidia_gpu::FenceAsyncSharedOp>(loc, false);
rewriter.create<triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp>(
loc, op.getDescPtr(), op.getIndices(), alloc);
rewriter.create<triton::nvidia_gpu::TMAStoreWait>(loc, 0);
rewriter.eraseOp(op);
return success();
}
Expand Down
13 changes: 12 additions & 1 deletion test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: elect.sync
// CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void
// CHECK: cp.async.bulk.commit_group
// CHECK: cp.async.bulk.wait_group 0
tt.func @tma_copy_local_to_global(%tma: !tt.ptr<i64>, %alloc: !tt.memdesc<128x128xf32, #shared1>, %x: i32) {
triton_nvidia_gpu.async_tma_copy_local_to_global %tma[%x, %x] %alloc : <i64>, <128x128xf32, #shared1>
tt.return
}
}

// -----

#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: async_tma_store_wait
// CHECK: "cp.async.bulk.wait_group.read 0x0;", "" : () -> !llvm.void
tt.func @async_tma_store_wait() {
triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32}
tt.return
}
}
19 changes: 19 additions & 0 deletions test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return %17#0, %17#2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>
}
}

// -----
// Test pipelining of experimental_descriptor_store
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tma_store_pipeline
tt.func public @tma_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr<i8>, %arg2: i32, %arg3: i32) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 {
%1 = arith.divsi %arg4, %arg2 : i32
// CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32}
// CHECK-NEXT: triton_gpu.local_store
// CHECK-NEXT: triton_nvidia_gpu.fence_async_shared
// CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global
tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr<i8>, tensor<1xf32, #blocked>
}
tt.return
}
}
Loading

0 comments on commit 161f7a4

Please sign in to comment.