Skip to content

Commit

Permalink
[quidditch_snitch] Implement canonicalization for memref microkernels (
Browse files Browse the repository at this point in the history
…#39)

These canonicalizations help to get the microkernel into a minimal form
to microoptimize the xDSL generation and also avoid hitting code
generation limits. They mostly remove redundant arguments and results.
  • Loading branch information
zero9178 authored Jun 24, 2024
1 parent 9ea52d4 commit 8a79f26
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 0 deletions.
153 changes: 153 additions & 0 deletions codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "QuidditchSnitchOps.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"

#define GET_OP_CLASSES
Expand Down Expand Up @@ -185,6 +186,158 @@ LogicalResult MemRefMicrokernelOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// MemRefMicrokernelOp Canonicalization
//===----------------------------------------------------------------------===//

namespace {
struct RemoveDeadResults : OpRewritePattern<MemRefMicrokernelOp> {
using OpRewritePattern<MemRefMicrokernelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MemRefMicrokernelOp op,
PatternRewriter &rewriter) const override {
SmallVector<bool> deadResults(op.getNumResults());
for (OpResult result : op.getResults())
if (result.use_empty())
deadResults[result.getResultNumber()] = true;

if (llvm::none_of(deadResults, [](auto value) { return value; }))
return failure();

SmallVector<Type> newResults;
for (auto [index, type] : llvm::enumerate(op.getResults().getTypes()))
if (!deadResults[index])
newResults.push_back(type);

auto replacement = rewriter.create<MemRefMicrokernelOp>(
op.getLoc(), newResults, op.getInputs());
rewriter.inlineRegionBefore(op.getBody(), replacement.getBody(),
replacement.getBody().end());
MicrokernelYieldOp yieldOp = replacement.getYieldOp();
for (auto index :
llvm::reverse(llvm::seq<std::size_t>(0, deadResults.size())))
if (deadResults[index])
rewriter.modifyOpInPlace(yieldOp, [&, index = index] {
yieldOp.getResultsMutable().erase(index);
});

unsigned nextAliveIndex = 0;
for (auto [index, dead] : llvm::enumerate(deadResults)) {
if (dead)
continue;
rewriter.replaceAllUsesWith(op.getResult(index),
replacement.getResult(nextAliveIndex));
nextAliveIndex++;
}

rewriter.eraseOp(op);
return success();
}
};

struct SinkConstantArguments : OpRewritePattern<MemRefMicrokernelOp> {
using OpRewritePattern<MemRefMicrokernelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MemRefMicrokernelOp op,
PatternRewriter &rewriter) const override {
SmallVector<std::pair<BlockArgument, Operation *>> constantOps;
for (auto [input, arg] :
llvm::zip(op.getInputs(), op.getBody().getArguments()))
if (matchPattern(input, m_Constant()))
constantOps.emplace_back(arg, input.getDefiningOp());

if (constantOps.empty())
return failure();

rewriter.modifyOpInPlace(op, [&] {
rewriter.setInsertionPointToStart(&op.getBody().front());
for (auto [repl, constantOp] : constantOps) {
Operation *clone = rewriter.clone(*constantOp);
repl.replaceAllUsesWith(clone->getResult(0));
}
});
return success();
}
};

struct ReplaceInvariantResults : OpRewritePattern<MemRefMicrokernelOp> {
using OpRewritePattern<MemRefMicrokernelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MemRefMicrokernelOp op,
PatternRewriter &rewriter) const override {
bool changed = false;
for (auto [result, yielded] :
llvm::zip_equal(op.getResults(), op.getYieldOp().getResults())) {
auto arg = dyn_cast<BlockArgument>(yielded);
if (!arg || !arg.getParentBlock()->isEntryBlock())
continue;
changed = true;
rewriter.replaceAllUsesWith(result, op.getInputs()[arg.getArgNumber()]);
}
return success(changed);
}
};

struct RemoveDeadArguments : OpRewritePattern<MemRefMicrokernelOp> {
using OpRewritePattern<MemRefMicrokernelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MemRefMicrokernelOp op,
PatternRewriter &rewriter) const override {
SmallVector<bool> deadArguments(op.getInputs().size());
for (BlockArgument arg : op.getBody().getArguments())
if (arg.use_empty())
deadArguments[arg.getArgNumber()] = true;

if (llvm::none_of(deadArguments, [](auto value) { return value; }))
return failure();

SmallVector<Value> newInputs;
for (auto [index, value] : llvm::enumerate(op.getInputs()))
if (!deadArguments[index])
newInputs.push_back(value);

auto replacement = rewriter.create<MemRefMicrokernelOp>(
op.getLoc(), op.getResults().getTypes(), newInputs);
rewriter.inlineRegionBefore(op.getBody(), replacement.getBody(),
replacement.getBody().end());
rewriter.modifyOpInPlace(replacement, [&] {
replacement.getBody().front().eraseArguments([&](BlockArgument argument) {
return deadArguments[argument.getArgNumber()];
});
});

rewriter.replaceOp(op, replacement);
return success();
}
};

struct ReplaceIdenticalArguments : OpRewritePattern<MemRefMicrokernelOp> {
using OpRewritePattern<MemRefMicrokernelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(MemRefMicrokernelOp op,
PatternRewriter &rewriter) const override {
bool changed = false;
llvm::SmallDenseMap<Value, BlockArgument> seenPreviously;
for (auto [input, blockArg] :
llvm::zip_equal(op.getInputs(), op.getBody().getArguments())) {
auto [iter, inserted] = seenPreviously.insert({input, blockArg});
if (inserted)
continue;

changed = true;
rewriter.replaceAllUsesWith(blockArg, iter->second);
}
return success(changed);
}
};
} // namespace

void MemRefMicrokernelOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<RemoveDeadResults, RemoveDeadArguments, SinkConstantArguments,
ReplaceInvariantResults, ReplaceIdenticalArguments>(context);
}

//===----------------------------------------------------------------------===//
// MemRefMicrokernelOp::RegionBranchOpInterface
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def QuidditchSnitch_MemRefMicrokernelOp
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;

let extraClassDeclaration = [{

Expand Down
60 changes: 60 additions & 0 deletions codegen/tests/Dialect/Snitch/canonicalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: quidditch-opt %s --canonicalize --split-input-file --allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: @dead_result
func.func @dead_result() {
// CHECK: quidditch_snitch.memref.microkernel() : () -> ()
%0 = quidditch_snitch.memref.microkernel() : () -> i32 {
%c = arith.constant 1 : i32
quidditch_snitch.microkernel_yield %c : i32
}
return
}

// CHECK-LABEL: @sink_constants
func.func @sink_constants() -> i32 {
%c = arith.constant 1 : i32
// CHECK: quidditch_snitch.memref.microkernel() : () -> i32
%0 = quidditch_snitch.memref.microkernel(%c) : (i32) -> i32 {
^bb0(%arg0 : i32):
// CHECK-NEXT: %[[C:.*]] = arith.constant
// CHECK-NEXT: "test.transform"(%[[C]])
%1 = "test.transform"(%arg0) : (i32) -> i32
quidditch_snitch.microkernel_yield %1 : i32
}
return %0 : i32
}

// CHECK-LABEL: @invariant_result
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func @invariant_result(%arg0 : i32) -> i32 {
%0 = quidditch_snitch.memref.microkernel(%arg0) : (i32) -> i32 {
^bb0(%arg1 : i32):
quidditch_snitch.microkernel_yield %arg1 : i32
}
// CHECK: return %[[ARG0]]
return %0 : i32
}

// CHECK-LABEL: @dead_argument
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func @dead_argument(%arg0 : i32) {
// CHECK: quidditch_snitch.memref.microkernel()
quidditch_snitch.memref.microkernel(%arg0) : (i32) -> () {
^bb0(%arg1 : i32):
quidditch_snitch.microkernel_yield
}
return
}

// CHECK-LABEL: @identical_argument
// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
func.func @identical_argument(%arg0 : i32) -> i32 {
// CHECK: quidditch_snitch.memref.microkernel(%[[ARG0]])
%0 = quidditch_snitch.memref.microkernel(%arg0, %arg0) : (i32, i32) -> i32 {
^bb0(%arg1 : i32, %arg2 : i32):
// CHECK: "test.transform"(%[[ARG1:.*]], %[[ARG1]])
%1 = "test.transform"(%arg1, %arg2) : (i32, i32) -> i32
quidditch_snitch.microkernel_yield %1 : i32
}
return %0 : i32
}
2 changes: 2 additions & 0 deletions codegen/tools/quidditch-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Quidditch/Target/Passes.h>

#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"

namespace quidditch {
#define GEN_PASS_REGISTRATION
Expand All @@ -22,6 +23,7 @@ int main(int argc, char **argv) {

quidditch::registerPasses();
mlir::bufferization::registerBufferizationPasses();
mlir::registerTransformsPasses();

return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "MLIR modular optimizer driver\n", registry));
Expand Down

0 comments on commit 8a79f26

Please sign in to comment.