Skip to content

Commit

Permalink
function returns, JIT
Browse files Browse the repository at this point in the history
  • Loading branch information
Raekye committed Sep 21, 2019
1 parent c7a6621 commit bdc6007
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 44 deletions.
11 changes: 10 additions & 1 deletion midori/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ int test_lang() {
LangASTPrinter p;
program->accept(&p);
CodeGen cg;
std::unique_ptr<LangASTPrototype> proto(new LangASTPrototype("main", "Int", {}));
std::vector<std::unique_ptr<LangASTDecl>> v;
//v.push_back(std::unique_ptr<LangASTDecl>(new LangASTDecl("foo", "Int")));
std::unique_ptr<LangASTPrototype> proto(new LangASTPrototype("main", "Int", std::move(v)));
LangASTFunction g(std::move(proto), std::move(program));
cg.process(&g);
cg.dump("program.bc");
cg.run();
return 0;
}

extern "C" {
void putint(int x) {
std::cout << "int " << x << std::endl;
}
}

void foo(int& x) {
x = 3;
}
Expand Down
3 changes: 2 additions & 1 deletion midori/src/midori/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ find_package(LLVM REQUIRED CONFIG)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})

llvm_map_components_to_libnames(llvm_libs core orcjit native)
llvm_map_components_to_libnames(llvm_libs core interpreter mcjit orcjit native)

add_library(midori SHARED ${SOURCES})
target_compile_options(midori PRIVATE -Wall -Wextra -Wpedantic -Werror -Wno-unknown-pragmas)
target_link_libraries(midori PUBLIC -rdynamic)
target_link_libraries(midori PUBLIC ${llvm_libs})
#target_link_libraries(midori PUBLIC coverage)
128 changes: 104 additions & 24 deletions midori/src/midori/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,37 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/GenericValue.h"
#include "llvm/ExecutionEngine/OrcMCJITReplacement.h"
#include "llvm/ExecutionEngine/MCJIT.h"
//#include "llvm/ExecutionEngine/Interpreter.h"
#include <cassert>

CodeGen::CodeGen() : builder(this->context), module("midori", this->context), type_manager(&(this->context)) {
CodeGen::CodeGen() : builder(this->context), module(new llvm::Module("midori", this->context)), type_manager(&(this->context)) {
llvm::BasicBlock::Create(this->context, "main");
}

void CodeGen::process(LangAST* program) {
TypeChecker tc(&(this->type_manager));
program->accept(&tc);
program->accept(this);
llvm::verifyModule(*(this->module), &(llvm::errs()));
}

std::error_code CodeGen::dump(std::string path) {
std::error_code ec;
llvm::raw_fd_ostream o(path, ec, llvm::sys::fs::OpenFlags::F_None);
llvm::WriteBitcodeToFile(&(this->module), o);
o.close();
llvm::WriteBitcodeToFile(this->module.get(), o);
return ec;
}

void CodeGen::run() {
llvm::Function* f = this->get_function("main");
llvm::ExecutionEngine* ee = llvm::EngineBuilder(std::move(this->module)).create();
ee->runFunction(f, {});
}

void CodeGen::visit(LangASTBlock* v) {
this->push_scope();
for (std::unique_ptr<LangAST> const& l : v->lines) {
Expand Down Expand Up @@ -82,36 +94,78 @@ void CodeGen::visit(LangASTBinOp* v) {
}
v->left->accept(this);
llvm::Value* lhs = this->ret;
llvm::Type* type = lhs->getType();
this->ret = nullptr;
switch (v->op) {
case LangASTBinOp::Op::PLUS:
this->ret = this->builder.CreateAdd(lhs, rhs, "addtmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateAdd(lhs, rhs, "addtmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFAdd(lhs, rhs, "addftmp");
}
break;
case LangASTBinOp::Op::MINUS:
this->ret = this->builder.CreateSub(lhs, rhs, "subtmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateSub(lhs, rhs, "subtmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFSub(lhs, rhs, "subftmp");
}
break;
case LangASTBinOp::Op::STAR:
this->ret = this->builder.CreateMul(lhs, rhs, "multmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateMul(lhs, rhs, "multmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFMul(lhs, rhs, "mulftmp");
}
break;
case LangASTBinOp::Op::SLASH:
this->ret = this->builder.CreateSDiv(lhs, rhs, "divtmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateSDiv(lhs, rhs, "divtmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFDiv(lhs, rhs, "divftmp");
}
break;
case LangASTBinOp::Op::EQ:
this->ret = this->builder.CreateICmpEQ(lhs, rhs, "eqtmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpEQ(lhs, rhs, "eqtmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpOEQ(lhs, rhs, "eqftmp");
}
break;
case LangASTBinOp::Op::NE:
this->ret = this->builder.CreateICmpNE(lhs, rhs, "netmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpNE(lhs, rhs, "netmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpONE(lhs, rhs, "neftmp");
}
break;
case LangASTBinOp::Op::LT:
this->ret = this->builder.CreateICmpSLT(lhs, rhs, "lttmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpSLT(lhs, rhs, "lttmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpOLT(lhs, rhs, "ltftmp");
}
break;
case LangASTBinOp::Op::GT:
this->ret = this->builder.CreateICmpSGT(lhs, rhs, "gttmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpSGT(lhs, rhs, "gttmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpOGT(lhs, rhs, "gtftmp");
}
break;
case LangASTBinOp::Op::LE:
this->ret = this->builder.CreateICmpSLE(lhs, rhs, "letmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpSGE(lhs, rhs, "getmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpOGE(lhs, rhs, "geftmp");
}
break;
case LangASTBinOp::Op::GE:
this->ret = this->builder.CreateICmpSGE(lhs, rhs, "getmp");
if (type->isIntegerTy()) {
this->ret = this->builder.CreateICmpSLE(lhs, rhs, "letmp");
} else if (type->isFloatingPointTy()) {
this->ret = this->builder.CreateFCmpOLE(lhs, rhs, "leftmp");
}
break;
default:
this->ret = nullptr;
Expand Down Expand Up @@ -139,13 +193,17 @@ void CodeGen::visit(LangASTIf* v) {
llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(this->context, "then", f);
llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(this->context, "else");
llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(this->context, "ifcont");
int x = 0;

this->builder.CreateCondBr(cond, then_bb, else_bb);

this->builder.SetInsertPoint(then_bb);
v->block_if->accept(this);

this->builder.CreateBr(merge_bb);
if (this->builder.GetInsertBlock()->getTerminator() == nullptr) {
this->builder.CreateBr(merge_bb);
x++;
}

f->getBasicBlockList().push_back(else_bb);
this->builder.SetInsertPoint(else_bb);
Expand All @@ -154,10 +212,15 @@ void CodeGen::visit(LangASTIf* v) {
v->block_else->accept(this);
}

this->builder.CreateBr(merge_bb);
if (this->builder.GetInsertBlock()->getTerminator() == nullptr) {
this->builder.CreateBr(merge_bb);
x++;
}

f->getBasicBlockList().push_back(merge_bb);
this->builder.SetInsertPoint(merge_bb);
if (x > 0) {
f->getBasicBlockList().push_back(merge_bb);
this->builder.SetInsertPoint(merge_bb);
}
this->pop_scope();
this->ret = nullptr;
}
Expand All @@ -176,7 +239,9 @@ void CodeGen::visit(LangASTWhile* v) {
f->getBasicBlockList().push_back(loop_bb);
this->builder.SetInsertPoint(loop_bb);
v->block->accept(this);
this->builder.CreateBr(cond_bb);
if (this->builder.GetInsertBlock()->getTerminator() == nullptr) {
this->builder.CreateBr(cond_bb);
}
f->getBasicBlockList().push_back(after_bb);
this->builder.SetInsertPoint(after_bb);
this->pop_scope();
Expand All @@ -189,7 +254,7 @@ void CodeGen::visit(LangASTPrototype* v) {
arg_types.push_back(a->type->llvm_type);
}
llvm::FunctionType* ft = llvm::FunctionType::get(this->type_manager.get(v->return_type)->llvm_type, arg_types, false);
llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, v->name, &(this->module));
llvm::Function* f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, v->name, this->module.get());
Int i = 0;
for (llvm::Argument& a : f->args()) {
a.setName(v->args.at(i)->name);
Expand All @@ -216,18 +281,29 @@ void CodeGen::visit(LangASTFunction* v) {
i++;
}
v->body->accept(this);
llvm::Value* z = llvm::ConstantInt::get(this->type_manager.get("Int")->llvm_type, 0, true);
this->builder.CreateRet(z);

llvm::verifyFunction(*f);
if (this->builder.GetInsertBlock()->getTerminator() == nullptr) {
if (v->proto->return_type == this->type_manager.void_type()->name) {
this->builder.CreateRetVoid();
}
}

f->print(llvm::errs());
llvm::verifyFunction(*f, &(llvm::errs()));

this->pop_scope();
this->builder.SetInsertPoint(old);
this->ret = f;
}

void CodeGen::visit(LangASTReturn* v) {
if (v->val == nullptr) {
this->ret = this->builder.CreateRetVoid();
return;
}
v->val->accept(this);
this->ret = this->builder.CreateRet(this->ret);
}

void CodeGen::visit(LangASTCall* v) {
llvm::Function* f = this->get_function(v->function);
if (f->arg_size() != v->args.size()) {
Expand All @@ -239,6 +315,10 @@ void CodeGen::visit(LangASTCall* v) {
a->accept(this);
args.push_back(this->ret);
}
if (f->getReturnType() == this->type_manager.void_type()->llvm_type) {
this->ret = this->builder.CreateCall(f, args);
return;
}
this->ret = this->builder.CreateCall(f, args, "calltmp");
}

Expand All @@ -261,7 +341,7 @@ llvm::Value* CodeGen::named_value(std::string s) {
}

llvm::Function* CodeGen::get_function(std::string name) {
if (llvm::Function* f = this->module.getFunction(name)) {
if (llvm::Function* f = this->module->getFunction(name)) {
return f;
}
return nullptr;
Expand Down
4 changes: 3 additions & 1 deletion midori/src/midori/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class CodeGen : public ILangASTVisitor {
CodeGen();
void process(LangAST*);
std::error_code dump(std::string);
void run();
virtual void visit(LangASTBlock*) override;
virtual void visit(LangASTIdent*) override;
virtual void visit(LangASTDecl*) override;
Expand All @@ -28,12 +29,13 @@ class CodeGen : public ILangASTVisitor {
virtual void visit(LangASTWhile*) override;
virtual void visit(LangASTPrototype*) override;
virtual void visit(LangASTFunction*) override;
virtual void visit(LangASTReturn*) override;
virtual void visit(LangASTCall*) override;

private:
llvm::LLVMContext context;
llvm::IRBuilder<> builder;
llvm::Module module;
std::unique_ptr<llvm::Module> module;
std::list<std::map<std::string, llvm::Value*>> frames;
llvm::Value* ret;
TypeManager type_manager;
Expand Down
Loading

0 comments on commit bdc6007

Please sign in to comment.