Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DMA] Add combine_token op #130

Merged
merged 1 commit into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,22 +393,13 @@ struct StartZeroMemTransferOpOpLowering
}
};

struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern<WaitForTransfersOp> {
struct WaitForTransferOpLowering : ConvertOpToLLVMPattern<WaitForTransferOp> {

using ConvertOpToLLVMPattern<WaitForTransfersOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<WaitForTransferOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(WaitForTransfersOp op, OpAdaptor adaptor,
matchAndRewrite(WaitForTransferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getTokens().empty()) {
rewriter.eraseOp(op);
return success();
}

Value current = adaptor.getTokens().front();
for (Value iter : llvm::drop_begin(adaptor.getTokens()))
current = rewriter.create<LLVM::UMaxOp>(op->getLoc(), current, iter);

Block *prev = op->getBlock();
Block *body = rewriter.splitBlock(prev, op->getIterator());
Block *after = rewriter.splitBlock(body, op->getNextNode()->getIterator());
Expand All @@ -417,8 +408,9 @@ struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern<WaitForTransfersOp> {

rewriter.setInsertionPointToEnd(body);
Value lastCompleted = rewriter.create<SnitchDMA::StatOp>(op->getLoc());
Value notDone = rewriter.create<LLVM::ICmpOp>(
op->getLoc(), LLVM::ICmpPredicate::ult, lastCompleted, current);
Value notDone =
rewriter.create<LLVM::ICmpOp>(op->getLoc(), LLVM::ICmpPredicate::ult,
lastCompleted, adaptor.getToken());
rewriter.create<LLVM::CondBrOp>(op->getLoc(), notDone, body, after);

rewriter.setInsertionPointToStart(after);
Expand All @@ -440,6 +432,28 @@ struct CompletedTokenOpLowering : ConvertOpToLLVMPattern<CompletedTokenOp> {
}
};

struct CombineTokensOpLowering : ConvertOpToLLVMPattern<CombineTokensOp> {

using ConvertOpToLLVMPattern<CombineTokensOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CombineTokensOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getTokens().empty()) {
rewriter.replaceOpWithNewOp<CompletedTokenOp>(op);
return success();
}

// TODO: Note that this lowering only works for Snitch's single channel DMA!
Value current = adaptor.getTokens().front();
for (Value iter : llvm::drop_begin(adaptor.getTokens()))
current = rewriter.create<LLVM::UMaxOp>(op->getLoc(), current, iter);

rewriter.replaceOp(op, current);
return success();
}
};

struct StatOpLowering : ConvertOpToLLVMPattern<SnitchDMA::StatOp> {
using ConvertOpToLLVMPattern<SnitchDMA::StatOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -485,9 +499,9 @@ void quidditch::populateDMAToLLVMConversionPatterns(
i32, ArrayRef<Type>{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT}));
dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr());

patterns.insert<CompletedTokenOpLowering, WaitForTransfersOpLowering,
StartZeroMemTransferOpOpLowering, StatOpLowering>(
typeConverter);
patterns.insert<CompletedTokenOpLowering, WaitForTransferOpLowering,
StartZeroMemTransferOpOpLowering, StatOpLowering,
CombineTokensOpLowering>(typeConverter);
patterns.insert<StartTransferOp1DLowering>(dmaStart1D, typeConverter);
patterns.insert<StartTransferOp2DLowering>(dmaStart2D, typeConverter);
patterns.insert<StartContiguousZeroMemTransferOpOpLowering>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ struct StartZeroMemTransferOpDMAImpl
StartZeroMemTransferOpDMAImpl, StartZeroMemTransferOp> {};

//===----------------------------------------------------------------------===//
// WaitForTransfersOpImpl::DMACoreSpecializationOpInterface
// WaitForTransferOpImpl::DMACoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

struct WaitForTransfersOpImpl
: CoreSpecializationOpInterface::ExternalModel<WaitForTransfersOpImpl,
WaitForTransfersOp> {
struct WaitForTransferOpImpl
: CoreSpecializationOpInterface::ExternalModel<WaitForTransferOpImpl,
WaitForTransferOp> {
void replaceWithNoop(Operation *op, RewriterBase &rewriter) const {
rewriter.eraseOp(op);
}

bool needsSynchronization(Operation *op) const { return true; }
};

struct WaitForTransfersOpDMAImpl
: DMACoreSpecializationOpInterface::ExternalModel<WaitForTransfersOpDMAImpl,
WaitForTransfersOp> {};
struct WaitForTransferOpDMAImpl
: DMACoreSpecializationOpInterface::ExternalModel<WaitForTransferOpDMAImpl,
WaitForTransferOp> {};

} // namespace

Expand All @@ -71,6 +71,6 @@ void quidditch::dma::registerDMACoreSpecializationOpInterface(
#define REGISTER_IMPLS(Op) Op::attachInterface<Op##Impl, Op##DMAImpl>(*context)
REGISTER_IMPLS(StartTransferOp);
REGISTER_IMPLS(StartZeroMemTransferOp);
REGISTER_IMPLS(WaitForTransfersOp);
REGISTER_IMPLS(WaitForTransferOp);
});
}
25 changes: 6 additions & 19 deletions codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter,
// memory as the transfer below. This is currently unspecified
// behaviour (both in the DMA dialect and in Snitch as far as we
// know).
rewriter.create<WaitForTransfersOp>(getLoc(), token);
rewriter.create<WaitForTransferOp>(getLoc(), token);
}

// Subview into the original memory without any padding.
Expand Down Expand Up @@ -379,7 +379,7 @@ WaitForTensorCopyOp::bufferize(RewriterBase &rewriter,
if (failed(transferTensorBuffer))
return failure();

rewriter.create<WaitForTransfersOp>(getLoc(), getToken());
rewriter.create<WaitForTransferOp>(getLoc(), getToken());
replaceOpWithBufferizedValues(rewriter, getOperation(),
*transferTensorBuffer);
return success();
Expand Down Expand Up @@ -414,25 +414,12 @@ OpFoldResult StartTransferOp::fold(FoldAdaptor adaptor) {
}

//===----------------------------------------------------------------------===//
// WaitForTransfersOp
// WaitForTransferOp
//===----------------------------------------------------------------------===//

LogicalResult WaitForTransfersOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
bool changed = false;
MutableOperandRange tokens = getTokensMutable();
for (int i = tokens.size() - 1; i >= 0; i--) {
if (adaptor.getTokens()[i]) {
changed = true;
tokens.erase(i);
}
}
return success(changed);
}

LogicalResult WaitForTransfersOp::canonicalize(WaitForTransfersOp op,
PatternRewriter &rewriter) {
if (!op.getTokens().empty())
LogicalResult WaitForTransferOp::canonicalize(WaitForTransferOp op,
PatternRewriter &rewriter) {
if (!matchPattern(op.getToken(), m_Constant<CompletedTokenAttr>(nullptr)))
return failure();

rewriter.eraseOp(op);
Expand Down
28 changes: 21 additions & 7 deletions codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,19 @@ def DMA_StartZeroMemTransferOp : DMA_Op<"start_zero_mem_transfer",
}];
}

def DMA_WaitForTransfersOp : DMA_Op<"wait_for_transfers"> {
def DMA_WaitForTransferOp : DMA_Op<"wait_for_transfer"> {

let description = [{
Operation awaiting for DMA transfers denoted by its tokens to be finished.
Operation awaiting for all DMA transfers denoted by its token to have
finished.
}];

let arguments = (ins
Variadic<DMA_TokenType>:$tokens
);
let arguments = (ins DMA_TokenType:$token);

let assemblyFormat = [{
($tokens^ `:` type($tokens))? attr-dict
$token attr-dict
}];

let hasFolder = 1;
let hasCanonicalizeMethod = 1;
}

Expand All @@ -214,4 +212,20 @@ def DMA_CompletedTokenOp
let hasFolder = 1;
}

def DMA_CombineTokensOp : DMA_Op<"combine_tokens", [Pure]> {

let description = [{
Op combining multiple DMA tokens into one.
Awaiting the token returned by this function is equal in effect as if each
token was awaited independently in unspecified order.
}];

let arguments = (ins Variadic<DMA_TokenType>:$tokens);
let results = (outs DMA_TokenType:$result);

let assemblyFormat = [{
$tokens attr-dict
}];
}

#endif
2 changes: 1 addition & 1 deletion codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
[](OpBuilder &builder, Location loc, Value from, Value to) {
Value token =
builder.create<quidditch::dma::StartTransferOp>(loc, from, to);
builder.create<quidditch::dma::WaitForTransfersOp>(loc, token);
builder.create<quidditch::dma::WaitForTransferOp>(loc, token);
return success();
};

Expand Down
17 changes: 17 additions & 0 deletions codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: quidditch-opt %s --quidditch-convert-to-llvm | FileCheck %s

// CHECK-LABEL: @test
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token) {
// CHECK: llvm.br ^[[BODY:[[:alnum:]]+]]
// CHECK: ^[[BODY]]:
// CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero
// CHECK-SAME: "=r"
// CHECK-SAME: -> i32
// CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]]
// CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]]
// CHECK: ^[[CONT]]:
dma.wait_for_transfer %arg0
// CHECK-NEXT: llvm.return
return
}
26 changes: 14 additions & 12 deletions codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

// CHECK-LABEL: @test
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token) {
// CHECK: llvm.br ^[[BODY:[[:alnum:]]+]]
// CHECK: ^[[BODY]]:
// CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero
// CHECK-SAME: "=r"
// CHECK-SAME: -> i32
// CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]]
// CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]]
// CHECK: ^[[CONT]]:
dma.wait_for_transfers %arg0 : !dma.token
// CHECK-NEXT: llvm.return
return
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]
func.func private @test(%arg0 : !dma.token, %arg1 : !dma.token) -> !dma.token {
// CHECK: %[[TOKEN:.*]] = llvm.intr.umax(%[[ARG0]], %[[ARG1]])
%token = dma.combine_tokens %arg0, %arg1
// CHECK: llvm.return %[[TOKEN]]
return %token : !dma.token
}

// CHECK-LABEL: @test_empty(
func.func private @test_empty() -> !dma.token {
// CHECK: %[[TOKEN:.*]] = llvm.mlir.constant(0 :
%token = dma.combine_tokens
// CHECK: return %[[TOKEN]]
return %token : !dma.token
}
2 changes: 1 addition & 1 deletion codegen/tests/Dialect/DMA/IR/bufferization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func.func @tensor_copy_pad(%arg0 : tensor<?x?xf32>, %pad0 : index, %pad1 : index
// CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]]
// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]])
// CHECK: %[[ZERO_TOKEN:.*]] = dma.start_zero_mem_transfer %[[ALLOC]]
// CHECK: dma.wait_for_transfers %[[ZERO_TOKEN]]
// CHECK: dma.wait_for_transfer %[[ZERO_TOKEN]]
// CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1]
// CHECK: %[[TOKEN:.*]] = dma.start_transfer from %[[COPY]]
// CHECK-SAME: to %[[UNPADDED]]
Expand Down
2 changes: 1 addition & 1 deletion codegen/tests/Dialect/DMA/IR/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
func.func @wait_gets_removed() {
// CHECK-NEXT: return
%0 = dma.completed_token
dma.wait_for_transfers %0 : !dma.token
dma.wait_for_transfer %0
return
}

Expand Down
7 changes: 1 addition & 6 deletions codegen/tests/Dialect/DMA/IR/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
// RUN: quidditch-opt %s --verify-roundtrip

func.func @test(%arg0 : memref<f64>) {
dma.wait_for_transfers
return
}

func.func @test3(%arg0 : tensor<?x4xf64>) -> (tensor<?x4xf64>, !dma.token) {
func.func @test(%arg0 : tensor<?x4xf64>) -> (tensor<?x4xf64>, !dma.token) {
%0:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor<?x4xf64>
return %0#0, %0#1 : tensor<?x4xf64>, !dma.token
}
6 changes: 3 additions & 3 deletions codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ func.func @test(
^bb0(%arg1: index, %arg2: memref<40x100xf64, #quidditch_snitch.l1_encoding>, %arg3: !dma.token):
// CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP3]](%[[IV]])
// CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]]
// CHECK: wait_for_transfers %[[YIELDED1]]
// CHECK: wait_for_transfer %[[YIELDED1]]
// CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[YIELDED0]] : {{.*}})
// CHECK: yield %[[NEXT_YIELDED]], %{{.*}} :

%subview_3 = memref.subview %alloca[0, %arg1] [1, 40] [1, 1] : memref<1x1200xf64, #quidditch_snitch.l1_encoding> to memref<1x40xf64, strided<[1200, 1], offset: ?>, #quidditch_snitch.l1_encoding>
dma.wait_for_transfers %arg3 : !dma.token
dma.wait_for_transfer %arg3
linalg.matmul_transpose_b
ins(%alloca2, %arg2 : memref<1x100xf64, #quidditch_snitch.l1_encoding>, memref<40x100xf64, #quidditch_snitch.l1_encoding>)
outs(%out : memref<1x40xf64, #quidditch_snitch.l1_encoding>)
}
// CHECK: %[[IV:.*]] = affine.apply #[[$MAP4]]()
// CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP5]]()
// CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]]
// CHECK: wait_for_transfers %[[LAST]]#1
// CHECK: wait_for_transfer %[[LAST]]#1
// CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[LAST]]#0 : {{.*}})
return
}
12 changes: 6 additions & 6 deletions codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: quidditch_snitch.barrier
dma.start_transfer from %a : memref<32xf32> to %a_l1 : memref<32xf32>
%t = dma.start_transfer from %b : memref<32xf32> to %b_l1 : memref<32xf32>
dma.wait_for_transfers %t : !dma.token
dma.wait_for_transfer %t

// CHECK-NEXT: microkernel
// CHECK: }
Expand All @@ -34,7 +34,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: dma.completed_token
%t2 = dma.start_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32>
// CHECK-NEXT: quidditch_snitch.barrier
dma.wait_for_transfers %t2 : !dma.token
dma.wait_for_transfer %t2


// CHECK: scf.if
Expand All @@ -55,7 +55,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
scf.yield %c, %i : !dma.token, index
}
// CHECK: quidditch_snitch.barrier
dma.wait_for_transfers %r#0 : !dma.token
dma.wait_for_transfer %r#0
// CHECK-NEXT: return
return
}
Expand All @@ -69,12 +69,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {

// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.wait_for_transfers
// CHECK-NEXT: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: quidditch_snitch.barrier
// CHECK-NEXT: dma.start_transfer
// CHECK-NEXT: dma.wait_for_transfers
// CHECK-NEXT: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: scf.if
Expand All @@ -85,7 +85,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: completed_token
// CHECK-NEXT: arith.constant
// CHECK-NEXT: yield
// CHECK: dma.wait_for_transfers
// CHECK: dma.wait_for_transfer
// CHECK-NEXT: quidditch_snitch.barrier

// CHECK-NEXT: return