Skip to content

Commit

Permalink
MLIR RVSDG backend (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjalander authored Feb 2, 2024
1 parent efc4ea6 commit 4106a01
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 0 deletions.
25 changes: 25 additions & 0 deletions jlm/llvm/ir/linkage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#ifndef JLM_LLVM_IR_LINKAGE_HPP
#define JLM_LLVM_IR_LINKAGE_HPP

#include <jlm/util/common.hpp>
#include <string>
#include <unordered_map>

namespace jlm::llvm
{

Expand All @@ -31,6 +35,27 @@ is_externally_visible(const linkage & lnk)
return lnk != linkage::internal_linkage;
}

static inline std::string
ToString(const linkage & lnk)
{
static std::unordered_map<linkage, std::string> strings = {
{ linkage::external_linkage, "external_linkage" },
{ linkage::available_externally_linkage, "available_externally_linkage" },
{ linkage::link_once_any_linkage, "link_once_any_linkage" },
{ linkage::link_once_odr_linkage, "link_once_odr_linkage" },
{ linkage::weak_any_linkage, "weak_any_linkage" },
{ linkage::weak_odr_linkage, "weak_odr_linkage" },
{ linkage::appending_linkage, "appending_linkage" },
{ linkage::internal_linkage, "internal_linkage" },
{ linkage::private_linkage, "private_linkage" },
{ linkage::external_weak_linkage, "external_weak_linkage" },
{ linkage::common_linkage, "common_linkage" }
};

JLM_ASSERT(strings.find(lnk) != strings.end());
return strings[lnk];
}

}

#endif
3 changes: 3 additions & 0 deletions jlm/mlir/Makefile.sub
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# See COPYING for terms of redistribution.

libmlir_SOURCES = \
jlm/mlir/backend/JlmToMlirConverter.cpp \
jlm/mlir/frontend/MlirToJlmConverter.cpp \

libmlir_HEADERS = \
jlm/mlir/backend/JlmToMlirConverter.hpp \
jlm/mlir/frontend/MlirToJlmConverter.hpp \

libmlir_TESTS += \
tests/jlm/mlir/backend/TestJlmToMlirConverter \
tests/jlm/mlir/frontend/TestMlirToJlmConverter \

libmlir_TEST_LIBS += \
Expand Down
221 changes: 221 additions & 0 deletions jlm/mlir/backend/JlmToMlirConverter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright 2023 Magnus Sjalander <[email protected]>
* See COPYING for terms of redistribution.
*/

#include <jlm/mlir/backend/JlmToMlirConverter.hpp>

#include <jlm/rvsdg/bitstring/comparison.hpp>
#include <jlm/rvsdg/bitstring/constant.hpp>
#include <jlm/rvsdg/node.hpp>
#include <jlm/rvsdg/traverser.hpp>

#include <llvm/Support/raw_os_ostream.h>
#include <mlir/IR/Verifier.h>

namespace jlm::mlir
{

void
JlmToMlirConverter::Print(::mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath)
{
if (failed(::mlir::verify(omega)))
{
omega.emitError("module verification error");
throw util::error("Verification of RVSDG-MLIR failed");
}
if (filePath == "")
{
::llvm::raw_os_ostream os(std::cout);
omega.print(os);
}
else
{
std::error_code ec;
::llvm::raw_fd_ostream os(filePath.to_str(), ec);
omega.print(os);
}
}

::mlir::rvsdg::OmegaNode
JlmToMlirConverter::ConvertModule(const llvm::RvsdgModule & rvsdgModule)
{
return ConvertOmega(rvsdgModule.Rvsdg());
}

::mlir::rvsdg::OmegaNode
JlmToMlirConverter::ConvertOmega(const rvsdg::graph & graph)
{
auto omega = Builder_->create<::mlir::rvsdg::OmegaNode>(Builder_->getUnknownLoc());
auto & omegaBlock = omega.getRegion().emplaceBlock();

::llvm::SmallVector<::mlir::Value> regionResults = ConvertRegion(*graph.root(), omegaBlock);

auto omegaResult =
Builder_->create<::mlir::rvsdg::OmegaResult>(Builder_->getUnknownLoc(), regionResults);
omegaBlock.push_back(omegaResult);

return omega;
}

::llvm::SmallVector<::mlir::Value>
JlmToMlirConverter::ConvertRegion(rvsdg::region & region, ::mlir::Block & block)
{
for (size_t i = 0; i < region.narguments(); ++i)
{
auto type = ConvertType(region.argument(i)->type());
block.addArgument(type, Builder_->getUnknownLoc());
}

// Create an MLIR operation for each RVSDG node and store each pair in a
// hash map for easy lookup of corresponding MLIR operation
std::unordered_map<rvsdg::node *, ::mlir::Value> nodes;
for (rvsdg::node * rvsdgNode : rvsdg::topdown_traverser(&region))
{
// TODO
// Get the inputs of the node
// for (size_t i=0; i < rvsdgNode->ninputs(); i++)
//{
// ::llvm::outs() << rvsdgNode->input(i) << "\n";
//}
nodes[rvsdgNode] = ConvertNode(*rvsdgNode, block);
}

::llvm::SmallVector<::mlir::Value> results;
for (size_t i = 0; i < region.nresults(); ++i)
{
auto output = region.result(i)->origin();
rvsdg::node * outputNode = rvsdg::node_output::node(output);
if (outputNode == nullptr)
{
// The result is connected directly to an argument
results.push_back(block.getArgument(output->index()));
}
else
{
// The identified node should always exist in the hash map of nodes
JLM_ASSERT(nodes.find(outputNode) != nodes.end());
results.push_back(nodes[outputNode]);
}
}
return results;
}

::mlir::Value
JlmToMlirConverter::ConvertNode(const rvsdg::node & node, ::mlir::Block & block)
{
if (auto simpleNode = dynamic_cast<const rvsdg::simple_node *>(&node))
{
return ConvertSimpleNode(*simpleNode, block);
}
else if (auto lambda = dynamic_cast<const llvm::lambda::node *>(&node))
{
return ConvertLambda(*lambda, block);
}
else
{
auto message = util::strfmt("Unimplemented structural node: ", node.operation().debug_string());
JLM_UNREACHABLE(message.c_str());
}
}

::mlir::Value
JlmToMlirConverter::ConvertSimpleNode(const rvsdg::simple_node & node, ::mlir::Block & block)
{
if (auto bitsOp = dynamic_cast<const rvsdg::bitconstant_op *>(&(node.operation())))
{
auto value = bitsOp->value();
auto constOp = Builder_->create<::mlir::arith::ConstantIntOp>(
Builder_->getUnknownLoc(),
value.to_uint(),
value.nbits());
block.push_back(constOp);

return constOp;
}
else
{
auto message = util::strfmt("Unimplemented simple node: ", node.operation().debug_string());
JLM_UNREACHABLE(message.c_str());
}
}

::mlir::Value
JlmToMlirConverter::ConvertLambda(const llvm::lambda::node & lambdaNode, ::mlir::Block & block)
{
::llvm::SmallVector<::mlir::Type> arguments;
for (size_t i = 0; i < lambdaNode.nfctarguments(); ++i)
{
arguments.push_back(ConvertType(lambdaNode.fctargument(i)->type()));
}

::llvm::SmallVector<::mlir::Type> results;
for (size_t i = 0; i < lambdaNode.nfctresults(); ++i)
{
results.push_back(ConvertType(lambdaNode.fctresult(i)->type()));
}

::llvm::SmallVector<::mlir::Type> lambdaRef;
auto refType = Builder_->getType<::mlir::rvsdg::LambdaRefType>(
::llvm::ArrayRef(arguments),
::llvm::ArrayRef(results));
lambdaRef.push_back(refType);

::llvm::SmallVector<::mlir::Value> inputs;
// TODO
// Populate the inputs

// Add function attributes, e.g., the function name and linkage
::llvm::SmallVector<::mlir::NamedAttribute> attributes;
auto symbolName = Builder_->getNamedAttr(
Builder_->getStringAttr("sym_name"),
Builder_->getStringAttr(lambdaNode.name()));
attributes.push_back(symbolName);
auto linkage = Builder_->getNamedAttr(
Builder_->getStringAttr("linkage"),
Builder_->getStringAttr(llvm::ToString(lambdaNode.linkage())));
attributes.push_back(linkage);

auto lambda = Builder_->create<::mlir::rvsdg::LambdaNode>(
Builder_->getUnknownLoc(),
lambdaRef,
inputs,
::llvm::ArrayRef<::mlir::NamedAttribute>(attributes));
block.push_back(lambda);

auto & lambdaBlock = lambda.getRegion().emplaceBlock();
auto regionResults = ConvertRegion(*lambdaNode.subregion(), lambdaBlock);
auto lambdaResult =
Builder_->create<::mlir::rvsdg::LambdaResult>(Builder_->getUnknownLoc(), regionResults);
lambdaBlock.push_back(lambdaResult);

return lambda;
}

::mlir::Type
JlmToMlirConverter::ConvertType(const rvsdg::type & type)
{
if (auto bt = dynamic_cast<const rvsdg::bittype *>(&type))
{
return Builder_->getIntegerType(bt->nbits());
}
else if (rvsdg::is<llvm::loopstatetype>(type))
{
return Builder_->getType<::mlir::rvsdg::LoopStateEdgeType>();
}
else if (rvsdg::is<llvm::iostatetype>(type))
{
return Builder_->getType<::mlir::rvsdg::IOStateEdgeType>();
}
else if (rvsdg::is<llvm::MemoryStateType>(type))
{
return Builder_->getType<::mlir::rvsdg::MemStateEdgeType>();
}
else
{
auto message = util::strfmt("Type conversion not implemented: ", type.debug_string());
JLM_UNREACHABLE(message.c_str());
}
}

} // namespace jlm::mlir
Loading

0 comments on commit 4106a01

Please sign in to comment.