From 1b980e5471d139472bb34079f9d71f24ba9bd6b1 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Fri, 21 Apr 2023 23:29:57 -0400 Subject: [PATCH] feat: Add support for typing annotation --- .makim.yaml | 13 +- examples/average.arx | 4 +- examples/constant.arx | 4 +- examples/fibonacci.arx | 8 +- examples/print-star.arx | 6 +- examples/sum.arx | 4 +- meson.build | 3 +- src/codegen/arx-llvm.cpp | 21 ++ src/codegen/ast-to-object.cpp | 44 +++- src/codegen/ast-to-stdout.cpp | 46 ++-- src/error.h | 17 +- src/lexer.cpp | 57 ++++- src/lexer.h | 9 +- src/main.cpp | 3 +- src/parser.cpp | 207 +++++++++++++----- src/parser.h | 11 +- .../unittests/codegen/test-ast-to-llvm-ir.cpp | 2 +- .../unittests/codegen/test-ast-to-object.cpp | 2 +- .../unittests/codegen/test-ast-to-stdout.cpp | 2 +- tests/unittests/test-lexer.cpp | 14 +- tests/unittests/test-parser.cpp | 26 ++- 21 files changed, 379 insertions(+), 124 deletions(-) diff --git a/.makim.yaml b/.makim.yaml index 071904b..00dd6a0 100644 --- a/.makim.yaml +++ b/.makim.yaml @@ -19,8 +19,9 @@ env: :print_legend=1\ :detect_leaks=1\ " - MESON_EXTRA: "-Db_coverage=true \ - -Doptimization=0 \ + MESON_EXTRA_DEBUG: "-Db_coverage=true \ + --optimization=0 \ + --debug \ -Db_sanitize=address \ " groups: @@ -104,7 +105,7 @@ groups: - target: build.release args: build-type: "debug" - meson-extra: {{ env.MESON_EXTRA }} + meson-extra: {{ env.MESON_EXTRA_DEBUG }} clean: {{ args.clean }} asan-options: {{ env.SAN_OPTIONS_DEFAULT }} lsan-options: {{ env.SAN_OPTIONS_DEFAULT }} @@ -120,7 +121,7 @@ groups: - target: build.release args: build-type: "debug" - meson-extra: {{ env.MESON_EXTRA }} -Ddev=enabled + meson-extra: {{ env.MESON_EXTRA_DEBUG }} -Ddev=enabled clean: {{ args.clean }} asan-options: {{ env.SAN_OPTIONS_DEFAULT }} lsan-options: {{ env.SAN_OPTIONS_DEFAULT }} @@ -231,7 +232,7 @@ groups: if [[ "{{ args.debug }}" == "True" ]]; then GDB="gdb --args" - DEBUG_FLAGS="-g" + DEBUG_FLAGS="-Og" fi TEST_DIR_PATH="./tests" @@ -261,7 +262,7 @@ groups: print_header "${test_name}" OBJECT_FILE="${TMP_DIR}/${test_name}.o" - ${ARX} --output "${OBJECT_FILE}" --input "examples/${test_name}.arx" + ${ARX} --output "${OBJECT_FILE}" --input "examples/${test_name}.arx --build-lib" set -x clang++ \ diff --git a/examples/average.arx b/examples/average.arx index b32feb6..84bfdb6 100644 --- a/examples/average.arx +++ b/examples/average.arx @@ -1,2 +1,2 @@ -function average(x y): - (x + y) * 0.5; +fn average(x: float, y: float) -> float: + return (x + y) * 0.5; diff --git a/examples/constant.arx b/examples/constant.arx index c5b12d1..1b8ca26 100644 --- a/examples/constant.arx +++ b/examples/constant.arx @@ -1,2 +1,2 @@ -function get_constant(x): - x; +fn get_constant(x: float) -> float: + return x; diff --git a/examples/fibonacci.arx b/examples/fibonacci.arx index 7474bb6..c72772b 100644 --- a/examples/fibonacci.arx +++ b/examples/fibonacci.arx @@ -1,5 +1,5 @@ -function fib(x): - if x < 3: - 1 +fn fib(x: float) -> float: + if x <= 1: + return x; else: - fib(x-1)+fib(x-2) + return fib(x-1)+fib(x-2); diff --git a/examples/print-star.arx b/examples/print-star.arx index 2d17fcb..eda69a7 100644 --- a/examples/print-star.arx +++ b/examples/print-star.arx @@ -1,5 +1,5 @@ -extern putchard(char); +extern putchard(char) -> void; -function print_star(n): +fn print_star(n: float) -> void: for i = 1, i < n, 1.0 in - putchard(42); # ascii 42 = '*' + return putchard(42); # ascii 42 = '*' diff --git a/examples/sum.arx b/examples/sum.arx index 10c4bcd..d464f49 100644 --- a/examples/sum.arx +++ b/examples/sum.arx @@ -1,2 +1,2 @@ -function sum(a b): - a + b; +fn sum(a: float, b: float) -> float: + return a + b; diff --git a/meson.build b/meson.build index d64b8c5..765b89c 100644 --- a/meson.build +++ b/meson.build @@ -2,7 +2,8 @@ project('arx', 'cpp', 'c', license : 'Apache-2.0', version : '1.6.0', # semantic-release default_options : [ - 'warning_level=everything', + #'warning_level=everything', + 'warning_level=1', 'cpp_std=c++20', ] ) diff --git a/src/codegen/arx-llvm.cpp b/src/codegen/arx-llvm.cpp index fa93489..97c4fc6 100644 --- a/src/codegen/arx-llvm.cpp +++ b/src/codegen/arx-llvm.cpp @@ -1,3 +1,4 @@ +#include #include // for DIBuilder #include // for IRBuilder @@ -22,6 +23,26 @@ llvm::Type* ArxLLVM::FLOAT_TYPE; llvm::Type* ArxLLVM::DOUBLE_TYPE; llvm::Type* ArxLLVM::INT8_TYPE; llvm::Type* ArxLLVM::INT32_TYPE; +llvm::Type* ArxLLVM::VOID_TYPE; + +auto ArxLLVM::get_data_type(std::string type_name) -> llvm::Type* { + if (type_name == "float") { + return ArxLLVM::FLOAT_TYPE; + } else if (type_name == "double") { + return ArxLLVM::DOUBLE_TYPE; + } else if (type_name == "int8") { + return ArxLLVM::INT8_TYPE; + } else if (type_name == "int32") { + return ArxLLVM::INT32_TYPE; + } else if (type_name == "char") { + return ArxLLVM::INT8_TYPE; + } else if (type_name == "void") { + return ArxLLVM::VOID_TYPE; + } + + llvm::errs() << "[EE] type_name not valid.\n"; + return nullptr; +} /* Debug Information Data types */ llvm::DIType* ArxLLVM::DI_FLOAT_TYPE; diff --git a/src/codegen/ast-to-object.cpp b/src/codegen/ast-to-object.cpp index a0cc541..e404551 100644 --- a/src/codegen/ast-to-object.cpp +++ b/src/codegen/ast-to-object.cpp @@ -97,8 +97,9 @@ auto ASTToObjectVisitor::getFunction(std::string name) -> void { */ auto ASTToObjectVisitor::CreateEntryBlockAlloca( llvm::Function* fn, llvm::StringRef var_name) -> llvm::AllocaInst* { - llvm::IRBuilder<> TmpB(&fn->getEntryBlock(), fn->getEntryBlock().begin()); - return TmpB.CreateAlloca(ArxLLVM::FLOAT_TYPE, nullptr, var_name); + llvm::IRBuilder<> tmp_builder( + &fn->getEntryBlock(), fn->getEntryBlock().begin()); + return tmp_builder.CreateAlloca(ArxLLVM::FLOAT_TYPE, nullptr, var_name); } /** @@ -165,7 +166,7 @@ auto ASTToObjectVisitor::visit(UnaryExprAST& expr) -> void { * */ auto ASTToObjectVisitor::visit(BinaryExprAST& expr) -> void { - // Special case '=' because we don't want to emit the lhs as an + // Special case '=' because we don't want to emit the lhs as an // expression.*/ if (expr.op == '=') { // Assignment requires the lhs to be an identifier. @@ -188,13 +189,13 @@ auto ASTToObjectVisitor::visit(BinaryExprAST& expr) -> void { }; // Look up the name.// - llvm::Value* Variable = ArxLLVM::named_values[var_lhs->get_name()]; - if (!Variable) { + llvm::Value* variable = ArxLLVM::named_values[LHSE->getName()]; + if (!variable) { this->result_val = LogErrorV("Unknown variable name"); return; } - ArxLLVM::ir_builder->CreateStore(val, Variable); + ArxLLVM::ir_builder->CreateStore(val, variable); this->result_val = val; } @@ -520,11 +521,32 @@ auto ASTToObjectVisitor::visit(VarExprAST& expr) -> void { * */ auto ASTToObjectVisitor::visit(PrototypeAST& expr) -> void { - std::vector args_type(expr.args.size(), ArxLLVM::FLOAT_TYPE); - llvm::Type* return_type = ArxLLVM::get_data_type("float"); + // Make the function type: double(double,double) etc. + std::vector args; + llvm::Type* arg_type; + + for (auto& arg : expr.args) { + arg_type = ArxLLVM::get_data_type(arg->type_name); + if (arg_type != nullptr) { + args.emplace_back(arg_type); + } else { + llvm::errs() << "ARX::GEN-OBJECT[ERROR]: PrototypeAST: " + << "Argument data type " << arg->type_name + << " not implemented yet."; + } + } + + llvm::Type* return_type = ArxLLVM::get_data_type(expr.type_name); + + if (return_type == nullptr) { + llvm::errs() << "ARX::GEN-OBJECT[ERROR]: PrototypeAST: " + << "Argument data type " << expr.type_name + << " not implemented yet."; + } - llvm::FunctionType* fn_type = - llvm::FunctionType::get(return_type, args_type, false /* isVarArg */); + llvm::FunctionType* fn_type = llvm::FunctionType::get( + return_type, args, false /* isVarArg */ + ); llvm::Function* fn = llvm::Function::Create( fn_type, @@ -794,7 +816,7 @@ auto compile_object(TreeAST& tree_ast) -> int { std::cout << "ARX[INFO]: " << compiler_cmd << std::endl; int compile_result = system(compiler_cmd.c_str()); - // ArxFile::delete_file(main_cpp_path); + ArxFile::delete_file(main_cpp_path); if (compile_result != 0) { llvm::errs() << "failed to compile and link object file"; diff --git a/src/codegen/ast-to-stdout.cpp b/src/codegen/ast-to-stdout.cpp index 312f4a9..7401d98 100644 --- a/src/codegen/ast-to-stdout.cpp +++ b/src/codegen/ast-to-stdout.cpp @@ -8,9 +8,7 @@ int INDENT_SIZE = 2; -class ASTToOutputVisitor - : public std::enable_shared_from_this, - public Visitor { +class ASTToOutputVisitor : public Visitor { public: int indent = 0; std::string annotation = ""; @@ -53,12 +51,21 @@ void ASTToOutputVisitor::visit(FloatExprAST& expr) { void ASTToOutputVisitor::visit(VariableExprAST& expr) { std::cout << this->indentation() << this->get_annotation() - << "(VariableExprAST " << expr.name << ")"; + << "(VariableExprAST " << expr.name << ":" << expr.type_name + << ")"; } void ASTToOutputVisitor::visit(UnaryExprAST& expr) { - std::cout << "(UnaryExprAST" - << ")" << std::endl; + std::cout << this->indentation() << "(UnaryExprAST op_code:" << expr.op_code + << " operand:"; + + int cur_indent = this->indent; + this->indent = 0; + + expr.operand->accept(*this); + std::cout << ")"; + + this->indent = cur_indent; } void ASTToOutputVisitor::visit(BinaryExprAST& expr) { @@ -197,29 +204,28 @@ void ASTToOutputVisitor::visit(VarExprAST& expr) { void ASTToOutputVisitor::visit(PrototypeAST& expr) { // TODO: implement it - std::cout << "(PrototypeAST " << expr.name << ")" << std::endl; -} - -void ASTToOutputVisitor::visit(FunctionAST& expr) { - std::cout << this->indentation() << '(' << std::endl; - this->indent += INDENT_SIZE; - - // create the function and open the args section - std::cout << this->indentation() << "Function " << expr.proto->name + std::cout << "(PrototypeAST " << expr.name << ") -> " << expr.type_name << " (" << std::endl; this->indent += INDENT_SIZE; - // std::cout << expr.proto->args.front(); - - for (const auto& node : expr.proto->args) { + for (const auto& node : expr.args) { node->accept(*this); std::cout << ", " << std::endl; } // close args section and open body section this->indent -= INDENT_SIZE; - std::cout << this->indentation() << "), " << std::endl - << this->indentation() << " (" << std::endl; + std::cout << this->indentation() << "), " << std::endl; +} + +void ASTToOutputVisitor::visit(FunctionAST& expr) { + std::cout << this->indentation() << '(' << std::endl; + this->indent += INDENT_SIZE; + + // create the function and open the args section + std::cout << this->indentation() << "Function "; + this->visit(*expr.proto); + std::cout << this->indentation() << " (" << std::endl; this->indent += INDENT_SIZE; // TODO: body should be a vector of unique_ptr diff --git a/src/error.h b/src/error.h index 1a03fa3..0258a0a 100644 --- a/src/error.h +++ b/src/error.h @@ -12,12 +12,23 @@ namespace llvm { } // namespace llvm /** - * @brief LogError* - These are little helper functions for error handling. + * @brief LogError* - A little helper function for error handling. * */ template -std::unique_ptr LogError(const char* Str) { - fprintf(stderr, "Error: %s\n", Str); +std::unique_ptr LogError(const char* msg) { + fprintf(stderr, "Error: %s\n", msg); + return nullptr; +} + +/** + * @brief LogError* - A little helper function for error handling with line + * and col information. + * + */ +template +std::unique_ptr LogParserError(const char* msg, int line, int col) { + fprintf(stderr, "ParserError[%i:%i]: %s\n", line, col, msg); return nullptr; } diff --git a/src/lexer.cpp b/src/lexer.cpp index 5f67c59..c330fae 100644 --- a/src/lexer.cpp +++ b/src/lexer.cpp @@ -51,6 +51,55 @@ auto Lexer::get_tok_name(int tok) -> std::string { return "var"; case tok_const: return "const"; + case tok_arrow_right: + return "->"; + } + return std::string(1, static_cast(Tok)); +} + +/** + * @brief Get the Token name to be used in a message. + * @param Tok The token + * @return Token name + * + */ +auto Lexer::get_tok_name_display(int Tok) -> std::string { + switch (Tok) { + case tok_eof: + return ""; + case tok_function: + return ""; + case tok_return: + return ""; + case tok_extern: + return ""; + case tok_identifier: + return ""; + case tok_float_literal: + return ""; + case tok_if: + return ""; + case tok_then: + return ""; + case tok_else: + return ""; + case tok_for: + return ""; + case tok_in: + return ""; + case tok_binary: + return ""; + case tok_unary: + return ""; + case tok_var: + return ""; + case tok_const: + return ""; + case tok_arrow_right: + return "->"; + case tok_expression: + // just used for error message + return ""; } return std::string(1, static_cast(tok)); } @@ -114,7 +163,7 @@ auto Lexer::gettok() -> int { Lexer::identifier_str += last_char; } - if (Lexer::identifier_str == "function") { + if (Lexer::identifier_str == "fn") { return tok_function; } if (Lexer::identifier_str == "return") { @@ -178,6 +227,12 @@ auto Lexer::gettok() -> int { // Otherwise, just return the character as its ascii value. int this_char = last_char; last_char = static_cast(Lexer::advance()); + + if (this_char == (int) '-' && last_char == (int) '>') { + last_char = static_cast(Lexer::advance()); + return tok_arrow_right; + } + return this_char; } diff --git a/src/lexer.h b/src/lexer.h index 70c7f93..256eaae 100644 --- a/src/lexer.h +++ b/src/lexer.h @@ -34,13 +34,15 @@ enum Token { // var definition tok_var = -40, tok_const = -41, + tok_arrow_right = -42, - tok_not_initialized = -9999 + tok_not_initialized = -9999, + tok_expression = -10000 // generic used just for error message }; struct SourceLocation { - int line; - int col; + int line = 1; + int col = 0; }; class Lexer { @@ -52,6 +54,7 @@ class Lexer { static SourceLocation lex_loc; static std::string get_tok_name(int); + static std::string get_tok_name_display(int); static int gettok(); static int advance(); static int get_next_token(); diff --git a/src/main.cpp b/src/main.cpp index e6d053a..4502355 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -9,8 +9,9 @@ // #include #include // for InitGoogleLogging -#include // for exit #include +#include // for exit +#include #include // for string, allocator #include "codegen/arx-llvm.h" // for ArxLLVM #include "codegen/ast-to-llvm-ir.h" // for compile_llvm_ir diff --git a/src/parser.cpp b/src/parser.cpp index 4df55ac..411a441 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1,6 +1,5 @@ #include "parser.h" // for ExprAST, Parser, PrototypeAST, ForExprAST #include // for isascii -#include // for strcat, strcpy #include // for operator<<, basic_ostream::operator<<, cout #include // for map #include // for unique_ptr, make_unique @@ -11,6 +10,38 @@ #include "error.h" // for LogError #include "lexer.h" // for Lexer, Lexer::cur_tok, Lexer::cur_loc, tok_iden... +static auto get_token_value(int tok) -> std::string { + switch (tok) { + case tok_identifier: + return std::string("(") + Lexer::identifier_str + std::string(")"); + case tok_float_literal: + return std::string("(") + std::to_string(Lexer::num_float) + + std::string(")"); + default: + return std::string(""); + } +} + +static auto parse_semicolon() -> void { + // NOTE: improve the way of parse `;` + while (Lexer::cur_tok == ';') { + Lexer::get_next_token(); + } +} + +template +std::unique_ptr log_full_parser_error( + int tok_expected, int tok_received, std::string msg) { + return LogParserError( + (std::string("Expected ") + Lexer::get_tok_name_display(tok_expected) + + std::string(", but received ") + + Lexer::get_tok_name_display(tok_received) + + get_token_value(tok_received) + std::string(". ") + msg) + .c_str(), + Lexer::cur_loc.line, + Lexer::cur_loc.col); +} + /** * @brief This holds the precedence for each binary operator that * is defined. @@ -136,7 +167,7 @@ std::unique_ptr Parser::parse_identifier_expr() { Lexer::get_next_token(); // eat identifier. if (Lexer::cur_tok != '(') { // Simple variable ref. - return std::make_unique(LitLoc, IdName); + return std::make_unique(LitLoc, IdName, "REF"); } // Call. // @@ -175,7 +206,7 @@ std::unique_ptr Parser::parse_identifier_expr() { */ std::unique_ptr Parser::parse_if_expr() { SourceLocation if_loc = Lexer::cur_loc; - char msg[80]; + std::string msg; Lexer::get_next_token(); // eat the if. @@ -186,10 +217,9 @@ std::unique_ptr Parser::parse_if_expr() { }; if (Lexer::cur_tok != ':') { - strcpy(msg, "Parser: `if` statement expected ':', received: '"); - strcat(msg, std::to_string(Lexer::cur_tok).c_str()); - strcat(msg, "'."); - return LogError(msg); + msg = std::string("Parser: `if` statement expected ':', received: '") + + std::to_string(Lexer::cur_tok) + std::string("'."); + return LogError(msg.c_str()); } Lexer::get_next_token(); // eat the ':' @@ -198,16 +228,18 @@ std::unique_ptr Parser::parse_if_expr() { return nullptr; }; + parse_semicolon(); + if (Lexer::cur_tok != tok_else) { - return LogError("Parser: Expected else"); + return log_full_parser_error( + tok_else, Lexer::cur_tok, std::string("")); } Lexer::get_next_token(); // eat the else token if (Lexer::cur_tok != ':') { - strcpy(msg, "Parser: `else` statement expected ':', received: '"); - strcat(msg, std::to_string(Lexer::cur_tok).c_str()); - strcat(msg, "'."); - return LogError(msg); + msg = std::string("Parser: `else` statement expected ':', received: '") + + std::to_string(Lexer::cur_tok) + std::string("'."); + return LogError(msg.c_str()); } Lexer::get_next_token(); // eat the ':' @@ -353,7 +385,10 @@ std::unique_ptr Parser::parse_var_expr() { * ::= varexpr */ std::unique_ptr Parser::parse_primary() { - char msg[80]; + std::string msg; + int tmp_tok = 0; + + parse_semicolon(); switch (Lexer::cur_tok) { case tok_identifier: @@ -368,15 +403,15 @@ std::unique_ptr Parser::parse_primary() { return static_cast>(parse_for_expr()); case tok_var: return static_cast>(parse_var_expr()); - case ';': - // ignore top-level semicolons. - Lexer::get_next_token(); // eat `;` + case tok_return: + // NOTE: it would be treated in a proper way in a follow-up PR + Lexer::get_next_token(); // eat `return` return Parser::parse_primary(); default: - strcpy(msg, "Parser: Unknown token when expecting an expression: '"); - strcat(msg, std::to_string(Lexer::cur_tok).c_str()); - strcat(msg, "'."); - return LogError(msg); + tmp_tok = Lexer::cur_tok; + Lexer::get_next_token(); // eat the wrong token + return log_full_parser_error( + tok_expression, tmp_tok, std::string("")); } } @@ -388,18 +423,23 @@ std::unique_ptr Parser::parse_primary() { * ::= '!' unary */ std::unique_ptr Parser::parse_unary() { + if (Lexer::cur_tok == ';') { + Lexer::get_next_token(); + return Parser::parse_primary(); + } + // If the current token is not an operator, it must be a primary expr. if ( !isascii(Lexer::cur_tok) || Lexer::cur_tok == '(' || - Lexer::cur_tok == ',') { + Lexer::cur_tok == ',' || Lexer::cur_tok == tok_return) { return Parser::parse_primary(); } // If this is a unary operator, read it. - int Opc = Lexer::cur_tok; + int op_code = Lexer::cur_tok; Lexer::get_next_token(); if (auto operand = Parser::parse_unary()) { - return std::make_unique(Opc, std::move(operand)); + return std::make_unique(op_code, std::move(operand)); } return nullptr; } @@ -416,7 +456,7 @@ std::unique_ptr Parser::parse_bin_op_rhs( int expr_prec, std::unique_ptr lhs) { // If this is a binop, find its precedence. // while (true) { - int tok_prec = get_tok_precedence(); + int tok_prec = Parser::get_tok_precedence(); // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. @@ -486,30 +526,47 @@ std::unique_ptr Parser::parse_extern_prototype() { break; default: - return LogError( - "Parser: Expected function name in prototype"); + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Expected function name in the function definition.")); } if (Lexer::cur_tok != '(') { - return LogError( - "Parser: Expected '(' in the function definition."); + return log_full_parser_error( + '(', Lexer::cur_tok, std::string("")); } std::vector> args; while (Lexer::get_next_token() == tok_identifier) { - auto arg = std::make_unique( - VariableExprAST(Lexer::cur_loc, Lexer::identifier_str)); - args.emplace_back(std::move(arg)); + args.emplace_back(std::make_unique( + Lexer::cur_loc, Lexer::identifier_str, Lexer::identifier_str)); } + if (Lexer::cur_tok != ')') { - return LogError( - "Parser: Expected ')' in the function definition."); + return log_full_parser_error( + ')', Lexer::cur_tok, std::string("")); } // success. // - Lexer::get_next_token(); // eat ')'. + if (Lexer::get_next_token() != tok_arrow_right) { + return log_full_parser_error( + tok_arrow_right, + Lexer::cur_tok, + std::string("Returning type annotation required. ") + + std::string("If the function doesn't return a value, use `void`.")); + } - return std::make_unique(fn_loc, fn_name, std::move(args)); + if (Lexer::get_next_token() != tok_identifier) { + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Returning type annotation required. ") + + std::string("If the function doesn't return a value, use `void`.")); + } + + return std::make_unique( + fn_loc, fn_name, Lexer::identifier_str, std::move(args)); } /** @@ -522,6 +579,13 @@ std::unique_ptr Parser::parse_extern_prototype() { */ std::unique_ptr Parser::parse_prototype() { std::string fn_name; + std::string var_type_annotation; + std::string ret_type_annotation; + std::string identifier_name; + std::string msg; + + SourceLocation cur_loc; + SourceLocation fn_loc = Lexer::cur_loc; switch (Lexer::cur_tok) { @@ -531,37 +595,79 @@ std::unique_ptr Parser::parse_prototype() { break; default: - return LogError( - "Parser: Expected function name in prototype"); + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Expected function name in prototype.")); } if (Lexer::cur_tok != '(') { - return LogError( - "Parser: Expected '(' in the function definition."); + return log_full_parser_error( + '(', Lexer::cur_tok, std::string("")); } std::vector> args; while (Lexer::get_next_token() == tok_identifier) { - auto arg = std::make_unique( - VariableExprAST(Lexer::cur_loc, Lexer::identifier_str)); - args.emplace_back(std::move(arg)); + // note: this is a workaround + identifier_name = Lexer::identifier_str; + cur_loc = Lexer::cur_loc; + + if (Lexer::get_next_token() != ':') { + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Variable type annotation required.")); + } + if (Lexer::get_next_token() != tok_identifier) { + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Variable type annotation required.")); + } + var_type_annotation = Lexer::identifier_str; + + args.push_back(std::make_unique( + cur_loc, identifier_name, var_type_annotation)); + + if (Lexer::get_next_token() != ',') { + break; + } } + if (Lexer::cur_tok != ')') { - return LogError( - "Parser: Expected ')' in the function definition."); + return log_full_parser_error( + ')', Lexer::cur_tok, std::string("")); } // success. // - Lexer::get_next_token(); // eat ')'. + if (Lexer::get_next_token() != tok_arrow_right) { + return log_full_parser_error( + tok_arrow_right, + Lexer::cur_tok, + std::string("Returning type annotation required. ") + + std::string("If the function doesn't return a value, use `void`.")); + } - if (Lexer::cur_tok != ':') { - return LogError( - "Parser: Expected ':' in the function definition"); + if (Lexer::get_next_token() != tok_identifier) { + return log_full_parser_error( + tok_identifier, + Lexer::cur_tok, + std::string("Returning type annotation required. ") + + std::string("If the function doesn't return a value, use `void`.")); } - Lexer::get_next_token(); // eat ':'. + // TODO: ret_type_annotation is not used yet + ret_type_annotation = Lexer::identifier_str; + + if (Lexer::get_next_token() != ':') { + return log_full_parser_error( + ':', + Lexer::cur_tok, + std::string("Before starting the function body, use ':'.")); + } - return std::make_unique(fn_loc, fn_name, std::move(args)); + return std::make_unique( + fn_loc, fn_name, ret_type_annotation, std::move(args)); } /** @@ -594,6 +700,7 @@ std::unique_ptr Parser::parse_top_level_expr() { auto proto = std::make_unique( fn_loc, "__anon_expr", + "void", // ANONYMOUS std::move(std::vector>())); return std::make_unique(std::move(proto), std::move(E)); } diff --git a/src/parser.h b/src/parser.h index b3145ce..dde779e 100644 --- a/src/parser.h +++ b/src/parser.h @@ -116,13 +116,14 @@ class FloatExprAST : public ExprAST { class VariableExprAST : public ExprAST { public: std::string name; + std::string type_name; /** * @param _loc The token location * @param _name The variable name */ - VariableExprAST(SourceLocation _loc, std::string _name) - : ExprAST(_loc), name(std::move(_name)) { + VariableExprAST(SourceLocation _loc, std::string _name, std::string _type_name) + : ExprAST(_loc), name(std::move(_name), type_name(std::move(_type_name)) { this->kind = ExprKind::VariableKind; } @@ -253,7 +254,7 @@ class IfExprAST : public ExprAST { ExprAST::dump(out << "if", ind); this->cond->dump(indent(out, ind) << "cond:", ind + 1); this->then->dump(indent(out, ind) << "then:", ind + 1); - this->else_->dump(indent(out, ind) << "else_:", ind + 1); + this->else_->dump(indent(out, ind) << "else:", ind + 1); return out; } }; @@ -340,6 +341,7 @@ class VarExprAST : public ExprAST { class PrototypeAST : public ExprAST { public: std::string name; + std::string type_name; std::vector> args; int line; @@ -351,8 +353,9 @@ class PrototypeAST : public ExprAST { PrototypeAST( SourceLocation _loc, std::string _name, + std::string _type_name, std::vector>&& _args) - : name(std::move(_name)), args(std::move(_args)), line(_loc.line) { + : name(std::move(_name)), type_name(std::move(_type_name)), args(std::move(_args)), line(_loc.line) { this->kind = ExprKind::PrototypeKind; } diff --git a/tests/unittests/codegen/test-ast-to-llvm-ir.cpp b/tests/unittests/codegen/test-ast-to-llvm-ir.cpp index f6a4175..d9832c9 100644 --- a/tests/unittests/codegen/test-ast-to-llvm-ir.cpp +++ b/tests/unittests/codegen/test-ast-to-llvm-ir.cpp @@ -7,7 +7,7 @@ // Check object generation TEST(CodeGenTest, ObjectGeneration) { string_to_buffer((char*) R""""( - function add_one(a): + fn add_one(a: float) -> float: a + 1 add(1); diff --git a/tests/unittests/codegen/test-ast-to-object.cpp b/tests/unittests/codegen/test-ast-to-object.cpp index 3bf2b48..98219d8 100644 --- a/tests/unittests/codegen/test-ast-to-object.cpp +++ b/tests/unittests/codegen/test-ast-to-object.cpp @@ -9,7 +9,7 @@ extern bool IS_BUILD_LIB; // Check object generation TEST(CodeGenTest, ObjectGeneration) { string_to_buffer((char*) R""""( - function add_one(a): + fn add_one(a: float) -> float: a + 1 add(1); diff --git a/tests/unittests/codegen/test-ast-to-stdout.cpp b/tests/unittests/codegen/test-ast-to-stdout.cpp index 13ce262..fa4caa9 100644 --- a/tests/unittests/codegen/test-ast-to-stdout.cpp +++ b/tests/unittests/codegen/test-ast-to-stdout.cpp @@ -7,7 +7,7 @@ // Check object generation TEST(CodeGenTest, ObjectGeneration) { string_to_buffer((char*) R""""( - function add_one(a): + fn add_one(a: float) -> float: a + 1 add(1); diff --git a/tests/unittests/test-lexer.cpp b/tests/unittests/test-lexer.cpp index c3aa0e5..a0e39ce 100644 --- a/tests/unittests/test-lexer.cpp +++ b/tests/unittests/test-lexer.cpp @@ -59,11 +59,11 @@ TEST(LexerTest, GetNextTokenSimpleTest) { TEST(LexerTest, GetTokTest) { /* Test gettok for main tokens */ string_to_buffer((char*) R""""( - function math(x): + fn math(x: float) -> float: if x > 10: - x + 1 + return x + 1; else: - x * 20 + return x * 20; math(1); )""""); @@ -72,21 +72,29 @@ TEST(LexerTest, GetTokTest) { EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) '('); EXPECT_EQ(Lexer::gettok(), tok_identifier); + EXPECT_EQ(Lexer::gettok(), (int) ':'); + EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) ')'); + EXPECT_EQ(Lexer::gettok(), tok_arrow_right); + EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) ':'); EXPECT_EQ(Lexer::gettok(), tok_if); EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) '>'); EXPECT_EQ(Lexer::gettok(), tok_float_literal); EXPECT_EQ(Lexer::gettok(), (int) ':'); + EXPECT_EQ(Lexer::gettok(), tok_return); EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) '+'); EXPECT_EQ(Lexer::gettok(), tok_float_literal); + EXPECT_EQ(Lexer::gettok(), (int) ';'); EXPECT_EQ(Lexer::gettok(), tok_else); EXPECT_EQ(Lexer::gettok(), (int) ':'); + EXPECT_EQ(Lexer::gettok(), tok_return); EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) '*'); EXPECT_EQ(Lexer::gettok(), tok_float_literal); + EXPECT_EQ(Lexer::gettok(), (int) ';'); EXPECT_EQ(Lexer::gettok(), tok_identifier); EXPECT_EQ(Lexer::gettok(), (int) '('); EXPECT_EQ(Lexer::gettok(), tok_float_literal); diff --git a/tests/unittests/test-parser.cpp b/tests/unittests/test-parser.cpp index 3fc4eff..ee795c0 100644 --- a/tests/unittests/test-parser.cpp +++ b/tests/unittests/test-parser.cpp @@ -11,11 +11,11 @@ TEST(ParserTest, GetNextTokenTest) { /* Test gettok for main tokens */ string_to_buffer((char*) R""""( - function math(x): + fn math(x: float) -> float: if x > 10: - x + 1 + return x + 1; else: - x * 20 + return x * 20; math(1); )""""); @@ -29,8 +29,16 @@ TEST(ParserTest, GetNextTokenTest) { Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_identifier); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, (int) ':'); + Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, tok_identifier); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) ')'); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, tok_arrow_right); + Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, tok_identifier); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) ':'); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_if); @@ -43,22 +51,30 @@ TEST(ParserTest, GetNextTokenTest) { Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) ':'); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, tok_return); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_identifier); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) '+'); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_float_literal); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, (int) ';'); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_else); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) ':'); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, tok_return); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_identifier); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) '*'); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_float_literal); Lexer::get_next_token(); + EXPECT_EQ(Lexer::cur_tok, (int) ';'); + Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, tok_identifier); Lexer::get_next_token(); EXPECT_EQ(Lexer::cur_tok, (int) '('); @@ -114,9 +130,9 @@ TEST(ParserTest, ParseIfExprTest) { /* Test gettok for main tokens */ string_to_buffer((char*) R""""( if 1 > 2: - a = 1 + a = 1; else: - a = 2 + a = 2; )""""); Lexer::get_next_token(); // update Lexer::cur_tok