Skip to content

Commit

Permalink
[compiler] Introduce quidditch_snitch dialect (#37)
Browse files Browse the repository at this point in the history
We are starting to introduce more snitch and xDSL specific concepts into
the codegen pipeline that we'll want to encapsulate into operations that
are progressively lowered and optimized. As a simple start, the dialect
adds the `xdsl_kernel` op which contains linalg operations that are to
be compiled by xDSL. Many discardable attributes that we've been using
can now also be represented as dialect attribute.
  • Loading branch information
zero9178 authored Jun 20, 2024
1 parent a3ed821 commit 2ce63b1
Show file tree
Hide file tree
Showing 24 changed files with 480 additions and 240 deletions.
25 changes: 25 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
iree_add_all_subdirs()

iree_tablegen_library(
NAME
PassesIncGen
TD_FILE
"Passes.td"
OUTS
--gen-pass-decls Passes.h.inc
)

iree_cc_library(
NAME
ConvertToRISCV
HDRS
"Passes.h"
"Passes.h.inc"
SRCS
"ConvertToRISCV.cpp"
DEPS
::PassesIncGen
Quidditch::Dialect::Snitch::QuidditchSnitchDialect
MLIRFuncDialect
MLIRIR
)
190 changes: 190 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/ConvertToRISCV.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#include "Passes.h"

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Program.h"

#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.h"
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"

namespace quidditch {
#define GEN_PASS_DEF_CONVERTTORISCVPASS
#include "Quidditch/Conversion/Passes.h.inc"
} // namespace quidditch

using namespace mlir;
using namespace quidditch::Snitch;

namespace {
class ConvertToRISCV
: public quidditch::impl::ConvertToRISCVPassBase<ConvertToRISCV> {
public:
using Base::Base;

protected:
void runOnOperation() override;

private:
FailureOr<StringAttr> convertToRISCVAssembly(XDSLKernelOp kernelOp,
StringAttr kernelName);
};
} // namespace

static bool canUseBarepointerCC(Type type) {
auto memRef = dyn_cast<MemRefType>(type);
if (!memRef)
return true;
if (isa<UnrankedMemRefType>(memRef))
return false;

int64_t offset = 0;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memRef, strides, offset)))
return false;

for (int64_t stride : strides)
if (ShapedType::isDynamic(stride))
return false;

return !ShapedType::isDynamic(offset);
}

FailureOr<StringAttr>
ConvertToRISCV::convertToRISCVAssembly(XDSLKernelOp kernelOp,
StringAttr kernelName) {
if (!llvm::all_of(kernelOp.getBody().getArgumentTypes(),
canUseBarepointerCC)) {
auto emit =
assertCompiled ? &XDSLKernelOp::emitError : &XDSLKernelOp::emitWarning;

(kernelOp.*emit)("function inputs ")
<< kernelOp.getBody().getArgumentTypes()
<< " do not support bare-pointer calling convention required by "
"xDSL.";
return failure();
}

OpBuilder builder(&getContext());
OwningOpRef<func::FuncOp> tempFuncOp = builder.create<func::FuncOp>(
kernelOp.getLoc(), kernelName,
builder.getFunctionType(kernelOp.getBody().getArgumentTypes(), {}));
IRMapping mapping;
kernelOp.getBody().cloneInto(&tempFuncOp->getBody(), mapping);
builder.setInsertionPointToEnd(&tempFuncOp->getBody().back());
builder.create<func::ReturnOp>(kernelOp.getLoc());

SmallString<64> stdinFile;
int stdinFd;
if (llvm::sys::fs::createTemporaryFile("xdsl-in", "mlir", stdinFd, stdinFile))
return failure();

llvm::FileRemover stdinFileRemove(stdinFile);
{
llvm::raw_fd_ostream ss(stdinFd, /*shouldClose=*/true);
tempFuncOp->print(ss, OpPrintingFlags().useLocalScope());
}

SmallString<64> stdoutFile;
if (llvm::sys::fs::createTemporaryFile("xdsl-out", "S", stdoutFile))
return failure();

llvm::FileRemover stdoutFileRemove(stdoutFile);

SmallString<64> stderrFile;
if (llvm::sys::fs::createTemporaryFile("xdsl-diag", "S", stderrFile))
return failure();

llvm::FileRemover stderrFileRemove(stderrFile);

std::optional<llvm::StringRef> redirects[3] = {/*stdin=*/stdinFile.str(),
/*stdout=*/stdoutFile.str(),
/*stderr=*/stderrFile.str()};
int ret = llvm::sys::ExecuteAndWait(
xDSLOptPath,
{xDSLOptPath, "-p",
"convert-linalg-to-memref-stream,memref-streamify,convert-"
"memref-stream-to-loops,scf-for-loop-flatten,"
"arith-add-fastmath,loop-hoist-memref,lower-affine,convert-memref-"
"stream-to-snitch,convert-func-to-"
"riscv-func,convert-memref-to-riscv,convert-arith-to-riscv,"
"convert-scf-to-riscv-scf,dce,reconcile-unrealized-casts,test-"
"lower-snitch-stream-to-asm",
"-t", "riscv-asm"},
std::nullopt, redirects);
if (ret != 0) {
auto diagEmit =
assertCompiled ? &Operation::emitError : &Operation::emitWarning;

InFlightDiagnostic diag =
((kernelOp)->*diagEmit)("Failed to translate kernel with xDSL");

if (llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer =
llvm::MemoryBuffer::getFile(stderrFile, /*IsText=*/true))
diag.attachNote() << "stderr:\n" << buffer.get()->getBuffer();

return diag;
}

llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer =
llvm::MemoryBuffer::getFile(stdoutFile, /*IsText=*/true);
if (!buffer)
return kernelOp.emitError("failed to open ") << stdoutFile;

return StringAttr::get(&getContext(), (*buffer)->getBuffer());
}

void ConvertToRISCV::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
auto *dialect = getContext().getLoadedDialect<QuidditchSnitchDialect>();

std::size_t kernelIndex = 0;
module.walk([&](XDSLKernelOp kernelOp) {
auto exit = llvm::make_scope_exit([&] { kernelOp.erase(); });

auto parentFuncOp = kernelOp->getParentOfType<func::FuncOp>();
auto kernelName = StringAttr::get(
&getContext(), llvm::formatv("{0}$xdsl_kernel{1}",
parentFuncOp.getSymName(), kernelIndex++)
.str());

FailureOr<StringAttr> riscvAssembly =
convertToRISCVAssembly(kernelOp, kernelName);
if (failed(riscvAssembly)) {
if (assertCompiled) {
signalPassFailure();
return WalkResult::interrupt();
}

dialect->getXdslCompilationFailedAttrHelper().setAttr(
kernelOp->getParentOfType<func::FuncOp>(),
UnitAttr::get(&getContext()));
return WalkResult::advance();
}

auto builder = OpBuilder::atBlockEnd(module.getBody());

auto kernelDecl = builder.create<func::FuncOp>(
kernelOp.getLoc(), kernelName,
builder.getFunctionType(kernelOp.getBody().getArgumentTypes(), {}));

kernelDecl.setVisibility(SymbolTable::Visibility::Private);
// Required to tell the conversion pass to LLVM that this is actually a
// call into the same linkage unit, and does not have to be rewritten to a
// HAL module call.
kernelDecl->setAttr("hal.import.bitcode", UnitAttr::get(&getContext()));
kernelDecl->setAttr("llvm.bareptr", UnitAttr::get(&getContext()));

dialect->getRiscvAssemblyAttrHelper().setAttr(kernelDecl, *riscvAssembly);

builder.setInsertionPoint(kernelOp);
builder.create<func::CallOp>(kernelOp.getLoc(), kernelDecl,
kernelOp.getInputs());
return WalkResult::advance();
});
}
10 changes: 10 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

#pragma once

#include <mlir/IR/BuiltinOps.h>
#include <mlir/Pass/Pass.h>

namespace quidditch {
#define GEN_PASS_DECL
#include "Quidditch/Conversion/Passes.h.inc"
} // namespace quidditch
21 changes: 21 additions & 0 deletions codegen/compiler/src/Quidditch/Conversion/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef QUIDDITCH_CONVERSION_PASSES
#define QUIDDITCH_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def ConvertToRISCVPass : Pass<"quidditch-convert-to-riscv", "mlir::ModuleOp"> {
let options = [
Option<"xDSLOptPath", "xdsl-opt-path", "std::string", [{""}],
"Path to the 'xdsl-opt' executable to use for kernel compilation.">,
Option<"assertCompiled", "assert-compiled", "bool", "false",
"If true, errors if any kernel could not be compiled with xDSL."
"Otherwise, removes the kernel from the output and emits a warning "
"instead.">,
];

let dependentDialects = [
"mlir::func::FuncDialect",
];
}

#endif
1 change: 1 addition & 0 deletions codegen/compiler/src/Quidditch/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
iree_add_all_subdirs()
45 changes: 45 additions & 0 deletions codegen/compiler/src/Quidditch/Dialect/Snitch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
iree_add_all_subdirs()

iree_cc_library(
NAME
QuidditchSnitchDialect
HDRS
"QuidditchSnitchDialect.h"
"QuidditchSnitchOps.h"
TEXTUAL_HDRS
"QuidditchSnitchDialect.cpp.inc"
"QuidditchSnitchDialect.h.inc"
"QuidditchSnitchOps.cpp.inc"
"QuidditchSnitchOps.h.inc"
SRCS
"QuidditchSnitchDialect.cpp"
"QuidditchSnitchOps.cpp"
DEPS
::QuidditchSnitchDialectGen
::QuidditchSnitchOpsGen
LLVMSupport
MLIRIR
MLIRInferTypeOpInterface
MLIRSupport
PUBLIC
)

iree_tablegen_library(
NAME
QuidditchSnitchOpsGen
TD_FILE
"QuidditchSnitchOps.td"
OUTS
--gen-op-decls QuidditchSnitchOps.h.inc
--gen-op-defs QuidditchSnitchOps.cpp.inc
)

iree_tablegen_library(
NAME
QuidditchSnitchDialectGen
TD_FILE
"QuidditchSnitchDialect.td"
OUTS
--gen-dialect-decls QuidditchSnitchDialect.h.inc
--gen-dialect-defs QuidditchSnitchDialect.cpp.inc
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "QuidditchSnitchDialect.h"

#include "QuidditchSnitchOps.h"

#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.cpp.inc"

using namespace quidditch::Snitch;

void QuidditchSnitchDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp.inc"
>();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#pragma once

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

#include "Quidditch/Dialect/Snitch/QuidditchSnitchDialect.h.inc"
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHDIALECT
#define QUIDDITCH_DIALECT_SNITCH_QUIDDITCHSNITCHDIALECT

include "mlir/IR/DialectBase.td"

def QuidditchSnitch_Dialect : Dialect {
let name = "quidditch_snitch";
let cppNamespace = "::quidditch::Snitch";

let discardableAttrs = (ins
"mlir::StringAttr":$riscv_assembly,
"mlir::UnitAttr":$xdsl_compilation_failed
);
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "QuidditchSnitchOps.h"

#include "mlir/IR/TypeUtilities.h"

#define GET_OP_CLASSES
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.cpp.inc"

using namespace mlir;
using namespace quidditch::Snitch;

Block *XDSLKernelOp::createEntryBlock() {
assert(getBody().getBlocks().empty());
Block &block = getBody().emplaceBlock();
block.addArguments(getInputs().getTypes(),
SmallVector<Location>(getInputs().size(), getLoc()));
return &block;
}

LogicalResult XDSLKernelOp::verify() {
// TODO: Weaken this in the future to likely only require element count and
// type but not the shape. TBD.
if (getBody().getArgumentTypes() != getInputs().getTypes())
return emitOpError("type of arguments and inputs must match");
return success();
}
10 changes: 10 additions & 0 deletions codegen/compiler/src/Quidditch/Dialect/Snitch/QuidditchSnitchOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

#pragma once

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

#define GET_OP_CLASSES
#include "Quidditch/Dialect/Snitch/QuidditchSnitchOps.h.inc"
Loading

0 comments on commit 2ce63b1

Please sign in to comment.