diff --git a/compare_compilers.sh b/compare_compilers.sh index 5f63dd8b..575ce24a 100755 --- a/compare_compilers.sh +++ b/compare_compilers.sh @@ -28,8 +28,9 @@ for arg in "$@"; do done if [ ${#files[@]} = 0 ]; then + # skip compiler_cli, because it has a race condition when two compilers simultaneously run it # TODO: do not skip Advent Of Code files - files=( $(find stdlib examples tests -name '*.jou' | grep -v aoc2023 | sort) ) + files=( $(find stdlib examples tests -name '*.jou' | grep -v aoc2023 | grep -v tests/should_succeed/compiler_cli | grep -v tests/crash | sort) ) fi if [ ${#actions[@]} = 0 ]; then actions=(tokenize parse run) @@ -103,11 +104,11 @@ for action in ${actions[@]}; do # Run compilers in parallel to speed up. ( set +e - ./jou $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: ' + ./jou $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: |compiler warning for file' ) > tmp/compare_compilers/compiler_written_in_c.txt & ( set +e - ./self_hosted_compiler $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: |linking failed' + ./self_hosted_compiler $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: |linking failed|compiler warning for file' ) > tmp/compare_compilers/self_hosted.txt & wait diff --git a/self_hosted/ast.jou b/self_hosted/ast.jou index 39f522bf..b5e82010 100644 --- a/self_hosted/ast.jou +++ b/self_hosted/ast.jou @@ -145,7 +145,10 @@ class AstExpression: float_or_double_text: byte[100] operands: AstExpression* # Only for operators. Length is arity, see get_arity() - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("[line %d] ", self->location.lineno) if self->kind == AstExpressionKind::String: printf("\"") @@ -181,7 +184,7 @@ class AstExpression: elif self->kind == AstExpressionKind::Array: printf("array\n") for i = 0; i < self->array.length; i++: - self->array.items[i].print(tp.print_prefix(i == self->array.length-1)) + self->array.items[i].print_with_tree_printer(tp.print_prefix(i == self->array.length-1)) elif self->kind == AstExpressionKind::Call: if self->call.uses_arrow_operator: printf("dereference and ") @@ -204,12 +207,12 @@ class AstExpression: if self->class_field.uses_arrow_operator: printf("dereference and ") printf("get class field \"%s\"\n", self->class_field.field_name) - self->class_field.instance->print(tp.print_prefix(True)) + self->class_field.instance->print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstExpressionKind::As: printf("as ") self->as_expression->type.print(True) printf("\n") - self->as_expression->value.print(tp.print_prefix(True)) + self->as_expression->value.print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstExpressionKind::SizeOf: printf("sizeof\n") elif self->kind == AstExpressionKind::AddressOf: @@ -258,7 +261,7 @@ class AstExpression: printf("?????\n") for i = 0; i < self->get_arity(); i++: - self->operands[i].print(tp.print_prefix(i == self->get_arity()-1)) + self->operands[i].print_with_tree_printer(tp.print_prefix(i == self->get_arity()-1)) def free(self) -> None: if self->kind == AstExpressionKind::Call: @@ -367,12 +370,12 @@ class AstCall: if self->method_call_self != NULL: sub = tp.print_prefix(self->nargs == 0) printf("self: ") - self->method_call_self->print(sub) + self->method_call_self->print_with_tree_printer(sub) for i = 0; i < self->nargs; i++: sub = tp.print_prefix(i == self->nargs - 1) printf("argument %d: ", i) - self->args[i].print(sub) + self->args[i].print_with_tree_printer(sub) def free(self) -> None: for i = 0; i < self->nargs; i++: @@ -390,7 +393,7 @@ class AstInstantiation: for i = 0; i < self->nfields; i++: sub = tp.print_prefix(i == self->nfields - 1) printf("field \"%s\": ", self->field_names[i]) - self->field_values[i].print(sub) + self->field_values[i].print_with_tree_printer(sub) def free(self) -> None: for i = 0; i < self->nfields; i++: @@ -398,6 +401,10 @@ class AstInstantiation: free(self->field_names) free(self->field_values) +class AstAssertion: + condition: AstExpression + condition_str: byte* + enum AstStatementKind: ExpressionStatement # Evaluate an expression. Discard the result. Assert @@ -436,21 +443,25 @@ class AstStatement: function: AstFunctionOrMethod classdef: AstClassDef enumdef: AstEnumDef + assertion: AstAssertion - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("[line %d] ", self->location.lineno) if self->kind == AstStatementKind::ExpressionStatement: printf("expression statement\n") - self->expression.print(tp.print_prefix(True)) + self->expression.print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstStatementKind::Assert: - printf("assert\n") - self->expression.print(tp.print_prefix(True)) + printf("assert \"%s\"\n", self->assertion.condition_str) + self->assertion.condition.print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstStatementKind::Pass: printf("pass\n") elif self->kind == AstStatementKind::Return: printf("return\n") if self->return_value != NULL: - self->return_value->print(tp.print_prefix(True)) + self->return_value->print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstStatementKind::If: printf("if\n") self->if_statement.print(tp) @@ -459,51 +470,51 @@ class AstStatement: self->for_loop.print(tp) elif self->kind == AstStatementKind::WhileLoop: printf("while loop\n") - self->while_loop.print(tp, True) + self->while_loop.print_with_tree_printer(tp, True) elif self->kind == AstStatementKind::Break: printf("break\n") elif self->kind == AstStatementKind::Continue: printf("continue\n") elif self->kind == AstStatementKind::DeclareLocalVar: printf("declare local var ") - self->var_declaration.print(&tp) + self->var_declaration.print_with_tree_printer(&tp) elif self->kind == AstStatementKind::Assign: printf("assign\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::InPlaceAdd: printf("in-place add\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::InPlaceSubtract: printf("in-place sub\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::InPlaceMultiply: printf("in-place mul\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::InPlaceDivide: printf("in-place div\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::InPlaceModulo: printf("in-place mod\n") - self->assignment.print(tp) + self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Function: if self->function.body.nstatements == 0: printf("declare a function: ") else: printf("define a function: ") - self->function.print(tp) + self->function.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Class: printf("define a ") - self->classdef.print(tp) + self->classdef.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Enum: printf("define ") - self->enumdef.print(tp) + self->enumdef.print_with_tree_printer(tp) elif self->kind == AstStatementKind::GlobalVariableDeclaration: printf("declare global var ") - self->var_declaration.print(NULL) + self->var_declaration.print_with_tree_printer(NULL) printf("\n") elif self->kind == AstStatementKind::GlobalVariableDefinition: printf("define global var ") - self->var_declaration.print(NULL) + self->var_declaration.print_with_tree_printer(NULL) printf("\n") else: printf("??????\n") @@ -526,14 +537,17 @@ class AstConditionAndBody: condition: AstExpression body: AstBody - def print(self, tp: TreePrinter, body_is_last_sub_item: bool) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}, True) + + def print_with_tree_printer(self, tp: TreePrinter, body_is_last_sub_item: bool) -> None: sub = tp.print_prefix(False) printf("condition: ") - self->condition.print(sub) + self->condition.print_with_tree_printer(sub) sub = tp.print_prefix(body_is_last_sub_item) printf("body:\n") - self->body.print(sub) + self->body.print_with_tree_printer(sub) def free(self) -> None: self->condition.free() @@ -543,9 +557,12 @@ class AstAssignment: target: AstExpression value: AstExpression - def print(self, tp: TreePrinter) -> None: - self->target.print(tp.print_prefix(False)) - self->value.print(tp.print_prefix(True)) + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + self->target.print_with_tree_printer(tp.print_prefix(False)) + self->value.print_with_tree_printer(tp.print_prefix(True)) class AstIfStatement: if_and_elifs: AstConditionAndBody* @@ -554,12 +571,12 @@ class AstIfStatement: def print(self, tp: TreePrinter) -> None: for i = 0; i < self->n_if_and_elifs; i++: - self->if_and_elifs[i].print(tp, i == self->n_if_and_elifs - 1 and self->else_body.nstatements == 0) + self->if_and_elifs[i].print_with_tree_printer(tp, i == self->n_if_and_elifs - 1 and self->else_body.nstatements == 0) if self->else_body.nstatements > 0: sub = tp.print_prefix(True) printf("else body:\n") - self->else_body.print(sub) + self->else_body.print_with_tree_printer(sub) def free(self) -> None: for i = 0; i < self->n_if_and_elifs; i++: @@ -580,19 +597,19 @@ class AstForLoop: def print(self, tp: TreePrinter) -> None: sub = tp.print_prefix(False) printf("init: ") - self->init->print(sub) + self->init->print_with_tree_printer(sub) sub = tp.print_prefix(False) printf("cond: ") - self->cond.print(sub) + self->cond.print_with_tree_printer(sub) sub = tp.print_prefix(False) printf("incr: ") - self->incr->print(sub) + self->incr->print_with_tree_printer(sub) sub = tp.print_prefix(True) printf("body:\n") - self->body.print(sub) + self->body.print_with_tree_printer(sub) def free(self) -> None: self->init->free() @@ -609,8 +626,12 @@ class AstNameTypeValue: type: AstType value: AstExpression* # can be NULL + def print(self) -> None: + tp = TreePrinter{} + self->print_with_tree_printer(&tp) + # tp can be set to NULL, in that case no trailing newline is printed - def print(self, tp: TreePrinter*) -> None: + def print_with_tree_printer(self, tp: TreePrinter*) -> None: printf("%s: ", self->name) self->type.print(True) if tp == NULL: @@ -620,7 +641,7 @@ class AstNameTypeValue: if self->value != NULL: sub = tp->print_prefix(True) printf("initial value: ") - self->value->print(sub) + self->value->print_with_tree_printer(sub) def free(self) -> None: if self->value != NULL: @@ -631,9 +652,12 @@ class AstBody: statements: AstStatement* nstatements: int - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: for i = 0; i < self->nstatements; i++: - self->statements[i].print(tp.print_prefix(i == self->nstatements - 1)) + self->statements[i].print_with_tree_printer(tp.print_prefix(i == self->nstatements - 1)) def free(self) -> None: for i = 0; i < self->nstatements; i++: @@ -656,7 +680,7 @@ class AstSignature: if strcmp(self->args[i].name, "self") == 0: printf("self") else: - self->args[i].print(NULL) + self->args[i].print_with_tree_printer(NULL) if self->takes_varargs: if self->nargs != 0: @@ -695,7 +719,7 @@ class AstFile: for i = 0; i < self->nimports; i++: self->imports[i].print() for i = 0; i < self->body.nstatements; i++: - self->body.statements[i].print(TreePrinter{}) + self->body.statements[i].print() def free(self) -> None: for i = 0; i < self->nimports; i++: @@ -707,9 +731,12 @@ class AstFunctionOrMethod: signature: AstSignature body: AstBody # empty body means declaration, otherwise it's a definition - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: self->signature.print() - self->body.print(tp) + self->body.print_with_tree_printer(tp) def free(self) -> None: self->signature.free() @@ -722,7 +749,7 @@ class AstUnionFields: def print(self, tp: TreePrinter) -> None: for i = 0; i < self->nfields; i++: subprinter = tp.print_prefix(i == self->nfields-1) - self->fields[i].print(&subprinter) # TODO: does this need to be optional/pointer? + self->fields[i].print_with_tree_printer(&subprinter) # TODO: does this need to be optional/pointer? def free(self) -> None: for i = 0; i < self->nfields; i++: @@ -744,7 +771,7 @@ class AstClassMember: def print(self, tp: TreePrinter) -> None: if self->kind == AstClassMemberKind::Field: printf("field ") - self->field.print(NULL) + self->field.print_with_tree_printer(NULL) printf("\n") elif self->kind == AstClassMemberKind::Union: printf("union:\n") @@ -752,7 +779,7 @@ class AstClassMember: elif self->kind == AstClassMemberKind::Method: printf("method ") self->method.signature.print() - self->method.body.print(tp) + self->method.body.print_with_tree_printer(tp) else: assert False @@ -772,7 +799,10 @@ class AstClassDef: members: AstClassMember* nmembers: int - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("class \"%s\" with %d members\n", self->name, self->nmembers) for i = 0; i < self->nmembers; i++: self->members[i].print(tp.print_prefix(i == self->nmembers-1)) @@ -788,7 +818,10 @@ class AstEnumDef: member_count: int member_names: byte[100]* - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("enum \"%s\" with %d members\n", self->name, self->member_count) for i = 0; i < self->member_count; i++: tp.print_prefix(i == self->member_count-1) diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index c2e9f2aa..dd5e7b92 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -5,6 +5,7 @@ import "./types.jou" import "./ast.jou" import "./target.jou" import "./evaluate.jou" +import "./errors_and_warnings.jou" import "stdlib/io.jou" import "stdlib/mem.jou" import "stdlib/str.jou" @@ -131,7 +132,7 @@ class AstToIR: for i = 0; i < self->n_local_vars; i++: if strcmp(self->local_vars[i].name, name) == 0: return self->local_vars[i].pointer - assert False + return NULL def declare_function(self, sig: Signature*) -> LLVMValue*: full_name: byte[200] @@ -321,7 +322,7 @@ class AstToIR: printf("%s %d %s\n", lhs_type->name, op, rhs_type->name) assert False - def do_assert(self, condition: LLVMValue*) -> None: + def do_assert(self, location: Location, condition: LLVMValue*, condition_str: byte*) -> None: true_block = LLVMAppendBasicBlock(self->llvm_function, "assert_true") false_block = LLVMAppendBasicBlock(self->llvm_function, "assert_false") LLVMBuildCondBr(self->builder, condition, true_block, false_block) @@ -336,12 +337,13 @@ class AstToIR: assert assert_fail_func != NULL args = [ - self->make_a_string_constant("foo", -1), - self->make_a_string_constant("bar", -1), - LLVMConstInt(LLVMInt32Type(), 123, False as int), + self->make_a_string_constant(condition_str, -1), + self->make_a_string_constant(location.path, -1), + LLVMConstInt(LLVMInt32Type(), location.lineno, False as int), ] LLVMBuildCall2(self->builder, assert_fail_func_type, assert_fail_func, args, 3, "") + LLVMBuildUnreachable(self->builder) LLVMPositionBuilderAtEnd(self->builder, true_block) def do_incr_decr(self, value_expression: AstExpression*, diff: int) -> LLVMValue*[2]: @@ -363,7 +365,13 @@ class AstToIR: def do_address_of_expression(self, ast: AstExpression*) -> LLVMValue*: if ast->kind == AstExpressionKind::GetVariable: - return self->get_local_var_pointer(ast->varname) + local_var = self->get_local_var_pointer(ast->varname) + if local_var != NULL: + return local_var + + global_var = LLVMGetNamedGlobal(self->module, ast->varname) + assert global_var != NULL + return global_var if ast->kind == AstExpressionKind::GetClassField: lhs_type = self->function_or_method_types->get_expression_types(ast->class_field.instance)->original_type @@ -402,6 +410,10 @@ class AstToIR: index = self->do_expression(&ast->operands[1]) return LLVMBuildGEP(self->builder, pointer, &index, 1, "indexing_ptr") + if ast->kind == AstExpressionKind::Dereference: + return self->do_expression(&ast->operands[0]) + + ast->print() printf("expression kind (taking address): %d\n", ast->kind) assert False @@ -619,14 +631,29 @@ class AstToIR: elif ast->kind == AstExpressionKind::Self: self_ptr = self->get_local_var_pointer("self") + assert self_ptr != NULL result = LLVMBuildLoad(self->builder, self_ptr, "self") elif ast->kind == AstExpressionKind::And: result = self->do_and_or(&ast->operands[0], &ast->operands[1], True) elif ast->kind == AstExpressionKind::Or: result = self->do_and_or(&ast->operands[0], &ast->operands[1], False) + elif ast->kind == AstExpressionKind::Not: + # compile "not x" as "x == False" + value = self->do_expression(&ast->operands[0]) + false = LLVMConstInt(LLVMInt1Type(), 0, False as int) + result = LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, value, false, "not") + + elif ast->kind == AstExpressionKind::Dereference: + result = LLVMBuildLoad(self->builder, self->do_expression(&ast->operands[0]), "ptr_deref") + + elif ast->kind == AstExpressionKind::SizeOf: + types = self->function_or_method_types->get_expression_types(&ast->operands[0]) + assert types != NULL + result = LLVMSizeOf(type_to_llvm(types->get_type_after_implicit_cast())) else: + ast->print() printf("create_llvm_ir: unknown expression kind %d...\n", ast->kind) assert False @@ -716,8 +743,7 @@ class AstToIR: # If more code follows, place it into a new block that never actually runs self->new_block("after_return") elif ast->kind == AstStatementKind::Assert: - condition = self->do_expression(&ast->expression) - self->do_assert(condition) + self->do_assert(ast->location, self->do_expression(&ast->assertion.condition), ast->assertion.condition_str) elif ast->kind == AstStatementKind::Pass: pass elif ast->kind == AstStatementKind::Assign: @@ -727,6 +753,7 @@ class AstToIR: elif ast->kind == AstStatementKind::DeclareLocalVar: if ast->var_declaration.value != NULL: target_pointer = self->get_local_var_pointer(ast->var_declaration.name) + assert target_pointer != NULL value = self->do_expression(ast->var_declaration.value) LLVMBuildStore(self->builder, value, target_pointer) elif ast->kind == AstStatementKind::If: @@ -759,6 +786,18 @@ class AstToIR: for i = 0; i < body->nstatements; i++: self->do_statement(&body->statements[i]) + def call_the_special_startup_function(self) -> None: + if WINDOWS: + name = "_jou_windows_startup" + elif MACOS: + name = "_jou_macos_startup" + else: + return + + functype = LLVMFunctionType(LLVMVoidType(), NULL, 0, False as int) + func = LLVMAddFunction(self->module, name, functype) + LLVMBuildCall2(self->builder, functype, func, NULL, 0, "") + def define_function_or_method(self, funcdef: AstFunctionOrMethod*, self_type: Type*) -> None: assert self->function_or_method_types == NULL self->function_or_method_types = self->file_types->find_defined_function_or_method(funcdef->signature.name, self_type) @@ -769,6 +808,9 @@ class AstToIR: self->llvm_function = self->declare_function(sig) self->new_block("start") + if (WINDOWS or MACOS) and strcmp(sig->name, "main") == 0 and not sig->is_method(): + self->call_the_special_startup_function() + # Allocate all local variables at the start of the function. assert self->n_local_vars == 0 assert self->local_vars == NULL @@ -801,27 +843,46 @@ class AstToIR: self->function_or_method_types = NULL +# This distinguishes defined global variables from: +# - imported global variables +# - declared global variables +def file_defines_global_var(ast: AstFile*, name: byte*) -> bool: + for s = ast->body.statements; s < &ast->body.statements[ast->body.nstatements]; s++: + if s->kind == AstStatementKind::GlobalVariableDefinition and strcmp(s->var_declaration.name, name) == 0: + return True + return False + + def create_llvm_ir(ast: AstFile*, ft: FileTypes*) -> LLVMModule*: module = LLVMModuleCreateWithName(ast->path) LLVMSetTarget(module, target.triple) LLVMSetDataLayout(module, target.data_layout) + for v = ft->globals; v < &ft->globals[ft->nglobals]; v++: + t = type_to_llvm(v->type) + globalptr = LLVMAddGlobal(module, t, v->name) + if file_defines_global_var(ast, v->name): + LLVMSetInitializer(globalptr, LLVMConstNull(t)) + a2i = AstToIR{ module = module, builder = LLVMCreateBuilder(), file_types = ft } - for i = 0; i < ast->body.nstatements; i++: - if ast->body.statements[i].kind == AstStatementKind::Function and ast->body.statements[i].function.body.nstatements > 0: - a2i.define_function_or_method(&ast->body.statements[i].function, NULL) - elif ast->body.statements[i].kind == AstStatementKind::Class: - classdef = &ast->body.statements[i].classdef + for s = ast->body.statements; s < &ast->body.statements[ast->body.nstatements]; s++: + if s->kind == AstStatementKind::Function and s->function.body.nstatements > 0: + a2i.define_function_or_method(&s->function, NULL) + elif s->kind == AstStatementKind::Class: + classdef = &s->classdef class_type = ft->find_type(classdef->name) assert class_type != NULL for k = 0; k < classdef->nmembers; k++: if classdef->members[k].kind == AstClassMemberKind::Method: a2i.define_function_or_method(&classdef->members[k].method, class_type) + else: + # TODO: need to handle some others? + pass LLVMDisposeBuilder(a2i.builder) return module diff --git a/self_hosted/main.jou b/self_hosted/main.jou index 4f39c145..3e29b028 100644 --- a/self_hosted/main.jou +++ b/self_hosted/main.jou @@ -175,6 +175,9 @@ class Compiler: if WINDOWS: self->automagic_files[1] = malloc(strlen(self->stdlib_path) + 40) sprintf(self->automagic_files[1], "%s/_windows_startup.jou", self->stdlib_path) + if MACOS: + self->automagic_files[1] = malloc(strlen(self->stdlib_path) + 40) + sprintf(self->automagic_files[1], "%s/_macos_startup.jou", self->stdlib_path) def parse_all_files(self) -> None: queue: ParseQueueItem* = malloc(50 * sizeof queue[0]) @@ -428,7 +431,8 @@ class Compiler: ret = system(command) if ret != 0: - fprintf(stderr, "%s: running the program failed\n", self->argv0) + # TODO: print something? The shell doesn't print stuff + # like "Segmentation fault" on Windows afaik exit(1) diff --git a/self_hosted/parser.jou b/self_hosted/parser.jou index e3f8bfa3..f9eadf5c 100644 --- a/self_hosted/parser.jou +++ b/self_hosted/parser.jou @@ -1,4 +1,6 @@ +import "stdlib/ascii.jou" import "stdlib/str.jou" +import "stdlib/io.jou" import "stdlib/mem.jou" import "./token.jou" import "./ast.jou" @@ -150,6 +152,46 @@ def check_class_for_duplicate_names(classdef: AstClassDef*) -> None: fail(p2->name_location, message) +# TODO: this function is just bad... +def read_assertion_from_file(start: Location, end: Location) -> byte*: + assert start.path == end.path + + f = fopen(start.path, "rb") + assert f != NULL + + line: byte[1024] + lineno = 1 + while lineno < start.lineno: + assert fgets(line, sizeof(line) as int, f) != NULL + lineno++ + + result: byte* = malloc(2000 * (end.lineno - start.lineno + 1)) + result[0] = '\0' + + while lineno <= end.lineno: + assert fgets(line, sizeof(line) as int, f) != NULL + lineno++ + + # TODO: strings containing '#' ... so much wrong with dis + if strstr(line, "#") != NULL: + *strstr(line, "#") = '\0' + trim_ascii_whitespace(line) + + # Add spaces between lines, but not after '(' or before ')' + if not starts_with(line, ")") and not ends_with(result, "("): + strcat(result, " ") + strcat(result, line) + + fclose(f) + + trim_ascii_whitespace(result) + if starts_with(result, "assert"): + memmove(result, &result[6], strlen(&result[6]) + 1) + trim_ascii_whitespace(result) + + return result + + class Parser: tokens: Token* stdlib_path: byte* @@ -710,7 +752,10 @@ class Parser: elif self->tokens->is_keyword("assert"): self->tokens++ result.kind = AstStatementKind::Assert - result.expression = self->parse_expression() + start = self->tokens->location + result.assertion.condition = self->parse_expression() + end = self->tokens->location + result.assertion.condition_str = read_assertion_from_file(start, end) elif self->tokens->is_keyword("pass"): self->tokens++ result.kind = AstStatementKind::Pass diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index 26580344..02c37c70 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -1,52 +1,16 @@ # This is a list of files that don't behave correctly when ran with the self-hosted compiler. -stdlib/errno.jou -tests/404/indirect_import_symbol.jou -tests/crash/null_deref.jou -tests/other_errors/address_of_minusminus.jou -tests/other_errors/array_literal_as_a_pointer.jou -tests/other_errors/array0.jou -tests/other_errors/assert_fail.jou -tests/other_errors/assert_fail_multiline.jou -tests/other_errors/instantiation_address_of_field.jou tests/other_errors/missing_return.jou tests/other_errors/missing_value_in_return.jou tests/other_errors/noreturn_but_return_with_value.jou tests/other_errors/noreturn_but_return_without_value.jou -tests/other_errors/runtime_return_1.jou -tests/other_errors/var_shadow.jou -tests/should_succeed/and_or_not.jou -tests/should_succeed/as.jou tests/should_succeed/compiler_cli.jou -tests/should_succeed/errno_test.jou -tests/should_succeed/file.jou -tests/should_succeed/global.jou -tests/should_succeed/global_bug.jou -tests/should_succeed/if_elif_else.jou -tests/should_succeed/implicit_conversions.jou -tests/should_succeed/imported/point_factory.jou -tests/should_succeed/indirect_method_import.jou tests/should_succeed/linked_list.jou -tests/should_succeed/local_import.jou -tests/should_succeed/loops.jou -tests/should_succeed/plusplus_minusminus.jou tests/should_succeed/pointer.jou tests/should_succeed/printf.jou tests/other_errors/return_void.jou -tests/should_succeed/sizeof.jou tests/should_succeed/stderr.jou -tests/should_succeed/undefined_value_warning.jou -tests/should_succeed/union.jou -tests/should_succeed/unreachable_warning.jou tests/should_succeed/unused_import.jou -tests/syntax_error/bad_addressof.jou -tests/wrong_type/assert.jou tests/wrong_type/cannot_be_indexed.jou tests/wrong_type/index.jou -stdlib/ascii.jou -tests/should_succeed/ascii_test.jou -stdlib/_macos_startup.jou -tests/should_succeed/if_WINDOWS_at_runtime.jou -tests/should_succeed/return_none.jou tests/syntax_error/assign_to_None.jou tests/syntax_error/None_as_value.jou -tests/should_succeed/unused_variable.jou diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index ddfc60e9..27e64fce 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -37,6 +37,8 @@ def can_cast_implicitly(from: Type*, to: Type*) -> bool: and from->size_in_bits < to->size_in_bits and not (from->kind == TypeKind::SignedInteger and to->kind == TypeKind::UnsignedInteger) ) + or (from == &float_type and to == &double_type) + or (from->is_integer_type() and to->kind == TypeKind::FloatingPoint) or (from->is_pointer_type() and to->is_pointer_type() and (from == &void_ptr_type or to == &void_ptr_type)) ) @@ -101,6 +103,18 @@ class ExportSymbol: signature: Signature # ExportSymbolKind::Function type: Type* # ExportSymbolKind::Type, ExportSymbolKind::GlobalVariable + def print(self) -> None: + if self->kind == ExportSymbolKind::Function: + s = self->signature.to_string(True, True) + printf("ExportSymbol: function %s\n", s) + free(s) + elif self->kind == ExportSymbolKind::Type: + printf("ExportSymbol: type %s as \"%s\"\n", self->type->name, self->name) + elif self->kind == ExportSymbolKind::GlobalVariable: + printf("ExportSymbol: variable %s: %s\n", self->name, self->type->name) + else: + assert False + class ExpressionTypes: expression: AstExpression* original_type: Type* @@ -216,11 +230,17 @@ class FileTypes: nglobals: int def add_imported_symbol(self, symbol: ExportSymbol*) -> None: - if symbol->kind != ExportSymbolKind::Function: - # TODO - return - self->all_functions = realloc(self->all_functions, sizeof self->all_functions[0] * (self->n_all_functions + 1)) - self->all_functions[self->n_all_functions++] = symbol->signature.copy() + if symbol->kind == ExportSymbolKind::Type: + self->types = realloc(self->types, (self->ntypes + 1) * sizeof(self->types[0])) + self->types[self->ntypes++] = symbol->type + elif symbol->kind == ExportSymbolKind::Function: + self->all_functions = realloc(self->all_functions, sizeof self->all_functions[0] * (self->n_all_functions + 1)) + self->all_functions[self->n_all_functions++] = symbol->signature.copy() + elif symbol->kind == ExportSymbolKind::GlobalVariable: + pass # TODO + else: + symbol->print() + assert False def find_function(self, name: byte*) -> Signature*: for i = 0; i < self->n_all_functions; i++: @@ -335,6 +355,8 @@ def type_from_ast(ft: FileTypes*, ast_type: AstType*) -> Type*: if ast_type->kind == AstTypeKind::Array: member_type = type_from_ast(ft, ast_type->array.member_type) length = evaluate_array_length(ast_type->array.length) + if length <= 0: + fail(ast_type->array.length->location, "array length must be positive") return member_type->get_array_type(length) ast_type->print(True) @@ -577,7 +599,10 @@ def short_expression_description(expr: AstExpression*) -> byte[200]: return "an indexed value" elif expr->kind == AstExpressionKind::Self: return "self" + elif expr->kind == AstExpressionKind::Array: + return "an array literal" else: + expr->print() printf("*** %d\n", expr->kind) assert False @@ -787,7 +812,9 @@ class Stage3TypeChecker: local_var = self->current_function_or_method->find_local_var(name) if local_var != NULL: return local_var->type - # TODO: check global vars (they don't exist yet) + for i = 0; i < self->file_types->nglobals; i++: + if strcmp(self->file_types->globals[i].name, name) == 0: + return self->file_types->globals[i].type return NULL def find_function_or_method(self, self_type: Type*, name: byte*) -> Signature*: @@ -1114,7 +1141,8 @@ class Stage3TypeChecker: fail(expression->location, message) result = check_class_field(expression->location, lhs_type, expression->class_field.field_name)->type elif expression->kind == AstExpressionKind::AddressOf: - result = self->do_expression(expression->operands)->original_type->get_pointer_type() + ensure_can_take_address(&expression->operands[0], "the '&' operator cannot be used with %s") + result = self->do_expression(&expression->operands[0])->original_type->get_pointer_type() elif expression->kind == AstExpressionKind::Dereference: pointer_type = self->do_expression(expression->operands)->original_type if pointer_type->kind != TypeKind::Pointer: @@ -1139,8 +1167,12 @@ class Stage3TypeChecker: class_type = self->current_function_or_method->signature.get_containing_class() assert class_type != NULL result = class_type->get_pointer_type() + elif expression->kind == AstExpressionKind::SizeOf: + self->do_expression(&expression->operands[0]) + result = long_type else: printf("*** expr %d\n", expression->kind as int) + expression->print() assert False p: ExpressionTypes* = malloc(sizeof *p) @@ -1197,7 +1229,12 @@ class Stage3TypeChecker: assert target_types->implicit_cast_type == NULL def do_statement(self, statement: AstStatement*) -> None: - if statement->kind == AstStatementKind::ExpressionStatement: + if statement->kind == AstStatementKind::Assert: + self->do_expression_and_implicit_cast( + &statement->expression, &bool_type, "assertion must be a boolean, not " + ) + + elif statement->kind == AstStatementKind::ExpressionStatement: self->do_expression_maybe_void(&statement->expression) elif statement->kind == AstStatementKind::Return: @@ -1332,6 +1369,7 @@ class Stage3TypeChecker: self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Modulo, "modulo") else: + statement->print() printf("*** typecheck: unknown statement kind %d\n", statement->kind) assert False diff --git a/self_hosted/types.jou b/self_hosted/types.jou index e393048a..5656fa0e 100644 --- a/self_hosted/types.jou +++ b/self_hosted/types.jou @@ -66,7 +66,7 @@ class Type: # Pointers and arrays of a given type live as long as the type itself. # To make it possible, we just store them within the type. - # These are initially NULL and created in dynamically as needed. + # These are initially NULL and created dynamically as needed. # # Do not access these outside this file. cached_pointer_type: Type* diff --git a/src/build_cfg.c b/src/build_cfg.c index f6604a23..560a46cc 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -722,47 +722,9 @@ static void build_if_statement(struct State *st, const AstIfStatement *ifstmt) add_jump(st, NULL, done, done, done); } -// TODO: this function is just bad... -static char *read_assertion_from_file(Location start, Location end) -{ - assert(start.filename == end.filename); - FILE *f = fopen(start.filename, "rb"); - assert(f); - - char line[1024]; - int lineno = 1; - while (lineno < start.lineno) { - fgets(line, sizeof line, f); - lineno++; - } - - List(char) str = {0}; - while (lineno <= end.lineno) { - memset(line, 0, sizeof line); - fgets(line, sizeof line, f); - lineno++; - - if (strstr(line, "#")) - *strstr(line, "#") = '\0'; - trim_whitespace(line); - // Add spaces between the lines, but not after '(' or before ')' - if (line[0] != ')' && str.len >= 1 && str.ptr[str.len-1] != '(') - AppendStr(&str, " "); - AppendStr(&str, line); - } - - fclose(f); - Append(&str, '\0'); - - if(!strncmp(str.ptr, "assert",6)) - memmove(str.ptr, &str.ptr[6], strlen(&str.ptr[6]) + 1); - trim_whitespace(str.ptr); - return str.ptr; -} - static void build_assert(struct State *st, Location assert_location, const AstAssert *assertion) { - const LocalVariable *condvar = build_expression(st, &assertion->expression); + const LocalVariable *condvar = build_expression(st, &assertion->condition); // If the condition is true, we jump to a block where the rest of the code goes. // If the condition is false, we jump to a block that calls _jou_assert_fail(). @@ -785,10 +747,8 @@ static void build_assert(struct State *st, Location assert_location, const AstAs args[i] = add_local_var(st, argtypes[i]); args[3] = NULL; - char *tmp = read_assertion_from_file(assertion->expression_start, assertion->expression_end); - add_constant(st, assert_location, ((Constant){CONSTANT_STRING,{.str=tmp}}), args[0]); - free(tmp); - tmp = strdup(assertion->expression_start.filename); + add_constant(st, assert_location, ((Constant){CONSTANT_STRING,{.str=assertion->condition_str}}), args[0]); + char *tmp = strdup(assertion->condition.location.filename); add_constant(st, assert_location, ((Constant){CONSTANT_STRING,{.str=tmp}}), args[1]); free(tmp); add_constant(st, assert_location, int_constant(intType, assert_location.lineno), args[2]); diff --git a/src/free.c b/src/free.c index 840165bb..e16356b3 100644 --- a/src/free.c +++ b/src/free.c @@ -155,8 +155,11 @@ void free_ast_statement(const AstStatement *stmt) free(stmt->data.forloop.incr); free_ast_body(&stmt->data.forloop.body); break; - case AST_STMT_EXPRESSION_STATEMENT: case AST_STMT_ASSERT: + free_expression(&stmt->data.assertion.condition); + free(stmt->data.assertion.condition_str); + break; + case AST_STMT_EXPRESSION_STATEMENT: free_expression(&stmt->data.expression); break; case AST_STMT_RETURN: diff --git a/src/jou_compiler.h b/src/jou_compiler.h index 90053147..fc6f2f61 100644 --- a/src/jou_compiler.h +++ b/src/jou_compiler.h @@ -274,8 +274,8 @@ struct AstAssignment { AstExpression value; }; struct AstAssert { - AstExpression expression; - Location expression_start, expression_end; + AstExpression condition; + char *condition_str; }; struct AstFunction { diff --git a/src/parse.c b/src/parse.c index 388a12fa..5379e3aa 100644 --- a/src/parse.c +++ b/src/parse.c @@ -721,6 +721,44 @@ static enum AstStatementKind determine_the_kind_of_a_statement_that_starts_with_ return AST_STMT_EXPRESSION_STATEMENT; } +// TODO: this function is just bad... +static char *read_assertion_from_file(Location start, Location end) +{ + assert(start.filename == end.filename); + FILE *f = fopen(start.filename, "rb"); + assert(f); + + char line[1024]; + int lineno = 1; + while (lineno < start.lineno) { + fgets(line, sizeof line, f); + lineno++; + } + + List(char) str = {0}; + while (lineno <= end.lineno) { + memset(line, 0, sizeof line); + fgets(line, sizeof line, f); + lineno++; + + if (strstr(line, "#")) + *strstr(line, "#") = '\0'; + trim_whitespace(line); + // Add spaces between the lines, but not after '(' or before ')' + if (line[0] != ')' && str.len >= 1 && str.ptr[str.len-1] != '(') + AppendStr(&str, " "); + AppendStr(&str, line); + } + + fclose(f); + Append(&str, '\0'); + + if(!strncmp(str.ptr, "assert",6)) + memmove(str.ptr, &str.ptr[6], strlen(&str.ptr[6]) + 1); + trim_whitespace(str.ptr); + return str.ptr; +} + // does not eat a trailing newline static AstStatement parse_oneline_statement(ParserState *ps) { @@ -735,9 +773,10 @@ static AstStatement parse_oneline_statement(ParserState *ps) } else if (is_keyword(ps->tokens, "assert")) { ps->tokens++; result.kind = AST_STMT_ASSERT; - result.data.assertion.expression_start = ps->tokens->location; - result.data.assertion.expression = parse_expression(ps); - result.data.assertion.expression_end = ps->tokens->location; + Location start = ps->tokens->location; + result.data.assertion.condition = parse_expression(ps); + Location end = ps->tokens->location; + result.data.assertion.condition_str = read_assertion_from_file(start, end); } else if (is_keyword(ps->tokens, "pass")) { ps->tokens++; result.kind = AST_STMT_PASS; diff --git a/src/print.c b/src/print.c index 8585b7e7..a74db610 100644 --- a/src/print.c +++ b/src/print.c @@ -311,8 +311,8 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp) print_ast_expression(&stmt->data.expression, print_tree_prefix(tp, true)); break; case AST_STMT_ASSERT: - printf("assert\n"); - print_ast_expression(&stmt->data.assertion.expression, print_tree_prefix(tp, true)); + printf("assert \"%s\"\n", stmt->data.assertion.condition_str); + print_ast_expression(&stmt->data.assertion.condition, print_tree_prefix(tp, true)); break; case AST_STMT_RETURN: printf("return\n");