Skip to content

Commit

Permalink
[quidditch_snitch] Add CompletedTokenAttr (#112)
Browse files Browse the repository at this point in the history
This is the attribute version corresponding to the `completed_token`
operation. The latter is now a `ConstantLike` that folds to this token
value.

The main advantage is the more comfortable use of the `fold` API rather
than canonicalization patterns.
  • Loading branch information
zero9178 authored Aug 17, 2024
1 parent 9776fed commit 6293803
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ include "mlir/IR/AttrTypeBase.td"
class QuidditchSnitch_Attr<string name, list<Trait> traits = []> :
AttrDef<QuidditchSnitch_Dialect, name, traits>;

def QuidditchSnitch_CompletedTokenAttr : QuidditchSnitch_Attr<"CompletedToken"> {

let mnemonic = "completed_token";

let description = [{
Attribute representing an instance of a `!quidditch_snitch.dma_token`
signaling a complete transfer.
}];
}

def QuidditchSnitch_L1EncodingAttr : QuidditchSnitch_Attr<"L1Encoding"> {
let mnemonic = "l1_encoding";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ void QuidditchSnitchDialect::initialize() {
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.cpp.inc"
>();
}

Operation *QuidditchSnitchDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type,
Location loc) {
if (isa<CompletedTokenAttr>(value))
return builder.create<CompletedTokenOp>(loc);

return nullptr;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def QuidditchSnitch_Dialect : Dialect {

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,24 @@ LogicalResult CallMicrokernelOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// StartTensorCopyOp
//===----------------------------------------------------------------------===//

LogicalResult StartTensorCopyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
auto waitOp = getCopy().getDefiningOp<WaitForTensorCopyOp>();
if (!waitOp)
return failure();
auto copyOp = waitOp.getTransferTensor().getDefiningOp<StartTensorCopyOp>();
if (!copyOp)
return failure();

results.emplace_back(waitOp);
results.emplace_back(CompletedTokenAttr::get(getContext()));
return success();
}

//===----------------------------------------------------------------------===//
// StartTensorCopyOp::BufferizableOpInterface
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -533,6 +551,17 @@ StartTensorCopyOp::bufferize(RewriterBase &rewriter,
return success();
}

//===----------------------------------------------------------------------===//
// WaitForTensorCopyOp
//===----------------------------------------------------------------------===//

OpFoldResult WaitForTensorCopyOp::fold(FoldAdaptor adaptor) {
if (adaptor.getToken())
return getTransferTensor();

return nullptr;
}

//===----------------------------------------------------------------------===//
// WaitForTensorCopyOp::BufferizableOpInterface
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -598,17 +627,23 @@ bool WaitForTensorCopyOp::isNotConflicting(
return false;
}

//===----------------------------------------------------------------------===//
// CompletedTokenOp
//===----------------------------------------------------------------------===//

OpFoldResult CompletedTokenOp::fold(FoldAdaptor adaptor) {
return CompletedTokenAttr::get(getContext());
}

//===----------------------------------------------------------------------===//
// StartDMATransferOp
//===----------------------------------------------------------------------===//

LogicalResult StartDMATransferOp::canonicalize(StartDMATransferOp op,
PatternRewriter &rewriter) {
if (op.getSource() != op.getDest())
return failure();
OpFoldResult StartDMATransferOp::fold(FoldAdaptor adaptor) {
if (getSource() != getDest())
return nullptr;

rewriter.replaceOpWithNewOp<CompletedTokenOp>(op);
return success();
return CompletedTokenAttr::get(getContext());
}

//===----------------------------------------------------------------------===//
Expand All @@ -621,7 +656,7 @@ WaitForDMATransfersOp::fold(FoldAdaptor adaptor,
bool changed = false;
MutableOperandRange tokens = getTokensMutable();
for (int i = tokens.size() - 1; i >= 0; i--) {
if (tokens[i].get().getDefiningOp<CompletedTokenOp>()) {
if (adaptor.getTokens()[i]) {
changed = true;
tokens.erase(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy",
let assemblyFormat = [{
$copy `to` `L1` `:` type($copy) attr-dict
}];

let hasFolder = 1;
}

def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_copy",
Expand Down Expand Up @@ -228,6 +230,8 @@ def QuidditchSnitch_WaitForTensorCopyOp : QuidditchSnitch_Op<"wait_for_tensor_co
let assemblyFormat = [{
`of` $copy `to` $transfer_tensor `using` $token `:` type($transfer_tensor) attr-dict
}];

let hasFolder = 1;
}

def FlatI8MemRef : ConfinedType<MemRefOf<[I8]>, [HasStaticShapePred,
Expand Down Expand Up @@ -267,7 +271,7 @@ def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer"
`from` $source `:` type($source) `to` $dest `:` type($dest) attr-dict
}];

let hasCanonicalizeMethod = 1;
let hasFolder = 1;
}

def QuidditchSnitch_WaitForDMATransfersOp
Expand All @@ -290,7 +294,7 @@ def QuidditchSnitch_WaitForDMATransfersOp
}

def QuidditchSnitch_CompletedTokenOp
: QuidditchSnitch_Op<"completed_token", [Pure]> {
: QuidditchSnitch_Op<"completed_token", [Pure, ConstantLike]> {

let description = [{
Op returning a special value representing a completed DMA transfer.
Expand All @@ -302,6 +306,8 @@ def QuidditchSnitch_CompletedTokenOp
let assemblyFormat = [{
attr-dict
}];

let hasFolder = 1;
}

def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> {
Expand Down
21 changes: 21 additions & 0 deletions codegen/tests/Dialect/Snitch/IR/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,24 @@ func.func @pipeline_invariant(%tensor : tensor<?xf32>) {
}
return
}

// CHECK-LABEL: @tensor_wait_gets_removed
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]
func.func @tensor_wait_gets_removed(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: return %[[ARG1]]
%t = quidditch_snitch.completed_token
%0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %arg1 using %t : tensor<?xf32>
return %0 : tensor<?xf32>
}

// CHECK-LABEL: @tensor_noop_transfer
func.func @tensor_noop_transfer(%arg0 : tensor<?xf32>) -> (tensor<?xf32>, !quidditch_snitch.dma_token) {
// CHECK: %[[T:.*]] = quidditch_snitch.completed_token
// CHECK: %[[R:.*]] = quidditch_snitch.wait_for_tensor_copy
// CHECK-NEXT: return %[[R]], %[[T]]
%r, %t = quidditch_snitch.start_tensor_copy %arg0 to L1 : tensor<?xf32>
%0 = quidditch_snitch.wait_for_tensor_copy of %arg0 to %r using %t : tensor<?xf32>
%r2, %t2 = quidditch_snitch.start_tensor_copy %0 to L1 : tensor<?xf32>
return %r2, %t2 : tensor<?xf32>, !quidditch_snitch.dma_token
}

0 comments on commit 6293803

Please sign in to comment.