Skip to content

Commit

Permalink
[FormMicrokernels] Improve and rename from OutlineLinalg... (#47)
Browse files Browse the repository at this point in the history
The pass implementation has been simplified by now running on functions
instead of modules and other cleanups. More importantly, it now properly
traverses into regions to form microkernels.
  • Loading branch information
zero9178 authored Jun 27, 2024
1 parent 1759e81 commit e86161a
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 112 deletions.
2 changes: 1 addition & 1 deletion codegen/compiler/src/Quidditch/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ iree_cc_library(
"Passes.h.inc"
SRCS
"DisableQuidditchVariant.cpp"
"OutlineLinalgOpsToxDSL.cpp"
"FormMicrokernels.cpp"
"LinkExecutables.cpp"
"ReluToMax.cpp"
DEPS
Expand Down
84 changes: 84 additions & 0 deletions codegen/compiler/src/Quidditch/Target/FormMicrokernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "Passes.h"

#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/Interfaces/FunctionInterfaces.h>

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h"
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h"

namespace quidditch {
#define GEN_PASS_DEF_FORMMICROKERNELSPASS
#include "Quidditch/Target/Passes.h.inc"
} // namespace quidditch

using namespace mlir;
using namespace quidditch::Snitch;

namespace {
class FormMicrokernels
: public quidditch::impl::FormMicrokernelsPassBase<FormMicrokernels> {
public:
using Base::Base;

protected:
void runOnOperation() override;
};
} // namespace

static void outlineOpsToFunction(MutableArrayRef<linalg::LinalgOp> ops) {
if (ops.empty())
return;

SetVector<Value> escapingResults;
for (Operation *op : ops)
for (Value result : op->getResults())
for (OpOperand &use : result.getUses())
if (!llvm::is_contained(ops, use.getOwner()))
escapingResults.insert(use.get());

auto builder = OpBuilder(ops.front());

auto kernelOp = builder.create<TensorMicrokernelOp>(
ops.front()->getLoc(),
llvm::map_to_vector(escapingResults, std::mem_fn(&Value::getType)));

Block *block = &kernelOp.getBody().emplaceBlock();
builder.setInsertionPointToStart(block);

for (Operation *op : ops) {
op->remove();
builder.insert(op);
}

builder.create<MicrokernelYieldOp>(ops.back().getLoc(),
escapingResults.getArrayRef());

SmallVector<Value> vector = escapingResults.takeVector();
for (auto [index, value] : llvm::enumerate(vector))
value.replaceUsesWithIf(kernelOp.getResult(index), [&](OpOperand &operand) {
return !kernelOp.getBody().isAncestor(
operand.getOwner()->getParentRegion());
});
}

void FormMicrokernels::runOnOperation() {
FunctionOpInterface func = getOperation();

// We add this suffix for tooling to know whether the kernel was xDSL
// compiled. It should have as little semantic impact as possible.
func.setName((func.getName() + "$iree_to_xdsl").str());

SmallVector<linalg::LinalgOp> outlinedOps;
func.walk([&](Block *block) {
for (Operation &op : *block) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp) {
outlineOpsToFunction(outlinedOps);
outlinedOps.clear();
continue;
}
outlinedOps.push_back(linalgOp);
}
});
}
97 changes: 0 additions & 97 deletions codegen/compiler/src/Quidditch/Target/OutlineLinalgOpsToxDSL.cpp

This file was deleted.

16 changes: 5 additions & 11 deletions codegen/compiler/src/Quidditch/Target/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,11 @@

include "mlir/Pass/PassBase.td"

def OutlineLinalgOpsToxDSLPass
: Pass<"quidditch-outline-linalg-ops-to-xdsl", "mlir::ModuleOp"> {
let description = [{
Outlines a series of linalg operation in a function into their own
functions with a suitable calling convention for xDSL.
The generated functions can be identified via `xdsl_generated` attribute.

The original function that the outlining was performed on will have a
`xdsl_kernels` which is an array of symbol references to the outlined
functions.
}];
def FormMicrokernelsPass
: InterfacePass<"quidditch-form-microkernels", "mlir::FunctionOpInterface"> {
let dependentDialects = [
"quidditch::Snitch::QuidditchSnitchDialect",
];
}

def LinkExecutablesPass : Pass<"quidditch-link-executables", "mlir::ModuleOp"> {
Expand Down
5 changes: 2 additions & 3 deletions codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createFuseTensorPadWithConsumerPass)
.addPass(createConcretizePadResultShapePass);

modulePassManager.addPass(quidditch::createOutlineLinalgOpsToxDSLPass());
.addPass(createConcretizePadResultShapePass)
.addPass(quidditch::createFormMicrokernelsPass);

BufferizationOptions::AllocationFn allocationFn =
[](OpBuilder &builder, Location loc, MemRefType memRefType,
Expand Down
15 changes: 15 additions & 0 deletions codegen/tests/Target/form-microkernels.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: quidditch-opt %s -p "builtin.module(func.func(quidditch-form-microkernels))" --allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: @linalgs_in_scf
func.func @linalgs_in_scf(%cond : i1) {
%cst0 = arith.constant 0.0 : f32
// CHECK: scf.if
scf.if %cond {
// CHECK: quidditch_snitch.tensor.microkernel
// CHECK-NEXT: linalg.fill
%empty = tensor.empty() : tensor<32xf32>
%fill = linalg.fill ins(%cst0 : f32) outs(%empty : tensor<32xf32>) -> tensor<32xf32>
"test.use"(%fill) : (tensor<32xf32>) -> ()
}
return
}

0 comments on commit e86161a

Please sign in to comment.