-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[compiler] Introduce
quidditch_snitch
dialect (#37)
We are starting to introduce more snitch and xDSL specific concepts into the codegen pipeline that we'll want to encapsulate into operations that are progressively lowered and optimized. As a simple start, the dialect adds the `xdsl_kernel` op which contains linalg operations that are to be compiled by xDSL. Many discardable attributes that we've been using can now also be represented as dialect attribute.
- Loading branch information
Showing
24 changed files
with
480 additions
and
240 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
iree_add_all_subdirs() | ||
|
||
iree_tablegen_library( | ||
NAME | ||
PassesIncGen | ||
TD_FILE | ||
"Passes.td" | ||
OUTS | ||
--gen-pass-decls Passes.h.inc | ||
) | ||
|
||
iree_cc_library( | ||
NAME | ||
ConvertToRISCV | ||
HDRS | ||
"Passes.h" | ||
"Passes.h.inc" | ||
SRCS | ||
"ConvertToRISCV.cpp" | ||
DEPS | ||
::PassesIncGen | ||
Quidditch::Dialect::Snitch::QuidditchSnitchDialect | ||
MLIRFuncDialect | ||
MLIRIR | ||
) |
190 changes: 190 additions & 0 deletions
190
codegen/compiler/src/Quidditch/Conversion/ConvertToRISCV.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
#include "Passes.h" | ||
|
||
#include "llvm/ADT/ScopeExit.h" | ||
#include "llvm/Support/FileUtilities.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
#include "llvm/Support/MemoryBuffer.h" | ||
#include "llvm/Support/Program.h" | ||
|
||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.h" | ||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.h" | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/IRMapping.h" | ||
|
||
namespace quidditch { | ||
#define GEN_PASS_DEF_CONVERTTORISCVPASS | ||
#include "Quidditch/Conversion/Passes.h.inc" | ||
} // namespace quidditch | ||
|
||
using namespace mlir; | ||
using namespace quidditch::Snitch; | ||
|
||
namespace { | ||
class ConvertToRISCV | ||
: public quidditch::impl::ConvertToRISCVPassBase<ConvertToRISCV> { | ||
public: | ||
using Base::Base; | ||
|
||
protected: | ||
void runOnOperation() override; | ||
|
||
private: | ||
FailureOr<StringAttr> convertToRISCVAssembly(XDSLKernelOp kernelOp, | ||
StringAttr kernelName); | ||
}; | ||
} // namespace | ||
|
||
static bool canUseBarepointerCC(Type type) { | ||
auto memRef = dyn_cast<MemRefType>(type); | ||
if (!memRef) | ||
return true; | ||
if (isa<UnrankedMemRefType>(memRef)) | ||
return false; | ||
|
||
int64_t offset = 0; | ||
SmallVector<int64_t, 4> strides; | ||
if (failed(getStridesAndOffset(memRef, strides, offset))) | ||
return false; | ||
|
||
for (int64_t stride : strides) | ||
if (ShapedType::isDynamic(stride)) | ||
return false; | ||
|
||
return !ShapedType::isDynamic(offset); | ||
} | ||
|
||
FailureOr<StringAttr> | ||
ConvertToRISCV::convertToRISCVAssembly(XDSLKernelOp kernelOp, | ||
StringAttr kernelName) { | ||
if (!llvm::all_of(kernelOp.getBody().getArgumentTypes(), | ||
canUseBarepointerCC)) { | ||
auto emit = | ||
assertCompiled ? &XDSLKernelOp::emitError : &XDSLKernelOp::emitWarning; | ||
|
||
(kernelOp.*emit)("function inputs ") | ||
<< kernelOp.getBody().getArgumentTypes() | ||
<< " do not support bare-pointer calling convention required by " | ||
"xDSL."; | ||
return failure(); | ||
} | ||
|
||
OpBuilder builder(&getContext()); | ||
OwningOpRef<func::FuncOp> tempFuncOp = builder.create<func::FuncOp>( | ||
kernelOp.getLoc(), kernelName, | ||
builder.getFunctionType(kernelOp.getBody().getArgumentTypes(), {})); | ||
IRMapping mapping; | ||
kernelOp.getBody().cloneInto(&tempFuncOp->getBody(), mapping); | ||
builder.setInsertionPointToEnd(&tempFuncOp->getBody().back()); | ||
builder.create<func::ReturnOp>(kernelOp.getLoc()); | ||
|
||
SmallString<64> stdinFile; | ||
int stdinFd; | ||
if (llvm::sys::fs::createTemporaryFile("xdsl-in", "mlir", stdinFd, stdinFile)) | ||
return failure(); | ||
|
||
llvm::FileRemover stdinFileRemove(stdinFile); | ||
{ | ||
llvm::raw_fd_ostream ss(stdinFd, /*shouldClose=*/true); | ||
tempFuncOp->print(ss, OpPrintingFlags().useLocalScope()); | ||
} | ||
|
||
SmallString<64> stdoutFile; | ||
if (llvm::sys::fs::createTemporaryFile("xdsl-out", "S", stdoutFile)) | ||
return failure(); | ||
|
||
llvm::FileRemover stdoutFileRemove(stdoutFile); | ||
|
||
SmallString<64> stderrFile; | ||
if (llvm::sys::fs::createTemporaryFile("xdsl-diag", "S", stderrFile)) | ||
return failure(); | ||
|
||
llvm::FileRemover stderrFileRemove(stderrFile); | ||
|
||
std::optional<llvm::StringRef> redirects[3] = {/*stdin=*/stdinFile.str(), | ||
/*stdout=*/stdoutFile.str(), | ||
/*stderr=*/stderrFile.str()}; | ||
int ret = llvm::sys::ExecuteAndWait( | ||
xDSLOptPath, | ||
{xDSLOptPath, "-p", | ||
"convert-linalg-to-memref-stream,memref-streamify,convert-" | ||
"memref-stream-to-loops,scf-for-loop-flatten," | ||
"arith-add-fastmath,loop-hoist-memref,lower-affine,convert-memref-" | ||
"stream-to-snitch,convert-func-to-" | ||
"riscv-func,convert-memref-to-riscv,convert-arith-to-riscv," | ||
"convert-scf-to-riscv-scf,dce,reconcile-unrealized-casts,test-" | ||
"lower-snitch-stream-to-asm", | ||
"-t", "riscv-asm"}, | ||
std::nullopt, redirects); | ||
if (ret != 0) { | ||
auto diagEmit = | ||
assertCompiled ? &Operation::emitError : &Operation::emitWarning; | ||
|
||
InFlightDiagnostic diag = | ||
((kernelOp)->*diagEmit)("Failed to translate kernel with xDSL"); | ||
|
||
if (llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer = | ||
llvm::MemoryBuffer::getFile(stderrFile, /*IsText=*/true)) | ||
diag.attachNote() << "stderr:\n" << buffer.get()->getBuffer(); | ||
|
||
return diag; | ||
} | ||
|
||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer = | ||
llvm::MemoryBuffer::getFile(stdoutFile, /*IsText=*/true); | ||
if (!buffer) | ||
return kernelOp.emitError("failed to open ") << stdoutFile; | ||
|
||
return StringAttr::get(&getContext(), (*buffer)->getBuffer()); | ||
} | ||
|
||
void ConvertToRISCV::runOnOperation() { | ||
ModuleOp module = getOperation(); | ||
SymbolTable symbolTable(module); | ||
auto *dialect = getContext().getLoadedDialect<QuidditchSnitchDialect>(); | ||
|
||
std::size_t kernelIndex = 0; | ||
module.walk([&](XDSLKernelOp kernelOp) { | ||
auto exit = llvm::make_scope_exit([&] { kernelOp.erase(); }); | ||
|
||
auto parentFuncOp = kernelOp->getParentOfType<func::FuncOp>(); | ||
auto kernelName = StringAttr::get( | ||
&getContext(), llvm::formatv("{0}$xdsl_kernel{1}", | ||
parentFuncOp.getSymName(), kernelIndex++) | ||
.str()); | ||
|
||
FailureOr<StringAttr> riscvAssembly = | ||
convertToRISCVAssembly(kernelOp, kernelName); | ||
if (failed(riscvAssembly)) { | ||
if (assertCompiled) { | ||
signalPassFailure(); | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
dialect->getXdslCompilationFailedAttrHelper().setAttr( | ||
kernelOp->getParentOfType<func::FuncOp>(), | ||
UnitAttr::get(&getContext())); | ||
return WalkResult::advance(); | ||
} | ||
|
||
auto builder = OpBuilder::atBlockEnd(module.getBody()); | ||
|
||
auto kernelDecl = builder.create<func::FuncOp>( | ||
kernelOp.getLoc(), kernelName, | ||
builder.getFunctionType(kernelOp.getBody().getArgumentTypes(), {})); | ||
|
||
kernelDecl.setVisibility(SymbolTable::Visibility::Private); | ||
// Required to tell the conversion pass to LLVM that this is actually a | ||
// call into the same linkage unit, and does not have to be rewritten to a | ||
// HAL module call. | ||
kernelDecl->setAttr("hal.import.bitcode", UnitAttr::get(&getContext())); | ||
kernelDecl->setAttr("llvm.bareptr", UnitAttr::get(&getContext())); | ||
|
||
dialect->getRiscvAssemblyAttrHelper().setAttr(kernelDecl, *riscvAssembly); | ||
|
||
builder.setInsertionPoint(kernelOp); | ||
builder.create<func::CallOp>(kernelOp.getLoc(), kernelDecl, | ||
kernelOp.getInputs()); | ||
return WalkResult::advance(); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
#pragma once | ||
|
||
#include <mlir/IR/BuiltinOps.h> | ||
#include <mlir/Pass/Pass.h> | ||
|
||
namespace quidditch { | ||
#define GEN_PASS_DECL | ||
#include "Quidditch/Conversion/Passes.h.inc" | ||
} // namespace quidditch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#ifndef QUIDDITCH_CONVERSION_PASSES | ||
#define QUIDDITCH_CONVERSION_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ConvertToRISCVPass : Pass<"quidditch-convert-to-riscv", "mlir::ModuleOp"> { | ||
let options = [ | ||
Option<"xDSLOptPath", "xdsl-opt-path", "std::string", [{""}], | ||
"Path to the 'xdsl-opt' executable to use for kernel compilation.">, | ||
Option<"assertCompiled", "assert-compiled", "bool", "false", | ||
"If true, errors if any kernel could not be compiled with xDSL." | ||
"Otherwise, removes the kernel from the output and emits a warning " | ||
"instead.">, | ||
]; | ||
|
||
let dependentDialects = [ | ||
"mlir::func::FuncDialect", | ||
]; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
iree_add_all_subdirs() |
45 changes: 45 additions & 0 deletions
45
codegen/compiler/src/Quidditch/Dialect/Snitch/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
iree_add_all_subdirs() | ||
|
||
iree_cc_library( | ||
NAME | ||
QuidditchSnitchDialect | ||
HDRS | ||
"QuidditchSnitchDialect.h" | ||
"QuidditchSnitchOps.h" | ||
TEXTUAL_HDRS | ||
"QuidditchSnitchDialect.cpp.inc" | ||
"QuidditchSnitchDialect.h.inc" | ||
"QuidditchSnitchOps.cpp.inc" | ||
"QuidditchSnitchOps.h.inc" | ||
SRCS | ||
"QuidditchSnitchDialect.cpp" | ||
"QuidditchSnitchOps.cpp" | ||
DEPS | ||
::QuidditchSnitchDialectGen | ||
::QuidditchSnitchOpsGen | ||
LLVMSupport | ||
MLIRIR | ||
MLIRInferTypeOpInterface | ||
MLIRSupport | ||
PUBLIC | ||
) | ||
|
||
iree_tablegen_library( | ||
NAME | ||
QuidditchSnitchOpsGen | ||
TD_FILE | ||
"QuidditchSnitchOps.td" | ||
OUTS | ||
--gen-op-decls QuidditchSnitchOps.h.inc | ||
--gen-op-defs QuidditchSnitchOps.cpp.inc | ||
) | ||
|
||
iree_tablegen_library( | ||
NAME | ||
QuidditchSnitchDialectGen | ||
TD_FILE | ||
"QuidditchSnitchDialect.td" | ||
OUTS | ||
--gen-dialect-decls QuidditchSnitchDialect.h.inc | ||
--gen-dialect-defs QuidditchSnitchDialect.cpp.inc | ||
) |
14 changes: 14 additions & 0 deletions
14
codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchDialect.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include "QuidditchSnitchDialect.h" | ||
|
||
#include "QuidditchSnitchOps.h" | ||
|
||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.cpp.inc" | ||
|
||
using namespace quidditch::Snitch; | ||
|
||
void QuidditchSnitchDialect::initialize() { | ||
addOperations< | ||
#define GET_OP_LIST | ||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp.inc" | ||
>(); | ||
} |
7 changes: 7 additions & 0 deletions
7
codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchDialect.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
#pragma once | ||
|
||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/IR/Operation.h" | ||
|
||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.h.inc" |
16 changes: 16 additions & 0 deletions
16
codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchDialect.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHDIALECT | ||
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHDIALECT | ||
|
||
include "mlir/IR/DialectBase.td" | ||
|
||
def QuidditchSnitch_Dialect : Dialect { | ||
let name = "quidditch_snitch"; | ||
let cppNamespace = "::quidditch::Snitch"; | ||
|
||
let discardableAttrs = (ins | ||
"mlir::StringAttr":$riscv_assembly, | ||
"mlir::UnitAttr":$xdsl_compilation_failed | ||
); | ||
} | ||
|
||
#endif |
25 changes: 25 additions & 0 deletions
25
codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "QuidditchSnitchOps.h" | ||
|
||
#include "mlir/IR/TypeUtilities.h" | ||
|
||
#define GET_OP_CLASSES | ||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp.inc" | ||
|
||
using namespace mlir; | ||
using namespace quidditch::Snitch; | ||
|
||
Block *XDSLKernelOp::createEntryBlock() { | ||
assert(getBody().getBlocks().empty()); | ||
Block &block = getBody().emplaceBlock(); | ||
block.addArguments(getInputs().getTypes(), | ||
SmallVector<Location>(getInputs().size(), getLoc())); | ||
return █ | ||
} | ||
|
||
LogicalResult XDSLKernelOp::verify() { | ||
// TODO: Weaken this in the future to likely only require element count and | ||
// type but not the shape. TBD. | ||
if (getBody().getArgumentTypes() != getInputs().getTypes()) | ||
return emitOpError("type of arguments and inputs must match"); | ||
return success(); | ||
} |
10 changes: 10 additions & 0 deletions
10
codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchOps.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
#pragma once | ||
|
||
#include "mlir/Bytecode/BytecodeOpInterface.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/OpImplementation.h" | ||
#include "mlir/Interfaces/InferTypeOpInterface.h" | ||
|
||
#define GET_OP_CLASSES | ||
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.h.inc" |
Oops, something went wrong.