From cb0c2fcba0635ecf8c02d18f2e9082fbe39a6b64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sun, 1 Sep 2024 15:07:48 +0100 Subject: [PATCH] [DMA] Add `combine_token` op Being able to dynamically combine tokens to await a runtime-dependent number of DMA transfers is a requirement for implementing legalization of DMA transfers in the `dma` dialect. This PR therefore adds the `combined_tokens` op which combines multiple tokens into one. The lowering for Snitch leverages the monotonicity guarantee of IDs to combine them. Note that this only works with a single channel DMA right now. As this subsumes the multi-token capabilities of `wait_for_transfers`, it has been renamed to just `wait_for_transfer` and only accepts a single token input now. --- .../Quidditch/Conversion/ConvertDMAToLLVM.cpp | 48 ++++++++++++------- .../DMACoreSpecializationOpInterfaceImpl.cpp | 16 +++---- .../src/Quidditch/Dialect/DMA/IR/DMAOps.cpp | 25 +++------- .../src/Quidditch/Dialect/DMA/IR/DMAOps.td | 28 ++++++++--- .../src/Quidditch/Target/QuidditchTarget.cpp | 2 +- .../ConvertDMAToLLVM/combine_tokens.mlir | 17 +++++++ .../Conversion/ConvertDMAToLLVM/dma_wait.mlir | 26 +++++----- .../tests/Dialect/DMA/IR/bufferization.mlir | 2 +- .../Dialect/DMA/IR/canonicalization.mlir | 2 +- codegen/tests/Dialect/DMA/IR/roundtrip.mlir | 7 +-- .../Snitch/Transforms/lower-pipeline.mlir | 6 +-- .../Transforms/specialize-dma-code.mlir | 12 ++--- 12 files changed, 110 insertions(+), 81 deletions(-) create mode 100644 codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir diff --git a/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp index 98430ad..86db828 100644 --- a/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp +++ b/codegen/compiler/src/Quidditch/Conversion/ConvertDMAToLLVM.cpp @@ -393,22 +393,13 @@ struct StartZeroMemTransferOpOpLowering } }; -struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern { +struct WaitForTransferOpLowering : ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(WaitForTransfersOp op, OpAdaptor adaptor, + matchAndRewrite(WaitForTransferOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (adaptor.getTokens().empty()) { - rewriter.eraseOp(op); - return success(); - } - - Value current = adaptor.getTokens().front(); - for (Value iter : llvm::drop_begin(adaptor.getTokens())) - current = rewriter.create(op->getLoc(), current, iter); - Block *prev = op->getBlock(); Block *body = rewriter.splitBlock(prev, op->getIterator()); Block *after = rewriter.splitBlock(body, op->getNextNode()->getIterator()); @@ -417,8 +408,9 @@ struct WaitForTransfersOpLowering : ConvertOpToLLVMPattern { rewriter.setInsertionPointToEnd(body); Value lastCompleted = rewriter.create(op->getLoc()); - Value notDone = rewriter.create( - op->getLoc(), LLVM::ICmpPredicate::ult, lastCompleted, current); + Value notDone = + rewriter.create(op->getLoc(), LLVM::ICmpPredicate::ult, + lastCompleted, adaptor.getToken()); rewriter.create(op->getLoc(), notDone, body, after); rewriter.setInsertionPointToStart(after); @@ -440,6 +432,28 @@ struct CompletedTokenOpLowering : ConvertOpToLLVMPattern { } }; +struct CombineTokensOpLowering : ConvertOpToLLVMPattern { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CombineTokensOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getTokens().empty()) { + rewriter.replaceOpWithNewOp(op); + return success(); + } + + // TODO: Note that this lowering only works for Snitch's single channel DMA! + Value current = adaptor.getTokens().front(); + for (Value iter : llvm::drop_begin(adaptor.getTokens())) + current = rewriter.create(op->getLoc(), current, iter); + + rewriter.replaceOp(op, current); + return success(); + } +}; + struct StatOpLowering : ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -485,9 +499,9 @@ void quidditch::populateDMAToLLVMConversionPatterns( i32, ArrayRef{ptrType, ptrType, sizeT, sizeT, sizeT, sizeT})); dmaStart2D->setAttr("hal.import.bitcode", builder.getUnitAttr()); - patterns.insert( - typeConverter); + patterns.insert(typeConverter); patterns.insert(dmaStart1D, typeConverter); patterns.insert(dmaStart2D, typeConverter); patterns.insert( diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp index 1bb4cb0..8e66ee7 100644 --- a/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/Extensions/DMACoreSpecializationOpInterfaceImpl.cpp @@ -46,12 +46,12 @@ struct StartZeroMemTransferOpDMAImpl StartZeroMemTransferOpDMAImpl, StartZeroMemTransferOp> {}; //===----------------------------------------------------------------------===// -// WaitForTransfersOpImpl::DMACoreSpecializationOpInterface +// WaitForTransferOpImpl::DMACoreSpecializationOpInterface //===----------------------------------------------------------------------===// -struct WaitForTransfersOpImpl - : CoreSpecializationOpInterface::ExternalModel { +struct WaitForTransferOpImpl + : CoreSpecializationOpInterface::ExternalModel { void replaceWithNoop(Operation *op, RewriterBase &rewriter) const { rewriter.eraseOp(op); } @@ -59,9 +59,9 @@ struct WaitForTransfersOpImpl bool needsSynchronization(Operation *op) const { return true; } }; -struct WaitForTransfersOpDMAImpl - : DMACoreSpecializationOpInterface::ExternalModel {}; +struct WaitForTransferOpDMAImpl + : DMACoreSpecializationOpInterface::ExternalModel {}; } // namespace @@ -71,6 +71,6 @@ void quidditch::dma::registerDMACoreSpecializationOpInterface( #define REGISTER_IMPLS(Op) Op::attachInterface(*context) REGISTER_IMPLS(StartTransferOp); REGISTER_IMPLS(StartZeroMemTransferOp); - REGISTER_IMPLS(WaitForTransfersOp); + REGISTER_IMPLS(WaitForTransferOp); }); } diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp index 913bfab..8e098b8 100644 --- a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.cpp @@ -296,7 +296,7 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter, // memory as the transfer below. This is currently unspecified // behaviour (both in the DMA dialect and in Snitch as far as we // know). - rewriter.create(getLoc(), token); + rewriter.create(getLoc(), token); } // Subview into the original memory without any padding. @@ -379,7 +379,7 @@ WaitForTensorCopyOp::bufferize(RewriterBase &rewriter, if (failed(transferTensorBuffer)) return failure(); - rewriter.create(getLoc(), getToken()); + rewriter.create(getLoc(), getToken()); replaceOpWithBufferizedValues(rewriter, getOperation(), *transferTensorBuffer); return success(); @@ -414,25 +414,12 @@ OpFoldResult StartTransferOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// WaitForTransfersOp +// WaitForTransferOp //===----------------------------------------------------------------------===// -LogicalResult WaitForTransfersOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - bool changed = false; - MutableOperandRange tokens = getTokensMutable(); - for (int i = tokens.size() - 1; i >= 0; i--) { - if (adaptor.getTokens()[i]) { - changed = true; - tokens.erase(i); - } - } - return success(changed); -} - -LogicalResult WaitForTransfersOp::canonicalize(WaitForTransfersOp op, - PatternRewriter &rewriter) { - if (!op.getTokens().empty()) +LogicalResult WaitForTransferOp::canonicalize(WaitForTransferOp op, + PatternRewriter &rewriter) { + if (!matchPattern(op.getToken(), m_Constant(nullptr))) return failure(); rewriter.eraseOp(op); diff --git a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td index 7307326..b7a8eb5 100644 --- a/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td +++ b/codegen/compiler/src/Quidditch/Dialect/DMA/IR/DMAOps.td @@ -179,21 +179,19 @@ def DMA_StartZeroMemTransferOp : DMA_Op<"start_zero_mem_transfer", }]; } -def DMA_WaitForTransfersOp : DMA_Op<"wait_for_transfers"> { +def DMA_WaitForTransferOp : DMA_Op<"wait_for_transfer"> { let description = [{ - Operation awaiting for DMA transfers denoted by its tokens to be finished. + Operation awaiting for all DMA transfers denoted by its token to have + finished. }]; - let arguments = (ins - Variadic:$tokens - ); + let arguments = (ins DMA_TokenType:$token); let assemblyFormat = [{ - ($tokens^ `:` type($tokens))? attr-dict + $token attr-dict }]; - let hasFolder = 1; let hasCanonicalizeMethod = 1; } @@ -214,4 +212,20 @@ def DMA_CompletedTokenOp let hasFolder = 1; } +def DMA_CombineTokensOp : DMA_Op<"combine_tokens", [Pure]> { + + let description = [{ + Op combining multiple DMA tokens into one. + Awaiting the token returned by this function is equal in effect as if each + token was awaited independently in unspecified order. + }]; + + let arguments = (ins Variadic:$tokens); + let results = (outs DMA_TokenType:$result); + + let assemblyFormat = [{ + $tokens attr-dict + }]; +} + #endif diff --git a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp index e49242d..8d79fe8 100644 --- a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp +++ b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp @@ -219,7 +219,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { [](OpBuilder &builder, Location loc, Value from, Value to) { Value token = builder.create(loc, from, to); - builder.create(loc, token); + builder.create(loc, token); return success(); }; diff --git a/codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir b/codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir new file mode 100644 index 0000000..f51fb8f --- /dev/null +++ b/codegen/tests/Conversion/ConvertDMAToLLVM/combine_tokens.mlir @@ -0,0 +1,17 @@ +// RUN: quidditch-opt %s --quidditch-convert-to-llvm | FileCheck %s + +// CHECK-LABEL: @test +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +func.func private @test(%arg0 : !dma.token) { + // CHECK: llvm.br ^[[BODY:[[:alnum:]]+]] + // CHECK: ^[[BODY]]: + // CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero + // CHECK-SAME: "=r" + // CHECK-SAME: -> i32 + // CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]] + // CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]] + // CHECK: ^[[CONT]]: + dma.wait_for_transfer %arg0 + // CHECK-NEXT: llvm.return + return +} diff --git a/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir index 47ad305..721ba9d 100644 --- a/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir +++ b/codegen/tests/Conversion/ConvertDMAToLLVM/dma_wait.mlir @@ -2,16 +2,18 @@ // CHECK-LABEL: @test // CHECK-SAME: %[[ARG0:[[:alnum:]]+]] -func.func private @test(%arg0 : !dma.token) { - // CHECK: llvm.br ^[[BODY:[[:alnum:]]+]] - // CHECK: ^[[BODY]]: - // CHECK-NEXT: %[[ID:.*]] = llvm.inline_asm has_side_effects ".insn r 0x2b, 0, 0b100, $0, zero, zero - // CHECK-SAME: "=r" - // CHECK-SAME: -> i32 - // CHECK: %[[COND:.*]] = llvm.icmp "ult" %[[ID]], %[[ARG0]] - // CHECK: llvm.cond_br %[[COND]], ^[[BODY]], ^[[CONT:[[:alnum:]]+]] - // CHECK: ^[[CONT]]: - dma.wait_for_transfers %arg0 : !dma.token - // CHECK-NEXT: llvm.return - return +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]] +func.func private @test(%arg0 : !dma.token, %arg1 : !dma.token) -> !dma.token { + // CHECK: %[[TOKEN:.*]] = llvm.intr.umax(%[[ARG0]], %[[ARG1]]) + %token = dma.combine_tokens %arg0, %arg1 + // CHECK: llvm.return %[[TOKEN]] + return %token : !dma.token +} + +// CHECK-LABEL: @test_empty( +func.func private @test_empty() -> !dma.token { + // CHECK: %[[TOKEN:.*]] = llvm.mlir.constant(0 : + %token = dma.combine_tokens + // CHECK: return %[[TOKEN]] + return %token : !dma.token } diff --git a/codegen/tests/Dialect/DMA/IR/bufferization.mlir b/codegen/tests/Dialect/DMA/IR/bufferization.mlir index c3a1802..42d9fbb 100644 --- a/codegen/tests/Dialect/DMA/IR/bufferization.mlir +++ b/codegen/tests/Dialect/DMA/IR/bufferization.mlir @@ -92,7 +92,7 @@ func.func @tensor_copy_pad(%arg0 : tensor, %pad0 : index, %pad1 : index // CHECK: %[[NEW_DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[DIM1]], %[[PAD1]]] // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[NEW_DIM0]], %[[NEW_DIM1]]) // CHECK: %[[ZERO_TOKEN:.*]] = dma.start_zero_mem_transfer %[[ALLOC]] - // CHECK: dma.wait_for_transfers %[[ZERO_TOKEN]] + // CHECK: dma.wait_for_transfer %[[ZERO_TOKEN]] // CHECK: %[[UNPADDED:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[DIM0]], %[[DIM1]]] [1, 1] // CHECK: %[[TOKEN:.*]] = dma.start_transfer from %[[COPY]] // CHECK-SAME: to %[[UNPADDED]] diff --git a/codegen/tests/Dialect/DMA/IR/canonicalization.mlir b/codegen/tests/Dialect/DMA/IR/canonicalization.mlir index 60998f0..7268ce0 100644 --- a/codegen/tests/Dialect/DMA/IR/canonicalization.mlir +++ b/codegen/tests/Dialect/DMA/IR/canonicalization.mlir @@ -4,7 +4,7 @@ func.func @wait_gets_removed() { // CHECK-NEXT: return %0 = dma.completed_token - dma.wait_for_transfers %0 : !dma.token + dma.wait_for_transfer %0 return } diff --git a/codegen/tests/Dialect/DMA/IR/roundtrip.mlir b/codegen/tests/Dialect/DMA/IR/roundtrip.mlir index de6d37b..5b28d66 100644 --- a/codegen/tests/Dialect/DMA/IR/roundtrip.mlir +++ b/codegen/tests/Dialect/DMA/IR/roundtrip.mlir @@ -1,11 +1,6 @@ // RUN: quidditch-opt %s --verify-roundtrip -func.func @test(%arg0 : memref) { - dma.wait_for_transfers - return -} - -func.func @test3(%arg0 : tensor) -> (tensor, !dma.token) { +func.func @test(%arg0 : tensor) -> (tensor, !dma.token) { %0:2 = dma.start_tensor_copy of %arg0 to #quidditch_snitch.l1_encoding -> tensor return %0#0, %0#1 : tensor, !dma.token } diff --git a/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir b/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir index 59254c0..cf88fa1 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/lower-pipeline.mlir @@ -53,12 +53,12 @@ func.func @test( ^bb0(%arg1: index, %arg2: memref<40x100xf64, #quidditch_snitch.l1_encoding>, %arg3: !dma.token): // CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP3]](%[[IV]]) // CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]] - // CHECK: wait_for_transfers %[[YIELDED1]] + // CHECK: wait_for_transfer %[[YIELDED1]] // CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[YIELDED0]] : {{.*}}) // CHECK: yield %[[NEXT_YIELDED]], %{{.*}} : %subview_3 = memref.subview %alloca[0, %arg1] [1, 40] [1, 1] : memref<1x1200xf64, #quidditch_snitch.l1_encoding> to memref<1x40xf64, strided<[1200, 1], offset: ?>, #quidditch_snitch.l1_encoding> - dma.wait_for_transfers %arg3 : !dma.token + dma.wait_for_transfer %arg3 linalg.matmul_transpose_b ins(%alloca2, %arg2 : memref<1x100xf64, #quidditch_snitch.l1_encoding>, memref<40x100xf64, #quidditch_snitch.l1_encoding>) outs(%out : memref<1x40xf64, #quidditch_snitch.l1_encoding>) @@ -66,7 +66,7 @@ func.func @test( // CHECK: %[[IV:.*]] = affine.apply #[[$MAP4]]() // CHECK: %[[STAGE1_IV:.*]] = affine.apply #[[$MAP5]]() // CHECK: memref.subview %{{.*}}[0, %[[STAGE1_IV]]] - // CHECK: wait_for_transfers %[[LAST]]#1 + // CHECK: wait_for_transfer %[[LAST]]#1 // CHECK: linalg.matmul_transpose_b ins(%{{.*}}, %[[LAST]]#0 : {{.*}}) return } diff --git a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir index 351b2d7..7900efb 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir @@ -19,7 +19,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: quidditch_snitch.barrier dma.start_transfer from %a : memref<32xf32> to %a_l1 : memref<32xf32> %t = dma.start_transfer from %b : memref<32xf32> to %b_l1 : memref<32xf32> - dma.wait_for_transfers %t : !dma.token + dma.wait_for_transfer %t // CHECK-NEXT: microkernel // CHECK: } @@ -34,7 +34,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: dma.completed_token %t2 = dma.start_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> // CHECK-NEXT: quidditch_snitch.barrier - dma.wait_for_transfers %t2 : !dma.token + dma.wait_for_transfer %t2 // CHECK: scf.if @@ -55,7 +55,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { scf.yield %c, %i : !dma.token, index } // CHECK: quidditch_snitch.barrier - dma.wait_for_transfers %r#0 : !dma.token + dma.wait_for_transfer %r#0 // CHECK-NEXT: return return } @@ -69,12 +69,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: dma.start_transfer // CHECK-NEXT: dma.start_transfer -// CHECK-NEXT: dma.wait_for_transfers +// CHECK-NEXT: dma.wait_for_transfer // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: dma.start_transfer -// CHECK-NEXT: dma.wait_for_transfers +// CHECK-NEXT: dma.wait_for_transfer // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: scf.if @@ -85,7 +85,7 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: completed_token // CHECK-NEXT: arith.constant // CHECK-NEXT: yield -// CHECK: dma.wait_for_transfers +// CHECK: dma.wait_for_transfer // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: return