-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FormMicrokernels] Improve and rename from
OutlineLinalg...
(#47)
The pass implementation has been simplified by now running on functions instead of modules and other cleanups. More importantly, it now properly traverses into regions to form microkernels.
- Loading branch information
Showing
6 changed files
with
107 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
codegen/compiler/src/Quidditch/Target/FormMicrokernels.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#include "Passes.h" | ||
|
||
#include <mlir/Dialect/Linalg/IR/Linalg.h> | ||
#include <mlir/IR/IRMapping.h> | ||
#include <mlir/Interfaces/FunctionInterfaces.h> | ||
|
||
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" | ||
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" | ||
|
||
namespace quidditch { | ||
#define GEN_PASS_DEF_FORMMICROKERNELSPASS | ||
#include "Quidditch/Target/Passes.h.inc" | ||
} // namespace quidditch | ||
|
||
using namespace mlir; | ||
using namespace quidditch::Snitch; | ||
|
||
namespace { | ||
class FormMicrokernels | ||
: public quidditch::impl::FormMicrokernelsPassBase<FormMicrokernels> { | ||
public: | ||
using Base::Base; | ||
|
||
protected: | ||
void runOnOperation() override; | ||
}; | ||
} // namespace | ||
|
||
static void outlineOpsToFunction(MutableArrayRef<linalg::LinalgOp> ops) { | ||
if (ops.empty()) | ||
return; | ||
|
||
SetVector<Value> escapingResults; | ||
for (Operation *op : ops) | ||
for (Value result : op->getResults()) | ||
for (OpOperand &use : result.getUses()) | ||
if (!llvm::is_contained(ops, use.getOwner())) | ||
escapingResults.insert(use.get()); | ||
|
||
auto builder = OpBuilder(ops.front()); | ||
|
||
auto kernelOp = builder.create<TensorMicrokernelOp>( | ||
ops.front()->getLoc(), | ||
llvm::map_to_vector(escapingResults, std::mem_fn(&Value::getType))); | ||
|
||
Block *block = &kernelOp.getBody().emplaceBlock(); | ||
builder.setInsertionPointToStart(block); | ||
|
||
for (Operation *op : ops) { | ||
op->remove(); | ||
builder.insert(op); | ||
} | ||
|
||
builder.create<MicrokernelYieldOp>(ops.back().getLoc(), | ||
escapingResults.getArrayRef()); | ||
|
||
SmallVector<Value> vector = escapingResults.takeVector(); | ||
for (auto [index, value] : llvm::enumerate(vector)) | ||
value.replaceUsesWithIf(kernelOp.getResult(index), [&](OpOperand &operand) { | ||
return !kernelOp.getBody().isAncestor( | ||
operand.getOwner()->getParentRegion()); | ||
}); | ||
} | ||
|
||
void FormMicrokernels::runOnOperation() { | ||
FunctionOpInterface func = getOperation(); | ||
|
||
// We add this suffix for tooling to know whether the kernel was xDSL | ||
// compiled. It should have as little semantic impact as possible. | ||
func.setName((func.getName() + "$iree_to_xdsl").str()); | ||
|
||
SmallVector<linalg::LinalgOp> outlinedOps; | ||
func.walk([&](Block *block) { | ||
for (Operation &op : *block) { | ||
auto linalgOp = dyn_cast<linalg::LinalgOp>(op); | ||
if (!linalgOp) { | ||
outlineOpsToFunction(outlinedOps); | ||
outlinedOps.clear(); | ||
continue; | ||
} | ||
outlinedOps.push_back(linalgOp); | ||
} | ||
}); | ||
} |
97 changes: 0 additions & 97 deletions
97
codegen/compiler/src/Quidditch/Target/OutlineLinalgOpsToxDSL.cpp
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// RUN: quidditch-opt %s -p "builtin.module(func.func(quidditch-form-microkernels))" --allow-unregistered-dialect | FileCheck %s | ||
|
||
// CHECK-LABEL: @linalgs_in_scf | ||
func.func @linalgs_in_scf(%cond : i1) { | ||
%cst0 = arith.constant 0.0 : f32 | ||
// CHECK: scf.if | ||
scf.if %cond { | ||
// CHECK: quidditch_snitch.tensor.microkernel | ||
// CHECK-NEXT: linalg.fill | ||
%empty = tensor.empty() : tensor<32xf32> | ||
%fill = linalg.fill ins(%cst0 : f32) outs(%empty : tensor<32xf32>) -> tensor<32xf32> | ||
"test.use"(%fill) : (tensor<32xf32>) -> () | ||
} | ||
return | ||
} |