From 3d4d5da4c4e3e754abbe96f53faaa7a684d7e892 Mon Sep 17 00:00:00 2001 From: Raekye Date: Mon, 10 Nov 2014 19:36:50 -0500 Subject: [PATCH] user defined functions work --- sayaka/src/ast_node.h | 24 +++++++++++----------- sayaka/src/ast_node_function.cpp | 16 ++++++++------- sayaka/src/ast_node_functioncall.cpp | 1 + sayaka/src/ast_node_functionprototype.cpp | 22 +++++++++++++++----- sayaka/src/ast_node_identifier.cpp | 8 ++++++-- sayaka/src/code_gen_context.cpp | 2 -- sayaka/src/compiler.cpp | 6 +++--- sayaka/src/parser.y | 25 +++++++++++++---------- 8 files changed, 62 insertions(+), 42 deletions(-) diff --git a/sayaka/src/ast_node.h b/sayaka/src/ast_node.h index 6973673..461c74c 100644 --- a/sayaka/src/ast_node.h +++ b/sayaka/src/ast_node.h @@ -110,29 +110,29 @@ class ASTNodeBinaryOperator : public ASTNode { virtual ASTNodeBinaryOperator* pass_types(CodeGenContext*, ASTType*) override; }; -class ASTNodeFunction : public ASTNode { +class ASTNodeFunctionPrototype : public ASTNode { public: - ASTNodeBlock* body; std::string return_type; + std::string function_name; + std::vector* args; - ASTNodeFunction(ASTNodeBlock*, std::string); + ASTNodeFunctionPrototype(std::string, std::string, std::vector*); - virtual ~ASTNodeFunction(); + virtual ~ASTNodeFunctionPrototype(); virtual llvm::Value* gen_code(CodeGenContext*) override; - virtual ASTNodeFunction* pass_types(CodeGenContext*, ASTType*) override; + virtual ASTNodeFunctionPrototype* pass_types(CodeGenContext*, ASTType*) override; }; -class ASTNodeFunctionPrototype : public ASTNode { +class ASTNodeFunction : public ASTNode { public: - std::string return_type; - std::string function_name; - std::vector>* args; + ASTNodeFunctionPrototype* prototype; + ASTNodeBlock* body; - ASTNodeFunctionPrototype(std::string, std::string, std::vector>*); + ASTNodeFunction(ASTNodeFunctionPrototype*, ASTNodeBlock*); - virtual ~ASTNodeFunctionPrototype(); + virtual ~ASTNodeFunction(); virtual llvm::Value* gen_code(CodeGenContext*) override; - virtual ASTNodeFunctionPrototype* pass_types(CodeGenContext*, ASTType*) override; + virtual ASTNodeFunction* pass_types(CodeGenContext*, ASTType*) override; }; class ASTNodeFunctionCall : public ASTNode { diff --git a/sayaka/src/ast_node_function.cpp b/sayaka/src/ast_node_function.cpp index 614b1e2..70ea32f 100644 --- a/sayaka/src/ast_node_function.cpp +++ b/sayaka/src/ast_node_function.cpp @@ -2,30 +2,31 @@ #include -ASTNodeFunction::ASTNodeFunction(ASTNodeBlock* body, std::string return_type) { +ASTNodeFunction::ASTNodeFunction(ASTNodeFunctionPrototype* prototype, ASTNodeBlock* body) { + this->prototype = prototype; this->body = body; - this->return_type = return_type; } ASTNodeFunction::~ASTNodeFunction() { + delete this->prototype; delete this->body; } ASTNodeFunction* ASTNodeFunction::pass_types(CodeGenContext* code_gen_context, ASTType* ignore) { this->type = NULL; // TODO; - this->body = this->body->pass_types(code_gen_context, code_gen_context->ast_types_resolver.get(this->return_type)); + this->prototype = this->prototype->pass_types(code_gen_context, NULL); + this->body = this->body->pass_types(code_gen_context, code_gen_context->ast_types_resolver.get(this->prototype->return_type)); return this; } llvm::Value* ASTNodeFunction::gen_code(CodeGenContext* code_gen_context) { - std::cout << "Generating function" << std::endl; + std::cout << "Generating function " << this->prototype->function_name << std::endl; std::vector arg_types; - ASTType* type = code_gen_context->ast_types_resolver.get(this->return_type); + ASTType* type = code_gen_context->ast_types_resolver.get(this->prototype->return_type); if (type == NULL) { throw std::runtime_error("Unknown type"); } - llvm::FunctionType* fn_type = llvm::FunctionType::get(type->llvm_type, arg_types, false); - llvm::Function* fn = llvm::Function::Create(fn_type, llvm::Function::ExternalLinkage, "", code_gen_context->module); + llvm::Function* fn = (llvm::Function*) this->prototype->gen_code(code_gen_context); llvm::BasicBlock* basic_block = llvm::BasicBlock::Create(llvm::getGlobalContext(), "entry", fn); code_gen_context->push_block(basic_block); @@ -35,6 +36,7 @@ llvm::Value* ASTNodeFunction::gen_code(CodeGenContext* code_gen_context) { code_gen_context->builder.CreateRet(ret_val); llvm::verifyFunction(*fn); code_gen_context->pop_block(); + fn->dump(); return fn; } code_gen_context->pop_block(); diff --git a/sayaka/src/ast_node_functioncall.cpp b/sayaka/src/ast_node_functioncall.cpp index ff6f06e..9dd1c69 100644 --- a/sayaka/src/ast_node_functioncall.cpp +++ b/sayaka/src/ast_node_functioncall.cpp @@ -21,6 +21,7 @@ ASTNodeFunctionCall* ASTNodeFunctionCall::pass_types(CodeGenContext* code_gen_co } llvm::Value* ASTNodeFunctionCall::gen_code(CodeGenContext* code_gen_context) { + std::cout << "Generating function call" << std::endl; llvm::Function* callee = code_gen_context->module->getFunction(this->function_name); if (callee == NULL) { throw std::runtime_error("Unknown function called"); diff --git a/sayaka/src/ast_node_functionprototype.cpp b/sayaka/src/ast_node_functionprototype.cpp index 5554152..d90ec04 100644 --- a/sayaka/src/ast_node_functionprototype.cpp +++ b/sayaka/src/ast_node_functionprototype.cpp @@ -1,6 +1,6 @@ #include "ast_node.h" -ASTNodeFunctionPrototype::ASTNodeFunctionPrototype(std::string return_type, std::string function_name, std::vector>* args) { +ASTNodeFunctionPrototype::ASTNodeFunctionPrototype(std::string return_type, std::string function_name, std::vector* args) { this->return_type = return_type; this->function_name = function_name; this->args = args; @@ -11,17 +11,29 @@ ASTNodeFunctionPrototype::~ASTNodeFunctionPrototype() { } ASTNodeFunctionPrototype* ASTNodeFunctionPrototype::pass_types(CodeGenContext* code_gen_context, ASTType* type) { - this->type = code_gen_context->ast_types_resolver.int_ty(); + this->type = code_gen_context->ast_types_resolver.int_ty(); // TODO + for (std::vector::iterator it = this->args->begin(); it != this->args->end(); it++) { + *it = (*it)->pass_types(code_gen_context, NULL); + } return this; } llvm::Value* ASTNodeFunctionPrototype::gen_code(CodeGenContext* code_gen_context) { + std::cout << "Generating prototype" << std::endl; std::vector args_types_list; - for (std::vector>::iterator it = this->args->begin(); it != this->args->end(); it++) { - args_types_list.push_back(code_gen_context->ast_types_resolver.get(std::get<0>(*it))->llvm_type); + for (std::vector::iterator it = this->args->begin(); it != this->args->end(); it++) { + args_types_list.push_back(code_gen_context->ast_types_resolver.get((*it)->type_name)->llvm_type); } llvm::FunctionType* ft = llvm::FunctionType::get(code_gen_context->ast_types_resolver.get(this->return_type)->llvm_type, args_types_list, false); llvm::Function* fn = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, this->function_name, code_gen_context->module); - return llvm::ConstantInt::get(code_gen_context->ast_types_resolver.int_ty()->llvm_type, 0, true); + int i = 0; + for (llvm::Function::arg_iterator it = fn->arg_begin(); it != fn->arg_end(); it++) { + ASTNodeDeclaration* var = this->args->operator[](i); + it->setName(var->var_name); + code_gen_context->scope.put(var->var_name, new CodeGenVariable(code_gen_context->ast_types_resolver.get(var->type_name), it)); + i++; + } + return fn; + //return llvm::ConstantInt::get(code_gen_context->ast_types_resolver.int_ty()->llvm_type, 0, true); } diff --git a/sayaka/src/ast_node_identifier.cpp b/sayaka/src/ast_node_identifier.cpp index f76a46a..d42dd73 100644 --- a/sayaka/src/ast_node_identifier.cpp +++ b/sayaka/src/ast_node_identifier.cpp @@ -21,10 +21,14 @@ ASTNodeIdentifier* ASTNodeIdentifier::pass_types(CodeGenContext* code_gen_contex } llvm::Value* ASTNodeIdentifier::gen_code(CodeGenContext* code_gen_context) { - std::cout << "Generating identifier" << std::endl; + std::cout << "Generating identifier " << this->name << std::endl; CodeGenVariable* var = code_gen_context->scope.get(this->name); if (var == NULL) { throw std::runtime_error("Undeclared variable"); } - return code_gen_context->builder.CreateLoad(std::get<1>(*var), false, this->name); + llvm::Value* val = std::get<1>(*var); + if (llvm::AllocaInst* alloc = dynamic_cast(val)) { + return code_gen_context->builder.CreateLoad(alloc, false, this->name); + } + return val; } diff --git a/sayaka/src/code_gen_context.cpp b/sayaka/src/code_gen_context.cpp index c58708c..b7383f8 100644 --- a/sayaka/src/code_gen_context.cpp +++ b/sayaka/src/code_gen_context.cpp @@ -12,12 +12,10 @@ CodeGenContext::~CodeGenContext() { void CodeGenContext::push_block(llvm::BasicBlock* block) { std::cout << "Pushing block" << std::endl; this->blocks.push(block); - this->push_scope(); this->builder.SetInsertPoint(this->current_block()); } void CodeGenContext::pop_block() { - this->pop_scope(); std::cout << "Popping block" << std::endl; this->blocks.pop(); this->builder.SetInsertPoint(this->current_block()); diff --git a/sayaka/src/compiler.cpp b/sayaka/src/compiler.cpp index 231c486..8515841 100644 --- a/sayaka/src/compiler.cpp +++ b/sayaka/src/compiler.cpp @@ -29,11 +29,11 @@ ASTNode* Compiler::parse(std::string code) { } void Compiler::run_code(ASTNode* root) { - ASTNodeFunction main_fn((ASTNodeBlock*) root, "Int"); + std::vector* args = new std::vector(); + ASTNodeFunctionPrototype* main_fn_prototype = new ASTNodeFunctionPrototype("Int", "main", args); + ASTNodeFunction main_fn(main_fn_prototype, (ASTNodeBlock*) root); main_fn.pass_types(&this->code_gen_context, NULL); llvm::Function* main_fn_val = (llvm::Function*) main_fn.gen_code(&this->code_gen_context); - std::cout << "Main fn code:" << std::endl; - main_fn_val->dump(); void* fn_ptr = this->execution_engine->getPointerToFunction(main_fn_val); int32_t ret = ((int32_t (*)()) fn_ptr)(); std::cout << "Main fn at " << fn_ptr << "; executed: " << ret << std::endl; diff --git a/sayaka/src/parser.y b/sayaka/src/parser.y index 6d85a19..ef8214d 100644 --- a/sayaka/src/parser.y +++ b/sayaka/src/parser.y @@ -21,6 +21,7 @@ void yyerror(YYLTYPE* llocp, ASTNode**, yyscan_t scanner, const char *s) { class ASTNode; class ASTNodeBlock; class ASTNodeIdentifier; +class ASTNodeDeclaration; #ifndef YY_TYPEDEF_YY_SCANNER_T #define YY_TYPEDEF_YY_SCANNER_T @@ -41,7 +42,7 @@ typedef void* yyscan_t; ASTNode* node; ASTNodeBlock* block; ASTNodeIdentifier* identifier; - std::vector>* typed_args_list; + std::vector* typed_args_list; std::vector* args_list; std::string* str; } @@ -56,7 +57,7 @@ typedef void* yyscan_t; %token TOKEN_COMMA %token TOKEN_NUMBER TOKEN_IDENTIFIER TOKEN_TYPE_NAME -%type program expr number binary_operator_expr assignment_expr variable_declaration_expr cast_expr function_call_expr function_prototype_expr +%type program expr number binary_operator_expr assignment_expr variable_declaration_expr cast_expr function_call_expr function_prototype_expr function_expr %type stmts %type identifier %type typed_args_list @@ -85,6 +86,7 @@ expr | variable_declaration_expr { $$ = $1; } | assignment_expr { $$ = $1; } | cast_expr { $$ = $1; } + | function_expr { $$ = $1; } | function_prototype_expr { $$ = $1; } | function_call_expr { $$ = $1; } | binary_operator_expr { $$ = $1; } @@ -139,18 +141,19 @@ function_prototype_expr } ; +function_expr + : function_prototype_expr TOKEN_LBRACE stmts TOKEN_RBRACE { + $$ = new ASTNodeFunction((ASTNodeFunctionPrototype*) $1, $3); + } + ; typed_args_list - : TOKEN_TYPE_NAME TOKEN_IDENTIFIER { - $$ = new std::vector>(); - $$->push_back(std::make_tuple(*$1, *$2)); - delete $1; - delete $2; + : variable_declaration_expr { + $$ = new std::vector(); + $$->push_back((ASTNodeDeclaration*) $1); } - | typed_args_list TOKEN_COMMA TOKEN_TYPE_NAME TOKEN_IDENTIFIER { - $$->push_back(std::make_tuple(*$3, *$4)); - delete $3; - delete $4; + | typed_args_list TOKEN_COMMA variable_declaration_expr { + $$->push_back((ASTNodeDeclaration*) $3); } ;