Skip to content

Commit

Permalink
Add allocation hoist optimization (#534)
Browse files Browse the repository at this point in the history
* Add allocation hoist optimization

* Fix analysis of nested loops

* Check for irreducible cycles

* Add additional check when analyzing insertvalue

* Add allocation-specific attributes in LLVM IR

* Remove unused calloc function from runtime library

* Add float -> intN and float -> uintN constructors

* Only hoist atomic allocations

* Update codegen

* Simplify codegen

* Change allocation hoist pass to be a function pass

* Fix loop iteration order

* Use 'struct' instead of 'class'

* Add check for phi instructions in header; refactor

* Remove unneeded checks

* Fix C++ benchmark

* Remove annotation
  • Loading branch information
arshajii authored Feb 23, 2024
1 parent 7a787bf commit 4be3bbf
Show file tree
Hide file tree
Showing 9 changed files with 414 additions and 74 deletions.
1 change: 1 addition & 0 deletions bench/set_partition/set_partition.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <algorithm>
#include <chrono>
#include <functional>
#include <iostream>
#include <vector>

Expand Down
71 changes: 60 additions & 11 deletions codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,14 +726,8 @@ void LLVMVisitor::writeToPythonExtension(const PyModule &pymod,
/*Initializer=*/nullptr, "PyType_Type");

auto allocUncollectable = llvm::cast<llvm::Function>(
M->getOrInsertFunction("seq_alloc_uncollectable", ptr, i64).getCallee());
allocUncollectable->setDoesNotThrow();
allocUncollectable->setReturnDoesNotAlias();
allocUncollectable->setOnlyAccessesInaccessibleMemory();

auto free = llvm::cast<llvm::Function>(
M->getOrInsertFunction("seq_free", B->getVoidTy(), ptr).getCallee());
free->setDoesNotThrow();
makeAllocFunc(/*atomic=*/false, /*uncollectable=*/true).getCallee());
auto free = llvm::cast<llvm::Function>(makeFreeFunc().getCallee());

// Helpers
auto pyFuncWrap = [&](Func *func, bool wrap) -> llvm::Constant * {
Expand Down Expand Up @@ -1282,16 +1276,57 @@ void LLVMVisitor::run(const std::vector<std::string> &args,
}
}

llvm::FunctionCallee LLVMVisitor::makeAllocFunc(bool atomic) {
auto f = M->getOrInsertFunction(atomic ? "seq_alloc_atomic" : "seq_alloc",
B->getInt8PtrTy(), B->getInt64Ty());
#define ALLOC_FAMILY "seq_alloc"

llvm::FunctionCallee LLVMVisitor::makeAllocFunc(bool atomic, bool uncollectable) {
const std::string name =
atomic ? (uncollectable ? "seq_alloc_atomic_uncollectable" : "seq_alloc_atomic")
: (uncollectable ? "seq_alloc_uncollectable" : "seq_alloc");
auto f = M->getOrInsertFunction(name, B->getInt8PtrTy(), B->getInt64Ty());
auto *g = cast<llvm::Function>(f.getCallee());
g->setDoesNotThrow();
g->setReturnDoesNotAlias();
g->setOnlyAccessesInaccessibleMemory();
g->addRetAttr(llvm::Attribute::AttrKind::NoUndef);
g->addRetAttr(llvm::Attribute::AttrKind::NonNull);
g->addFnAttrs(
llvm::AttrBuilder(*context)
.addAllocKindAttr(llvm::AllocFnKind::Alloc | llvm::AllocFnKind::Uninitialized)
.addAllocSizeAttr(0, {})
.addAttribute("alloc-family", ALLOC_FAMILY));
return f;
}

llvm::FunctionCallee LLVMVisitor::makeReallocFunc() {
// note that seq_realloc takes arguments (ptr, new_size, old_size)
auto f = M->getOrInsertFunction("seq_realloc", B->getInt8PtrTy(), B->getInt8PtrTy(),
B->getInt64Ty(), B->getInt64Ty());
auto *g = cast<llvm::Function>(f.getCallee());
g->setDoesNotThrow();
g->addRetAttr(llvm::Attribute::AttrKind::NoUndef);
g->addRetAttr(llvm::Attribute::AttrKind::NonNull);
g->addParamAttr(0, llvm::Attribute::AttrKind::AllocatedPointer);
g->addFnAttrs(llvm::AttrBuilder(*context)
.addAllocKindAttr(llvm::AllocFnKind::Realloc |
llvm::AllocFnKind::Uninitialized)
.addAllocSizeAttr(1, {})
.addAttribute("alloc-family", ALLOC_FAMILY));
return f;
}

llvm::FunctionCallee LLVMVisitor::makeFreeFunc() {
auto f = M->getOrInsertFunction("seq_free", B->getVoidTy(), B->getInt8PtrTy());
auto *g = cast<llvm::Function>(f.getCallee());
g->setDoesNotThrow();
g->addParamAttr(0, llvm::Attribute::AttrKind::AllocatedPointer);
g->addFnAttrs(llvm::AttrBuilder(*context)
.addAllocKindAttr(llvm::AllocFnKind::Free)
.addAttribute("alloc-family", ALLOC_FAMILY));
return f;
}

#undef ALLOC_FAMILY

llvm::FunctionCallee LLVMVisitor::makePersonalityFunc() {
return M->getOrInsertFunction("seq_personality", B->getInt32Ty(), B->getInt32Ty(),
B->getInt32Ty(), B->getInt64Ty(), B->getInt8PtrTy(),
Expand Down Expand Up @@ -1573,6 +1608,20 @@ void LLVMVisitor::visit(const Module *x) {

B->SetInsertPoint(exitBlock);
B->CreateRet(B->getInt32(0));

// make sure allocation functions have the correct attributes
if (M->getFunction("seq_alloc"))
makeAllocFunc(/*atomic=*/false, /*uncollectable=*/false);
if (M->getFunction("seq_alloc_atomic"))
makeAllocFunc(/*atomic=*/true, /*uncollectable=*/false);
if (M->getFunction("seq_alloc_uncollectable"))
makeAllocFunc(/*atomic=*/false, /*uncollectable=*/true);
if (M->getFunction("seq_alloc_atomic_uncollectable"))
makeAllocFunc(/*atomic=*/true, /*uncollectable=*/true);
if (M->getFunction("seq_realloc"))
makeReallocFunc();
if (M->getFunction("seq_free"))
makeFreeFunc();
}

llvm::DISubprogram *LLVMVisitor::getDISubprogramForFunc(const Func *x) {
Expand Down
6 changes: 5 additions & 1 deletion codon/cir/llvm/llvisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ class LLVMVisitor : public util::ConstVisitor {
std::unordered_map<std::string, llvm::DICompositeType *> &cache);

/// GC allocation functions
llvm::FunctionCallee makeAllocFunc(bool atomic);
llvm::FunctionCallee makeAllocFunc(bool atomic, bool uncollectable = false);
// GC reallocation function
llvm::FunctionCallee makeReallocFunc();
// GC free function
llvm::FunctionCallee makeFreeFunc();
/// Personality function for exception handling
llvm::FunctionCallee makePersonalityFunc();
/// Exception allocation function
Expand Down
4 changes: 4 additions & 0 deletions codon/cir/llvm/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/CycleAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/RegionPass.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
Expand Down Expand Up @@ -109,5 +112,6 @@
#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
#include "llvm/Transforms/IPO/StripSymbols.h"
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Debugify.h"
Loading

0 comments on commit 4be3bbf

Please sign in to comment.