Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quidditch_snitch] Simplify specialize-dma-code using interfaces #115

Merged
merged 1 commit into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions codegen/compiler/src/Quidditch/Dialect/Snitch/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ iree_cc_library(
"QuidditchSnitchAttrs.h.inc"
"QuidditchSnitchDialect.cpp.inc"
"QuidditchSnitchDialect.h.inc"
"QuidditchSnitchInterfaces.cpp.inc"
"QuidditchSnitchInterfaces.h.inc"
"QuidditchSnitchOps.cpp.inc"
"QuidditchSnitchOps.h.inc"
"QuidditchSnitchTypes.cpp.inc"
"QuidditchSnitchTypes.h.inc"
SRCS
"QuidditchSnitchAttrs.cpp"
"QuidditchSnitchDialect.cpp"
"QuidditchSnitchInterfaces.cpp"
"QuidditchSnitchOps.cpp"
"QuidditchSnitchTypes.cpp"
DEPS
::QuidditchSnitchAttrsGen
::QuidditchSnitchDialectGen
::QuidditchSnitchInterfacesGen
::QuidditchSnitchOpsGen
::QuidditchSnitchTypesGen
LLVMSupport
Expand Down Expand Up @@ -72,3 +76,13 @@ iree_tablegen_library(
--gen-typedef-decls QuidditchSnitchTypes.h.inc
--gen-typedef-defs QuidditchSnitchTypes.cpp.inc
)

iree_tablegen_library(
NAME
QuidditchSnitchInterfacesGen
TD_FILE
"QuidditchSnitchInterfaces.td"
OUTS
--gen-op-interface-decls QuidditchSnitchInterfaces.h.inc
--gen-op-interface-defs QuidditchSnitchInterfaces.cpp.inc
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "QuidditchSnitchInterfaces.h"

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchInterfaces.cpp.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#pragma once

#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchInterfaces.h.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHINTERFACES
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHINTERFACES

include "mlir/IR/Interfaces.td"

def QuidditchSnitch_CoreSpecializationOpInterface
: OpInterface<"CoreSpecializationOpInterface"> {
let cppNamespace = "::quidditch::Snitch";

let description = [{
Interface used as a base class for ops meant to only run on a specific core.
When specializing a function for a specific core, ops implementing this
interface but not supported on a specific core will be removed using
`replaceWithNoop`.
}];

let methods = [
InterfaceMethod<
/*desc=*/[{
Method called to replace this operation with a noop in an unsupported
specialization. `rewriter`s insertion point is set right before the
operation.

The op must have been erased when this method returns.
}],
/*retTy=*/"void",
/*methodName=*/"replaceWithNoop",
/*args=*/(ins "mlir::RewriterBase&":$rewriter)
>,
InterfaceMethod<
/*desc=*/[{
Returns true if this operation requires synchronization between all cores.
}],
"bool", "needsSynchronization", (ins), [{}], [{
return false;
}]
>
];
}

def QuidditchSnitch_DMACoreSpecializationOpInterface
: OpInterface<"DMACoreSpecializationOpInterface", [QuidditchSnitch_CoreSpecializationOpInterface]> {
let cppNamespace = "::quidditch::Snitch";
}

def QuidditchSnitch_ComputeCoreSpecializationOpInterface
: OpInterface<"ComputeCoreSpecializationOpInterface", [QuidditchSnitch_CoreSpecializationOpInterface]> {
let cppNamespace = "::quidditch::Snitch";
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,14 @@ void MemRefMicrokernelOp::getCanonicalizationPatterns(
ReplaceIdenticalArguments>(context);
}

//===----------------------------------------------------------------------===//
// MemRefMicrokernelOp::ComputeCoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void MemRefMicrokernelOp::replaceWithNoop(RewriterBase &rewriter) {
rewriter.eraseOp(*this);
}

//===----------------------------------------------------------------------===//
// CallMicrokernelOp
//===----------------------------------------------------------------------===//
Expand All @@ -379,6 +387,14 @@ LogicalResult CallMicrokernelOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// MicrokernelFenceOp::ComputeCoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void MicrokernelFenceOp::replaceWithNoop(RewriterBase &rewriter) {
rewriter.eraseOp(*this);
}

//===----------------------------------------------------------------------===//
// StartTensorCopyOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -647,6 +663,22 @@ OpFoldResult StartDMATransferOp::fold(FoldAdaptor adaptor) {
return CompletedTokenAttr::get(getContext());
}

//===----------------------------------------------------------------------===//
// StartDMATransferOp::DMACoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void StartDMATransferOp::replaceWithNoop(RewriterBase &rewriter) {
rewriter.replaceOpWithNewOp<CompletedTokenOp>(*this);
}

//===----------------------------------------------------------------------===//
// StartZeroMemTransferOp::DMACoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void StartZeroMemTransferOp::replaceWithNoop(RewriterBase &rewriter) {
rewriter.replaceOpWithNewOp<CompletedTokenOp>(*this);
}

//===----------------------------------------------------------------------===//
// WaitForDMATransfersOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -674,6 +706,28 @@ LogicalResult WaitForDMATransfersOp::canonicalize(WaitForDMATransfersOp op,
return success();
}

//===----------------------------------------------------------------------===//
// WaitForDMATransfersOp::DMACoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void WaitForDMATransfersOp::replaceWithNoop(RewriterBase &rewriter) {
rewriter.eraseOp(*this);
}

//===----------------------------------------------------------------------===//
// ComputeCoreIndexOp::ComputeCoreSpecializationOpInterface
//===----------------------------------------------------------------------===//

void ComputeCoreIndexOp::replaceWithNoop(RewriterBase &rewriter) {
// Make the DMA core follow the control flow of the first compute core.
// This whole pass runs under the assumption that any operation that is
// run on either the DMA core or compute cores are in non-divergent
// control flow. Making the DMA core follow any compute cores control
// flow is therefore safe to do.
// This is mainly required for barriers within a `scf.forall`.
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(*this, 0);
}

//===----------------------------------------------------------------------===//
// PipelineOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "QuidditchSnitchInterfaces.h"
#include "QuidditchSnitchTypes.h"

#define GET_OP_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHOPS

include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.td"
include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchInterfaces.td"
include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchTypes.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/IR/CommonTypeConstraints.td"
Expand Down Expand Up @@ -91,7 +92,7 @@ def QuidditchSnitch_SyncTensorOp : QuidditchSnitch_Op<"sync_tensor",

def QuidditchSnitch_MemRefMicrokernelOp
: QuidditchSnitch_Op<"memref.microkernel", [IsolatedFromAbove, SingleBlock,
NoTerminator]> {
NoTerminator, QuidditchSnitch_ComputeCoreSpecializationOpInterface]> {

let description = [{
Operation denoting a region of operations as a microkernel.
Expand All @@ -117,6 +118,8 @@ def QuidditchSnitch_MemRefMicrokernelOp

let extraClassDeclaration = [{
mlir::Block* createEntryBlock();

void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

Expand Down Expand Up @@ -149,7 +152,8 @@ def QuidditchSnitch_CallMicrokernelOp
}];
}

def QuidditchSnitch_MicrokernelFenceOp : QuidditchSnitch_Op<"microkernel_fence"> {
def QuidditchSnitch_MicrokernelFenceOp : QuidditchSnitch_Op<"microkernel_fence",
[QuidditchSnitch_ComputeCoreSpecializationOpInterface]> {

let description = [{
Execution of this operation guarantees that the side-effects of all
Expand All @@ -160,6 +164,14 @@ def QuidditchSnitch_MicrokernelFenceOp : QuidditchSnitch_Op<"microkernel_fence">
let assemblyFormat = [{
attr-dict
}];

let extraClassDeclaration = [{
bool needsSynchronization() {
return true;
}

void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

def QuidditchSnitch_StartTensorCopyOp : QuidditchSnitch_Op<"start_tensor_copy",
Expand Down Expand Up @@ -247,7 +259,8 @@ def QuidditchSnitch_L1MemoryViewOp : QuidditchSnitch_Op<"l1_memory_view",
}

def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer",
[MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape]> {
[MemoryEffects<[MemWrite]>, SameOperandsElementType, SameOperandsShape,
QuidditchSnitch_DMACoreSpecializationOpInterface]> {

let description = [{
Operation performing a DMA transfer from one MemRef to another.
Expand All @@ -271,10 +284,15 @@ def QuidditchSnitch_StartDMATransferOp : QuidditchSnitch_Op<"start_dma_transfer"
}];

let hasFolder = 1;

let extraClassDeclaration = [{
void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

def QuidditchSnitch_StartZeroMemTransferOp : QuidditchSnitch_Op<"start_zero_mem_transfer",
[MemoryEffects<[MemWrite]>]> {
[MemoryEffects<[MemWrite]>,
QuidditchSnitch_DMACoreSpecializationOpInterface]> {

let description = [{

Expand All @@ -289,10 +307,16 @@ def QuidditchSnitch_StartZeroMemTransferOp : QuidditchSnitch_Op<"start_zero_mem_
let assemblyFormat = [{
$filled `:` type($filled) attr-dict
}];

let extraClassDeclaration = [{
void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

def QuidditchSnitch_WaitForDMATransfersOp
: QuidditchSnitch_Op<"wait_for_dma_transfers"> {
: QuidditchSnitch_Op<"wait_for_dma_transfers", [
QuidditchSnitch_DMACoreSpecializationOpInterface
]> {

let description = [{
Operation awaiting for DMA transfers denoted by its tokens to be finished.
Expand All @@ -308,6 +332,14 @@ def QuidditchSnitch_WaitForDMATransfersOp

let hasFolder = 1;
let hasCanonicalizeMethod = 1;

let extraClassDeclaration = [{
bool needsSynchronization() {
return true;
}

void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

def QuidditchSnitch_CompletedTokenOp
Expand All @@ -334,7 +366,8 @@ def QuidditchSnitch_BarrierOp : QuidditchSnitch_Op<"barrier"> {
}

def QuidditchSnitch_ComputeCoreIndexOp
: QuidditchSnitch_Op<"compute_core_index", [Pure]> {
: QuidditchSnitch_Op<"compute_core_index", [Pure,
QuidditchSnitch_ComputeCoreSpecializationOpInterface]> {

let description = [{
Returns the index of the compute core within a given cluster.
Expand All @@ -348,6 +381,10 @@ def QuidditchSnitch_ComputeCoreIndexOp
let assemblyFormat = [{
attr-dict
}];

let extraClassDeclaration = [{
void replaceWithNoop(mlir::RewriterBase& rewriter);
}];
}

def QuidditchSnitch_PipelineOp : QuidditchSnitch_Op<"pipeline",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,49 +28,29 @@ class SpecializeDMACode
using namespace mlir;
using namespace quidditch::Snitch;

static void removeComputeOps(FunctionOpInterface dmaCode) {
dmaCode->walk([&](Operation *operation) {
if (isa<MemRefMicrokernelOp, MicrokernelFenceOp>(operation))
operation->erase();
if (auto index = dyn_cast<ComputeCoreIndexOp>(operation)) {
OpBuilder builder(operation);
// Make the DMA core follow the control flow of the first compute core.
// This whole pass runs under the assumption that any operation that is
// run on either the DMA core or compute cores are in non-divergent
// control flow. Making the DMA core follow any compute cores control
// flow is therefore safe to do.
// This is mainly required for barriers within a `scf.forall`.
operation->replaceAllUsesWith(
builder.create<arith::ConstantIndexOp>(operation->getLoc(), 0));
operation->erase();
}
});
}
/// Removes all operations from 'function' that implement
/// 'CoreSpecializationOpInterface' but not 'Interface'.
template <typename Interface>
static void removeUnsupportedSpecializedOps(FunctionOpInterface function) {
function->walk([&](CoreSpecializationOpInterface operation) {
if (isa<Interface>(*operation))
return;

static void removeDmaCode(FunctionOpInterface computeCode) {
SmallVector<Operation *> toDelete;
computeCode->walk([&](Operation *operation) {
if (isa<WaitForDMATransfersOp>(operation))
operation->erase();
if (isa<StartDMATransferOp>(operation)) {
OpBuilder builder(operation);
operation->replaceAllUsesWith(
builder.create<CompletedTokenOp>(operation->getLoc()));
operation->erase();
}
IRRewriter rewriter(operation);
operation.replaceWithNoop(rewriter);
});
}

/// Inserts a barrier after every operation requiring according to
/// 'CoreSpecializationOpInterface'.
/// Note: Does not currently support barriers in divergent control flow.
static void insertBarriers(FunctionOpInterface function) {
function->walk([](Operation *operation) {
OpBuilder builder(operation->getContext());
if (isa<WaitForDMATransfersOp, MicrokernelFenceOp>(operation)) {
// Barrier needs to be after the wait to signal to compute ops the
// transfer is done.
builder.setInsertionPointAfter(operation);
} else
function->walk([](CoreSpecializationOpInterface operation) {
if (!operation.needsSynchronization())
return;

OpBuilder builder(operation.getContext());
builder.setInsertionPointAfter(operation);
builder.create<BarrierOp>(operation->getLoc());
});
}
Expand All @@ -92,7 +72,8 @@ void SpecializeDMACode::runOnOperation() {
dialect->getDmaSpecializationAttrHelper().setAttr(
function, FlatSymbolRefAttr::get(clone));

removeComputeOps(clone);
removeDmaCode(function);
removeUnsupportedSpecializedOps<ComputeCoreSpecializationOpInterface>(
function);
removeUnsupportedSpecializedOps<DMACoreSpecializationOpInterface>(clone);
}
}