Skip to content

Commit

Permalink
[DMA] Add combine_token op
Browse files Browse the repository at this point in the history
Being able to dynamically combine tokens to await a runtime-dependent number of DMA transfers is a requirement for implementing legalization of DMA transfers in the `dma` dialect.

This PR therefore adds the `combined_tokens` op which combines multiple tokens into one.
The lowering for Snitch leverages the monotonicity guarantee of IDs to combine them. Note that this only works with a single channel DMA right now.

As this subsumes the multi-token capabilities of `wait_for_transfers`, it has been renamed to just `wait_for_transfer` and only accepts a single token input now.
  • Loading branch information
zero9178 committed Sep 1, 2024
1 parent a35890f commit cb0c2fc
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 81 deletions.
48 changes: 31 additions & 17 deletions codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,22 +393,13 @@ struct StartZeroMemTransferOpOpLowering
}
};

struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern<WaitForTransfersOp> {
struct WaitForTransferOpLowering : ConvertOpToLLVMPattern<WaitForTransferOp> {

using ConvertOpToLLVMPattern<WaitForTransfersOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<WaitForTransferOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(WaitForTransfersOp op, OpAdaptor adaptor,
matchAndRewrite(WaitForTransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getTokens().empty()) {
rewriter.eraseOp(op);
return success();
}

Value current = adaptor.getTokens().front();
for (Value iter : llvm::drop_begin(adaptor.getTokens()))
current = rewriter.create<LLVM::UMaxOp>(op->getLoc(), current, iter);

Block *prev = op->getBlock();
Block *body = rewriter.splitBlock(prev, op->getIterator());
Block *after = rewriter.splitBlock(body, op->getNextNode()->getIterator());
Expand All @@ -417,8 +408,9 @@ struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern<WaitForTransfersOp> {

rewriter.setInsertionPointToEnd(body);
Value lastCompleted = rewriter.create<SnitchDMA::StatOp>(op->getLoc());
Value notDone = rewriter.create<LLVM::ICmpOp>(
op->getLoc(), LLVM::ICmpPredicate::ult, lastCompleted, current);
Value notDone =
rewriter.create<LLVM::ICmpOp>(op->getLoc(), LLVM::ICmpPredicate::ult,
lastCompleted, adaptor.getToken());
rewriter.create<LLVM::CondBrOp>(op->getLoc(), notDone, body, after);

rewriter.setInsertionPointToStart(after);
Expand All @@ -440,6 +432,28 @@ struct CompletedTokenOpLowering : ConvertOpToLLVMPattern<CompletedTokenOp> {
}
};

struct CombineTokensOpLowering : ConvertOpToLLVMPattern<CombineTokensOp> {

using ConvertOpToLLVMPattern<CombineTokensOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CombineTokensOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getTokens().empty()) {
rewriter.replaceOpWithNewOp<CompletedTokenOp>(op);
return success();
}

// TODO: Note that this lowering only works for Snitch's single channel DMA!
Value current = adaptor.getTokens().front();
for (Value iter : llvm::drop_begin(adaptor.getTokens()))
current = rewriter.create<LLVM::UMaxOp>(op->getLoc(), current, iter);

rewriter.replaceOp(op, current);
return success();
}
};

struct StatOpLowering : ConvertOpToLLVMPattern<SnitchDMA::StatOp> {
using ConvertOpToLLVMPattern<SnitchDMA::StatOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -485,9 +499,9 @@ void quidditch::populateDMAToLLVMConversionPatterns(
i32, ArrayRef<Type>{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT}));
dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr());

patterns.insert<CompletedTokenOpLowering, WaitForTransfersOpLowering,
StartZeroMemTransferOpOpLowering, StatOpLowering>(
typeConverter);
patterns.insert<CompletedTokenOpLowering, WaitForTransferOpLowering,
StartZeroMemTransferOpOpLowering, StatOpLowering,
CombineTokensOpLowering>(typeConverter);
patterns.insert<StartTransferOp1DLowering>(dmaStart1D, typeConverter);
patterns.insert<StartTransferOp2DLowering>(dmaStart2D, typeConverter);
patterns.insert<StartContiguousZeroMemTransferOpOpLowering>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ struct StartZeroMemTransferOpDMAImpl
StartZeroMemTransferOpDMAImpl, StartZeroMemTransferOp> {};

//===----------------------------------------------------------------------===//
// WaitForTransfersOpImpl::DMACoreSpecializationOpInterface
// WaitForTransferOpImpl::DMACoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

struct WaitForTransfersOpImpl
: CoreSpecializationOpInterface::ExternalModel<WaitForTransfersOpImpl,
WaitForTransfersOp> {
struct WaitForTransferOpImpl
: CoreSpecializationOpInterface::ExternalModel<WaitForTransferOpImpl,
WaitForTransferOp> {
void replaceWithNoop(Operation *op, RewriterBase &rewriter) const {
rewriter.eraseOp(op);
}

bool needsSynchronization(Operation *op) const { return true; }
};

struct WaitForTransfersOpDMAImpl
: DMACoreSpecializationOpInterface::ExternalModel<WaitForTransfersOpDMAImpl,
WaitForTransfersOp> {};
struct WaitForTransferOpDMAImpl
: DMACoreSpecializationOpInterface::ExternalModel<WaitForTransferOpDMAImpl,
WaitForTransferOp> {};

} // namespace

Expand All @@ -71,6 +71,6 @@ void quidditch::dma::registerDMACoreSpecializationOpInterface(
#define REGISTER_IMPLS(Op) Op::attachInterface<Op##Impl, Op##DMAImpl>(*context)
REGISTER_IMPLS(StartTransferOp);
REGISTER_IMPLS(StartZeroMemTransferOp);
REGISTER_IMPLS(WaitForTransfersOp);
REGISTER_IMPLS(WaitForTransferOp);
});
}
25 changes: 6 additions & 19 deletions codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter,
// memory as the transfer below. This is currently unspecified
// behaviour (both in the DMA dialect and in Snitch as far as we
// know).
rewriter.create<WaitForTransfersOp>(getLoc(), token);
rewriter.create<WaitForTransferOp>(getLoc(), token);
}

// Subview into the original memory without any padding.
Expand Down Expand Up @@ -379,7 +379,7 @@ WaitForTensorCopyOp::bufferize(RewriterBase &rewriter,
if (failed(transferTensorBuffer))
return failure();

rewriter.create<WaitForTransfersOp>(getLoc(), getToken());
rewriter.create<WaitForTransferOp>(getLoc(), getToken());
replaceOpWithBufferizedValues(rewriter, getOperation(),
*transferTensorBuffer);
return success();
Expand Down Expand Up @@ -414,25 +414,12 @@ OpFoldResult StartTransferOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// WaitForTransfersOp
// WaitForTransferOp
//===----------------------------------------------------------------------===//

LogicalResult WaitForTransfersOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
bool changed = false;
MutableOperandRange tokens = getTokensMutable();
for (int i = tokens.size() - 1; i >= 0; i--) {
if (adaptor.getTokens()[i]) {
changed = true;
tokens.erase(i);
}
}
return success(changed);
}

LogicalResult WaitForTransfersOp::canonicalize(WaitForTransfersOp op,
PatternRewriter &rewriter) {
if (!op.getTokens().empty())
LogicalResult WaitForTransferOp::canonicalize(WaitForTransferOp op,
PatternRewriter &rewriter) {
if (!matchPattern(op.getToken(), m_Constant<CompletedTokenAttr>(nullptr)))
return failure();

rewriter.eraseOp(op);
Expand Down
28 changes: 21 additions & 7 deletions codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,19 @@ def DMA_StartZeroMemTransferOp : DMA_Op<"start_zero_mem_transfer",
}];
}

def DMA_WaitForTransfersOp : DMA_Op<"wait_for_transfers"> {
def DMA_WaitForTransferOp : DMA_Op<"wait_for_transfer"> {

let description = [{
Operation awaiting for DMA transfers denoted by its tokens to be finished.
Operation awaiting for all DMA transfers denoted by its token to have
finished.
}];

let arguments = (ins
Variadic<DMA_TokenType>:$tokens
);
let arguments = (ins DMA_TokenType:$token);

let assemblyFormat = [{
($tokens^ `:` type($tokens))? attr-dict
$token attr-dict
}];

let hasFolder = 1;
let hasCanonicalizeMethod = 1;
}

Expand All @@ -214,4 +212,20 @@ def DMA_CompletedTokenOp
let hasFolder = 1;
}

def DMA_CombineTokensOp : DMA_Op<"combine_tokens", [Pure]> {

let description = [{
Op combining multiple DMA tokens into one.
Awaiting the token returned by this function is equal in effect as if each
token was awaited independently in unspecified order.
}];

let arguments = (ins Variadic<DMA_TokenType>:$tokens);
let results = (outs DMA_TokenType:$result);

let assemblyFormat = [{
$tokens attr-dict
}];
}

#endif
2 changes: 1 addition & 1 deletion codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
[](OpBuilder &builder, Location loc, Value from, Value to) {
Value token =
builder.create<quidditch::dma::StartTransferOp>(loc, from, to);
builder.create<quidditch::dma::WaitForTransfersOp>(loc, token);
builder.create<quidditch::dma::WaitForTransferOp>(loc, token);
return success();
};

Expand Down
17 changes: 17 additions & 0 deletions codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: quidditch-opt %s --quidditch-convert-to-llvm | FileCheck %s

// CHECK-LABEL: @test
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token) {
// CHECK: llvm.br ^[[BODY:[[:alnum:]]+]]
// CHECK: ^[[BODY]]:
// CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero
// CHECK-SAME: "=r"
// CHECK-SAME: -> i32
// CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]]
// CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]]
// CHECK: ^[[CONT]]:
dma.wait_for_transfer %arg0
// CHECK-NEXT: llvm.return
return
}
26 changes: 14 additions & 12 deletions codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

// CHECK-LABEL: @test
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token) {
// CHECK: llvm.br ^[[BODY:[[:alnum:]]+]]
// CHECK: ^[[BODY]]:
// CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero
// CHECK-SAME: "=r"
// CHECK-SAME: -> i32
// CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]]
// CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]]
// CHECK: ^[[CONT]]:
dma.wait_for_transfers %arg0 : !dma.token
// CHECK-NEXT: llvm.return
return
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token, %arg1 : !dma.token) -> !dma.token {
// CHECK: %[[TOKEN:.*]] = llvm.intr.umax(%[[ARG0]], %[[ARG1]])
%token = dma.combine_tokens %arg0, %arg1
// CHECK: llvm.return %[[TOKEN]]
return %token : !dma.token
}

// CHECK-LABEL: @test_empty(
func.func private @test_empty() -> !dma.token {
// CHECK: %[[TOKEN:.*]] = llvm.mlir.constant(0 :
%token = dma.combine_tokens
// CHECK: return %[[TOKEN]]
return %token : !dma.token
}
2 changes: 1 addition & 1 deletion codegen/tests/Dialect/DMA/IR/bufferization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func.func @tensor_copy_pad(%arg0 : tensor<?x?xf32>, %pad0 : index, %pad1 : index
// CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]]
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]])
// CHECK: %[[ZERO_TOKEN:.*]] = dma.start_zero_mem_transfer %[[ALLOC]]
// CHECK: dma.wait_for_transfers %[[ZERO_TOKEN]]
// CHECK: dma.wait_for_transfer %[[ZERO_TOKEN]]
// CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1]
// CHECK: %[[TOKEN:.*]] = dma.start_transfer from %[[COPY]]
// CHECK-SAME: to %[[UNPADDED]]
Expand Down
2 changes: 1 addition & 1 deletion codegen/tests/Dialect/DMA/IR/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
func.func @wait_gets_removed() {
// CHECK-NEXT: return
%0 = dma.completed_token
dma.wait_for_transfers %0 : !dma.token
dma.wait_for_transfer %0
return
}

Expand Down
7 changes: 1 addition & 6 deletions codegen/tests/Dialect/DMA/IR/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
// RUN: quidditch-opt %s --verify-roundtrip

func.func @test(%arg0 : memref<f64>) {
dma.wait_for_transfers
return
}

func.func @test3(%arg0 : tensor<?x4xf64>) -> (tensor<?x4xf64>, !dma.token) {
func.func @test(%arg0 : tensor<?x4xf64>) -> (tensor<?x4xf64>, !dma.token) {
%0:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor<?x4xf64>
return %0#0, %0#1 : tensor<?x4xf64>, !dma.token
}
6 changes: 3 additions & 3 deletions codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ func.func @test(
^bb0(%arg1: index, %arg2: memref<40x100xf64, #quidditch_snitch.l1_encoding>, %arg3: !dma.token):
// CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP3]](%[[IV]])
// CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]]
// CHECK: wait_for_transfers %[[YIELDED1]]
// CHECK: wait_for_transfer %[[YIELDED1]]
// CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[YIELDED0]] : {{.*}})
// CHECK: yield %[[NEXT_YIELDED]], %{{.*}} :

%subview_3 = memref.subview %alloca[0, %arg1] [1, 40] [1, 1] : memref<1x1200xf64, #quidditch_snitch.l1_encoding> to memref<1x40xf64, strided<[1200, 1], offset: ?>, #quidditch_snitch.l1_encoding>
dma.wait_for_transfers %arg3 : !dma.token
dma.wait_for_transfer %arg3
linalg.matmul_transpose_b
ins(%alloca2, %arg2 : memref<1x100xf64, #quidditch_snitch.l1_encoding>, memref<40x100xf64, #quidditch_snitch.l1_encoding>)
outs(%out : memref<1x40xf64, #quidditch_snitch.l1_encoding>)
}
// CHECK: %[[IV:.*]] = affine.apply #[[$MAP4]]()
// CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP5]]()
// CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]]
// CHECK: wait_for_transfers %[[LAST]]#1
// CHECK: wait_for_transfer %[[LAST]]#1
// CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[LAST]]#0 : {{.*}})
return
}
12 changes: 6 additions & 6 deletions codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: quidditch_snitch.barrier
dma.start_transfer from %a : memref<32xf32> to %a_l1 : memref<32xf32>
%t = dma.start_transfer from %b : memref<32xf32> to %b_l1 : memref<32xf32>
dma.wait_for_transfers %t : !dma.token
dma.wait_for_transfer %t

// CHECK-NEXT: microkernel
// CHECK: }
Expand All @@ -34,7 +34,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: dma.completed_token
%t2 = dma.start_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32>
// CHECK-NEXT: quidditch_snitch.barrier
dma.wait_for_transfers %t2 : !dma.token
dma.wait_for_transfer %t2


// CHECK: scf.if
Expand All @@ -55,7 +55,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
scf.yield %c, %i : !dma.token, index
}
// CHECK: quidditch_snitch.barrier
dma.wait_for_transfers %r#0 : !dma.token
dma.wait_for_transfer %r#0
// CHECK-NEXT: return
return
}
Expand All @@ -69,12 +69,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {

// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.wait_for_transfers
// CHECK-NEXT: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: quidditch_snitch.barrier
// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.wait_for_transfers
// CHECK-NEXT: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: scf.if
Expand All @@ -85,7 +85,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: completed_token
// CHECK-NEXT: arith.constant
// CHECK-NEXT: yield
// CHECK: dma.wait_for_transfers
// CHECK: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: return

0 comments on commit cb0c2fc

Please sign in to comment.