From 6293803dd491589fe2dc53a43dd75d6230bdefa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sat, 17 Aug 2024 19:23:44 +0200 Subject: [PATCH] [quidditch_snitch] Add `CompletedTokenAttr` (#112) This is the attribute version corresponding to the `completed_token` operation. The latter is now a `ConstantLike` that folds to this token value. The main advantage is the more comfortable use of the `fold` API rather than canonicalization patterns. --- .../Dialect/Snitch/IR/QuidditchSnitchAttrs.td | 10 ++++ .../Snitch/IR/QuidditchSnitchDialect.cpp | 10 ++++ .../Snitch/IR/QuidditchSnitchDialect.td | 1 + .../Dialect/Snitch/IR/QuidditchSnitchOps.cpp | 49 ++++++++++++++++--- .../Dialect/Snitch/IR/QuidditchSnitchOps.td | 10 +++- .../Dialect/Snitch/IR/canonicalization.mlir | 21 ++++++++ 6 files changed, 92 insertions(+), 9 deletions(-) diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td index b989279..0e3d1a3 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchAttrs.td @@ -8,6 +8,16 @@ include "mlir/IR/AttrTypeBase.td" class QuidditchSnitch_Attr traits = []> : AttrDef; +def QuidditchSnitch_CompletedTokenAttr : QuidditchSnitch_Attr<"CompletedToken"> { + + let mnemonic = "completed_token"; + + let description = [{ + Attribute representing an instance of a `!quidditch_snitch.dma_token` + signaling a complete transfer. + }]; +} + def QuidditchSnitch_L1EncodingAttr : QuidditchSnitch_Attr<"L1Encoding"> { let mnemonic = "l1_encoding"; diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp index cad250a..2a70184 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.cpp @@ -55,3 +55,13 @@ void QuidditchSnitchDialect::initialize() { #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc" >(); } + +Operation *QuidditchSnitchDialect::materializeConstant(OpBuilder &builder, + Attribute value, + Type type, + Location loc) { + if (isa(value)) + return builder.create(loc); + + return nullptr; +} diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td index c273486..4292a6c 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td @@ -15,6 +15,7 @@ def QuidditchSnitch_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let hasConstantMaterializer = 1; } #endif diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp index ba70a46..5722c7e 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.cpp @@ -378,6 +378,24 @@ LogicalResult CallMicrokernelOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// StartTensorCopyOp +//===----------------------------------------------------------------------===// + +LogicalResult StartTensorCopyOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + auto waitOp = getCopy().getDefiningOp(); + if (!waitOp) + return failure(); + auto copyOp = waitOp.getTransferTensor().getDefiningOp(); + if (!copyOp) + return failure(); + + results.emplace_back(waitOp); + results.emplace_back(CompletedTokenAttr::get(getContext())); + return success(); +} + //===----------------------------------------------------------------------===// // StartTensorCopyOp::BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -533,6 +551,17 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter, return success(); } +//===----------------------------------------------------------------------===// +// WaitForTensorCopyOp +//===----------------------------------------------------------------------===// + +OpFoldResult WaitForTensorCopyOp::fold(FoldAdaptor adaptor) { + if (adaptor.getToken()) + return getTransferTensor(); + + return nullptr; +} + //===----------------------------------------------------------------------===// // WaitForTensorCopyOp::BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -598,17 +627,23 @@ bool WaitForTensorCopyOp::isNotConflicting( return false; } +//===----------------------------------------------------------------------===// +// CompletedTokenOp +//===----------------------------------------------------------------------===// + +OpFoldResult CompletedTokenOp::fold(FoldAdaptor adaptor) { + return CompletedTokenAttr::get(getContext()); +} + //===----------------------------------------------------------------------===// // StartDMATransferOp //===----------------------------------------------------------------------===// -LogicalResult StartDMATransferOp::canonicalize(StartDMATransferOp op, - PatternRewriter &rewriter) { - if (op.getSource() != op.getDest()) - return failure(); +OpFoldResult StartDMATransferOp::fold(FoldAdaptor adaptor) { + if (getSource() != getDest()) + return nullptr; - rewriter.replaceOpWithNewOp(op); - return success(); + return CompletedTokenAttr::get(getContext()); } //===----------------------------------------------------------------------===// @@ -621,7 +656,7 @@ WaitForDMATransfersOp::fold(FoldAdaptor adaptor, bool changed = false; MutableOperandRange tokens = getTokensMutable(); for (int i = tokens.size() - 1; i >= 0; i--) { - if (tokens[i].get().getDefiningOp()) { + if (adaptor.getTokens()[i]) { changed = true; tokens.erase(i); } diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td index eb792ee..3eb2946 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.td @@ -190,6 +190,8 @@ def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy", let assemblyFormat = [{ $copy `to` `L1` `:` type($copy) attr-dict }]; + + let hasFolder = 1; } def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_copy", @@ -228,6 +230,8 @@ def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_co let assemblyFormat = [{ `of` $copy `to` $transfer_tensor `using` $token `:` type($transfer_tensor) attr-dict }]; + + let hasFolder = 1; } def FlatI8MemRef : ConfinedType, [HasStaticShapePred, @@ -267,7 +271,7 @@ def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer" `from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict }]; - let hasCanonicalizeMethod = 1; + let hasFolder = 1; } def QuidditchSnitch_WaitForDMATransfersOp @@ -290,7 +294,7 @@ def QuidditchSnitch_WaitForDMATransfersOp } def QuidditchSnitch_CompletedTokenOp - : QuidditchSnitch_Op<"completed_token", [Pure]> { + : QuidditchSnitch_Op<"completed_token", [Pure, ConstantLike]> { let description = [{ Op returning a special value representing a completed DMA transfer. @@ -302,6 +306,8 @@ def QuidditchSnitch_CompletedTokenOp let assemblyFormat = [{ attr-dict }]; + + let hasFolder = 1; } def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> { diff --git a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir index e446c56..31911f6 100644 --- a/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir +++ b/codegen/tests/Dialect/Snitch/IR/canonicalization.mlir @@ -91,3 +91,24 @@ func.func @pipeline_invariant(%tensor : tensor) { } return } + +// CHECK-LABEL: @tensor_wait_gets_removed +// CHECK-SAME: %[[ARG0:[[:alnum:]]+]] +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]] +func.func @tensor_wait_gets_removed(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NEXT: return %[[ARG1]] + %t = quidditch_snitch.completed_token + %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %arg1 using %t : tensor + return %0 : tensor +} + +// CHECK-LABEL: @tensor_noop_transfer +func.func @tensor_noop_transfer(%arg0 : tensor) -> (tensor, !quidditch_snitch.dma_token) { + // CHECK: %[[T:.*]] = quidditch_snitch.completed_token + // CHECK: %[[R:.*]] = quidditch_snitch.wait_for_tensor_copy + // CHECK-NEXT: return %[[R]], %[[T]] + %r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor + %0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %r using %t : tensor + %r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor + return %r2, %t2 : tensor, !quidditch_snitch.dma_token +}