From 4106a01323cf901915a30680114a053b94df92b4 Mon Sep 17 00:00:00 2001 From: Magnus Sjalander Date: Fri, 2 Feb 2024 04:29:57 +0100 Subject: [PATCH] MLIR RVSDG backend (#348) --- jlm/llvm/ir/linkage.hpp | 25 ++ jlm/mlir/Makefile.sub | 3 + jlm/mlir/backend/JlmToMlirConverter.cpp | 221 ++++++++++++++++++ jlm/mlir/backend/JlmToMlirConverter.hpp | 125 ++++++++++ .../mlir/backend/TestJlmToMlirConverter.cpp | 105 +++++++++ 5 files changed, 479 insertions(+) create mode 100644 jlm/mlir/backend/JlmToMlirConverter.cpp create mode 100644 jlm/mlir/backend/JlmToMlirConverter.hpp create mode 100644 tests/jlm/mlir/backend/TestJlmToMlirConverter.cpp diff --git a/jlm/llvm/ir/linkage.hpp b/jlm/llvm/ir/linkage.hpp index 26de9c31a..7a996ff82 100644 --- a/jlm/llvm/ir/linkage.hpp +++ b/jlm/llvm/ir/linkage.hpp @@ -6,6 +6,10 @@ #ifndef JLM_LLVM_IR_LINKAGE_HPP #define JLM_LLVM_IR_LINKAGE_HPP +#include +#include +#include + namespace jlm::llvm { @@ -31,6 +35,27 @@ is_externally_visible(const linkage & lnk) return lnk != linkage::internal_linkage; } +static inline std::string +ToString(const linkage & lnk) +{ + static std::unordered_map strings = { + { linkage::external_linkage, "external_linkage" }, + { linkage::available_externally_linkage, "available_externally_linkage" }, + { linkage::link_once_any_linkage, "link_once_any_linkage" }, + { linkage::link_once_odr_linkage, "link_once_odr_linkage" }, + { linkage::weak_any_linkage, "weak_any_linkage" }, + { linkage::weak_odr_linkage, "weak_odr_linkage" }, + { linkage::appending_linkage, "appending_linkage" }, + { linkage::internal_linkage, "internal_linkage" }, + { linkage::private_linkage, "private_linkage" }, + { linkage::external_weak_linkage, "external_weak_linkage" }, + { linkage::common_linkage, "common_linkage" } + }; + + JLM_ASSERT(strings.find(lnk) != strings.end()); + return strings[lnk]; +} + } #endif diff --git a/jlm/mlir/Makefile.sub b/jlm/mlir/Makefile.sub index 97659f5bc..7dda68b3c 100644 --- a/jlm/mlir/Makefile.sub +++ b/jlm/mlir/Makefile.sub @@ -2,12 +2,15 @@ # See COPYING for terms of redistribution. libmlir_SOURCES = \ + jlm/mlir/backend/JlmToMlirConverter.cpp \ jlm/mlir/frontend/MlirToJlmConverter.cpp \ libmlir_HEADERS = \ + jlm/mlir/backend/JlmToMlirConverter.hpp \ jlm/mlir/frontend/MlirToJlmConverter.hpp \ libmlir_TESTS += \ + tests/jlm/mlir/backend/TestJlmToMlirConverter \ tests/jlm/mlir/frontend/TestMlirToJlmConverter \ libmlir_TEST_LIBS += \ diff --git a/jlm/mlir/backend/JlmToMlirConverter.cpp b/jlm/mlir/backend/JlmToMlirConverter.cpp new file mode 100644 index 000000000..d0dcff797 --- /dev/null +++ b/jlm/mlir/backend/JlmToMlirConverter.cpp @@ -0,0 +1,221 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#include + +#include +#include +#include +#include + +#include +#include + +namespace jlm::mlir +{ + +void +JlmToMlirConverter::Print(::mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath) +{ + if (failed(::mlir::verify(omega))) + { + omega.emitError("module verification error"); + throw util::error("Verification of RVSDG-MLIR failed"); + } + if (filePath == "") + { + ::llvm::raw_os_ostream os(std::cout); + omega.print(os); + } + else + { + std::error_code ec; + ::llvm::raw_fd_ostream os(filePath.to_str(), ec); + omega.print(os); + } +} + +::mlir::rvsdg::OmegaNode +JlmToMlirConverter::ConvertModule(const llvm::RvsdgModule & rvsdgModule) +{ + return ConvertOmega(rvsdgModule.Rvsdg()); +} + +::mlir::rvsdg::OmegaNode +JlmToMlirConverter::ConvertOmega(const rvsdg::graph & graph) +{ + auto omega = Builder_->create<::mlir::rvsdg::OmegaNode>(Builder_->getUnknownLoc()); + auto & omegaBlock = omega.getRegion().emplaceBlock(); + + ::llvm::SmallVector<::mlir::Value> regionResults = ConvertRegion(*graph.root(), omegaBlock); + + auto omegaResult = + Builder_->create<::mlir::rvsdg::OmegaResult>(Builder_->getUnknownLoc(), regionResults); + omegaBlock.push_back(omegaResult); + + return omega; +} + +::llvm::SmallVector<::mlir::Value> +JlmToMlirConverter::ConvertRegion(rvsdg::region & region, ::mlir::Block & block) +{ + for (size_t i = 0; i < region.narguments(); ++i) + { + auto type = ConvertType(region.argument(i)->type()); + block.addArgument(type, Builder_->getUnknownLoc()); + } + + // Create an MLIR operation for each RVSDG node and store each pair in a + // hash map for easy lookup of corresponding MLIR operation + std::unordered_map nodes; + for (rvsdg::node * rvsdgNode : rvsdg::topdown_traverser(®ion)) + { + // TODO + // Get the inputs of the node + // for (size_t i=0; i < rvsdgNode->ninputs(); i++) + //{ + // ::llvm::outs() << rvsdgNode->input(i) << "\n"; + //} + nodes[rvsdgNode] = ConvertNode(*rvsdgNode, block); + } + + ::llvm::SmallVector<::mlir::Value> results; + for (size_t i = 0; i < region.nresults(); ++i) + { + auto output = region.result(i)->origin(); + rvsdg::node * outputNode = rvsdg::node_output::node(output); + if (outputNode == nullptr) + { + // The result is connected directly to an argument + results.push_back(block.getArgument(output->index())); + } + else + { + // The identified node should always exist in the hash map of nodes + JLM_ASSERT(nodes.find(outputNode) != nodes.end()); + results.push_back(nodes[outputNode]); + } + } + return results; +} + +::mlir::Value +JlmToMlirConverter::ConvertNode(const rvsdg::node & node, ::mlir::Block & block) +{ + if (auto simpleNode = dynamic_cast(&node)) + { + return ConvertSimpleNode(*simpleNode, block); + } + else if (auto lambda = dynamic_cast(&node)) + { + return ConvertLambda(*lambda, block); + } + else + { + auto message = util::strfmt("Unimplemented structural node: ", node.operation().debug_string()); + JLM_UNREACHABLE(message.c_str()); + } +} + +::mlir::Value +JlmToMlirConverter::ConvertSimpleNode(const rvsdg::simple_node & node, ::mlir::Block & block) +{ + if (auto bitsOp = dynamic_cast(&(node.operation()))) + { + auto value = bitsOp->value(); + auto constOp = Builder_->create<::mlir::arith::ConstantIntOp>( + Builder_->getUnknownLoc(), + value.to_uint(), + value.nbits()); + block.push_back(constOp); + + return constOp; + } + else + { + auto message = util::strfmt("Unimplemented simple node: ", node.operation().debug_string()); + JLM_UNREACHABLE(message.c_str()); + } +} + +::mlir::Value +JlmToMlirConverter::ConvertLambda(const llvm::lambda::node & lambdaNode, ::mlir::Block & block) +{ + ::llvm::SmallVector<::mlir::Type> arguments; + for (size_t i = 0; i < lambdaNode.nfctarguments(); ++i) + { + arguments.push_back(ConvertType(lambdaNode.fctargument(i)->type())); + } + + ::llvm::SmallVector<::mlir::Type> results; + for (size_t i = 0; i < lambdaNode.nfctresults(); ++i) + { + results.push_back(ConvertType(lambdaNode.fctresult(i)->type())); + } + + ::llvm::SmallVector<::mlir::Type> lambdaRef; + auto refType = Builder_->getType<::mlir::rvsdg::LambdaRefType>( + ::llvm::ArrayRef(arguments), + ::llvm::ArrayRef(results)); + lambdaRef.push_back(refType); + + ::llvm::SmallVector<::mlir::Value> inputs; + // TODO + // Populate the inputs + + // Add function attributes, e.g., the function name and linkage + ::llvm::SmallVector<::mlir::NamedAttribute> attributes; + auto symbolName = Builder_->getNamedAttr( + Builder_->getStringAttr("sym_name"), + Builder_->getStringAttr(lambdaNode.name())); + attributes.push_back(symbolName); + auto linkage = Builder_->getNamedAttr( + Builder_->getStringAttr("linkage"), + Builder_->getStringAttr(llvm::ToString(lambdaNode.linkage()))); + attributes.push_back(linkage); + + auto lambda = Builder_->create<::mlir::rvsdg::LambdaNode>( + Builder_->getUnknownLoc(), + lambdaRef, + inputs, + ::llvm::ArrayRef<::mlir::NamedAttribute>(attributes)); + block.push_back(lambda); + + auto & lambdaBlock = lambda.getRegion().emplaceBlock(); + auto regionResults = ConvertRegion(*lambdaNode.subregion(), lambdaBlock); + auto lambdaResult = + Builder_->create<::mlir::rvsdg::LambdaResult>(Builder_->getUnknownLoc(), regionResults); + lambdaBlock.push_back(lambdaResult); + + return lambda; +} + +::mlir::Type +JlmToMlirConverter::ConvertType(const rvsdg::type & type) +{ + if (auto bt = dynamic_cast(&type)) + { + return Builder_->getIntegerType(bt->nbits()); + } + else if (rvsdg::is(type)) + { + return Builder_->getType<::mlir::rvsdg::LoopStateEdgeType>(); + } + else if (rvsdg::is(type)) + { + return Builder_->getType<::mlir::rvsdg::IOStateEdgeType>(); + } + else if (rvsdg::is(type)) + { + return Builder_->getType<::mlir::rvsdg::MemStateEdgeType>(); + } + else + { + auto message = util::strfmt("Type conversion not implemented: ", type.debug_string()); + JLM_UNREACHABLE(message.c_str()); + } +} + +} // namespace jlm::mlir diff --git a/jlm/mlir/backend/JlmToMlirConverter.hpp b/jlm/mlir/backend/JlmToMlirConverter.hpp new file mode 100644 index 000000000..2e1c61231 --- /dev/null +++ b/jlm/mlir/backend/JlmToMlirConverter.hpp @@ -0,0 +1,125 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_MLIR_BACKEND_JLMTOMLIRCONVERTER_HPP +#define JLM_MLIR_BACKEND_JLMTOMLIRCONVERTER_HPP + +// JLM +#include +#include + +// MLIR RVSDG dialects +#include +#include +#include + +// MLIR generic dialects +#include + +namespace jlm::mlir +{ + +class JlmToMlirConverter final +{ +public: + JlmToMlirConverter() + : Context_(std::make_unique<::mlir::MLIRContext>()) + { + Context_->getOrLoadDialect<::mlir::rvsdg::RVSDGDialect>(); + Context_->getOrLoadDialect<::mlir::jlm::JLMDialect>(); + Context_->getOrLoadDialect<::mlir::arith::ArithDialect>(); + Builder_ = std::make_unique<::mlir::OpBuilder>(Context_.get()); + } + + JlmToMlirConverter(const JlmToMlirConverter &) = delete; + + JlmToMlirConverter(JlmToMlirConverter &&) = delete; + + JlmToMlirConverter & + operator=(const JlmToMlirConverter &) = delete; + + JlmToMlirConverter & + operator=(JlmToMlirConverter &&) = delete; + + /** + * Prints MLIR RVSDG to a file. + * \param omega The MLIR RVSDG Omega node to be printed. + * \param filePath The path to the file to print the MLIR to. + */ + static void + Print(::mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath); + + /** + * Converts an RVSDG module to MLIR RVSDG. + * \param rvsdgModule The RVSDG module to be converted. + * \return An MLIR RVSDG OmegaNode containing the whole graph of the rvsdgModule. It is + * the responsibility of the caller to call ->destroy() on the returned omega, once it is no + * longer needed. + */ + ::mlir::rvsdg::OmegaNode + ConvertModule(const llvm::RvsdgModule & rvsdgModule); + +private: + /** + * Converts an omega and all nodes in its (sub)region(s) to an MLIR RVSDG OmegaNode. + * \param graph The root RVSDG graph. + * \return An MLIR RVSDG OmegaNode. + */ + ::mlir::rvsdg::OmegaNode + ConvertOmega(const rvsdg::graph & graph); + + /** + * Converts all nodes in an RVSDG region. Conversion of structural nodes cause their regions to + * also be converted. + * \param region The RVSDG region to be converted + * \param block The MLIR RVSDG block that corresponds to this RVSDG region, and to which + * converted nodes are insterted. + * \return A list of outputs of the converted region/block. + */ + ::llvm::SmallVector<::mlir::Value> + ConvertRegion(rvsdg::region & region, ::mlir::Block & block); + + /** + * Converts an RVSDG node to an MLIR RVSDG operation. + * \param node The RVSDG node to be converted + * \param block The MLIR RVSDG block to insert the converted node. + * \return The converted MLIR RVSDG operation. + */ + ::mlir::Value + ConvertNode(const rvsdg::node & node, ::mlir::Block & block); + + /** + * Converts an RVSDG simple_node to an MLIR RVSDG operation. + * \param node The RVSDG node to be converted + * \param block The MLIR RVSDG block to insert the converted node. + * \return The converted MLIR RVSDG operation. + */ + ::mlir::Value + ConvertSimpleNode(const rvsdg::simple_node & node, ::mlir::Block & block); + + /** + * Converts an RVSDG lambda node to an MLIR RVSDG LambdaNode. + * \param node The RVSDG lambda node to be converted + * \param block The MLIR RVSDG block to insert the lambda node. + * \return The converted MLIR RVSDG LambdaNode. + */ + ::mlir::Value + ConvertLambda(const llvm::lambda::node & node, ::mlir::Block & block); + + /** + * Converts an RVSDG type to an MLIR RVSDG type. + * \param type The RVSDG type to be converted. + * \result The corresponding MLIR RVSDG type. + */ + ::mlir::Type + ConvertType(const rvsdg::type & type); + + std::unique_ptr<::mlir::OpBuilder> Builder_; + std::unique_ptr<::mlir::MLIRContext> Context_; +}; + +} // namespace jlm::mlir + +#endif // JLM_MLIR_BACKEND_JLMTOMLIRCONVERTER_HPP diff --git a/tests/jlm/mlir/backend/TestJlmToMlirConverter.cpp b/tests/jlm/mlir/backend/TestJlmToMlirConverter.cpp new file mode 100644 index 000000000..db862fbaa --- /dev/null +++ b/tests/jlm/mlir/backend/TestJlmToMlirConverter.cpp @@ -0,0 +1,105 @@ +/* + * Copyright 2024 Magnus Själander + * See COPYING for terms of redistribution. + */ + +#include +#include + +#include +#include +#include + +static void +TestLambda() +{ + using namespace jlm::llvm; + using namespace mlir::rvsdg; + + auto rvsdgModule = RvsdgModule::Create(jlm::util::filepath(""), "", ""); + auto graph = &rvsdgModule->Rvsdg(); + + auto nf = graph->node_normal_form(typeid(jlm::rvsdg::operation)); + nf->set_mutable(false); + + { + // Setup the function + iostatetype iOStateType; + MemoryStateType memoryStateType; + loopstatetype loopStateType; + FunctionType functionType( + { &iOStateType, &memoryStateType, &loopStateType }, + { &jlm::rvsdg::bit32, &iOStateType, &memoryStateType, &loopStateType }); + + auto lambda = + lambda::node::create(graph->root(), functionType, "test", linkage::external_linkage); + auto iOStateArgument = lambda->fctargument(0); + auto memoryStateArgument = lambda->fctargument(1); + auto loopStateArgument = lambda->fctargument(2); + + auto constant = jlm::rvsdg::create_bitconstant(lambda->subregion(), 32, 4); + + lambda->finalize({ constant, iOStateArgument, memoryStateArgument, loopStateArgument }); + + // Convert the RVSDG to MLIR + jlm::mlir::JlmToMlirConverter mlirgen; + auto omega = mlirgen.ConvertModule(*rvsdgModule); + + // Validate the generated MLIR + auto & omegaRegion = omega.getRegion(); + assert(omegaRegion.getBlocks().size() == 1); + auto & omegaBlock = omegaRegion.front(); + // Lamda + terminating operation + assert(omegaBlock.getOperations().size() == 2); + auto & mlirLambda = omegaBlock.front(); + assert(mlirLambda.getName().getStringRef() == LambdaNode::getOperationName()); + + // Verify function name + auto functionNameAttribute = mlirLambda.getAttr(::llvm::StringRef("sym_name")); + auto * functionName = static_cast(&functionNameAttribute); + auto string = functionName->getValue().str(); + assert(string == "test"); + + // Verify function signature + auto result = mlirLambda.getResult(0).getType(); + assert(result.getTypeID() == LambdaRefType::getTypeID()); + auto * lambdaRefType = static_cast(&result); + std::vector arguments; + for (auto argumentType : lambdaRefType->getParameterTypes()) + { + arguments.push_back(argumentType); + } + assert(arguments[0].getTypeID() == IOStateEdgeType::getTypeID()); + assert(arguments[1].getTypeID() == MemStateEdgeType::getTypeID()); + assert(arguments[2].getTypeID() == LoopStateEdgeType::getTypeID()); + std::vector results; + for (auto returnType : lambdaRefType->getReturnTypes()) + { + results.push_back(returnType); + } + assert(results[0].getTypeID() == mlir::IntegerType::getTypeID()); + assert(results[1].getTypeID() == IOStateEdgeType::getTypeID()); + assert(results[2].getTypeID() == MemStateEdgeType::getTypeID()); + assert(results[3].getTypeID() == LoopStateEdgeType::getTypeID()); + + auto & lambdaRegion = mlirLambda.getRegion(0); + auto & lambdaBlock = lambdaRegion.front(); + // Bitconstant + terminating operation + assert(lambdaBlock.getOperations().size() == 2); + assert( + lambdaBlock.front().getName().getStringRef() + == mlir::arith::ConstantIntOp::getOperationName()); + + omega->destroy(); + } +} + +static int +Test() +{ + TestLambda(); + + return 0; +} + +JLM_UNIT_TEST_REGISTER("jlm/mlir/backend/TestMlirGen", Test)