diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index dd5e7b92..21ccdc2b 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -65,11 +65,16 @@ def create_llvm_union_type(types: LLVMType**, ntypes: int) -> LLVMType*: return LLVMArrayType(LLVMInt64Type(), ((size_needed + 7) / 8) as int) # ceil division + +# Pointers in classes are stored as i8*, so that a struct can contain a pointer to itself. +def field_uses_i8_ptr_hack(field: ClassField*) -> bool: + return field->type->kind == TypeKind::Pointer + + def class_type_to_llvm(fields: ClassField*, nfields: int) -> LLVMType*: elem_types: LLVMType** = malloc(nfields * sizeof elem_types[0]) for i = 0; i < nfields; i++: - # Store all pointers in structs as i8*, so that a struct can contain a pointer to itself for example. - if fields[i].type->kind == TypeKind::Pointer: + if field_uses_i8_ptr_hack(&fields[i]): elem_types[i] = LLVMPointerType(LLVMInt8Type(), 0) else: elem_types[i] = type_to_llvm(fields[i].type) @@ -224,6 +229,13 @@ class AstToIR: printf("unimplemented cast: %s --> %s\n", from->name, to->name) assert False + # Makes a temporary pointer, places the value there, then casts and reads the pointer. + def do_cast_through_pointers(self, value: LLVMValue*, to: LLVMType*) -> LLVMValue*: + p1 = LLVMBuildAlloca(self->builder, LLVMTypeOf(value), "cast_through_ptr_temp") + LLVMBuildStore(self->builder, value, p1) + p2 = LLVMBuildBitCast(self->builder, p1, LLVMPointerType(to, 0), "cast_through_ptr_temp") + return LLVMBuildLoad(self->builder, p2, "cast_through_ptr_result") + def do_binop( self, op: AstExpressionKind, @@ -390,19 +402,21 @@ class AstToIR: assert field != NULL field_pointer = LLVMBuildStructGEP2( self->builder, - type_to_llvm(class_type), instance_pointer, + type_to_llvm(class_type), + instance_pointer, field->union_id, field->name, ) - # This cast is needed for two reasons two cases: - # * All pointers are i8* in structs so we can do self-referencing classes. - # * This is how unions work. - return LLVMBuildBitCast( - self->builder, - field_pointer, LLVMPointerType(type_to_llvm(field->type),0), - "struct_member_cast", - ) + if field_uses_i8_ptr_hack(field) or field->belongs_to_union: + field_pointer = LLVMBuildBitCast( + self->builder, + field_pointer, + LLVMPointerType(type_to_llvm(field->type), 0), + "class_field_ptr_cast", + ) + + return field_pointer if ast->kind == AstExpressionKind::Indexing: # &pointer[index] = pointer + some offset @@ -568,6 +582,8 @@ class AstToIR: field = instance_type->class_members.find_field(ast->instantiation.field_names[i]) assert field != NULL value = self->do_expression(&ast->instantiation.field_values[i]) + if field_uses_i8_ptr_hack(field) or field->belongs_to_union: + value = self->do_cast_through_pointers(value, LLVMStructGetTypeAtIndex(type_to_llvm(instance_type), field->union_id)) result = LLVMBuildInsertValue(self->builder, result, value, field->union_id, "instance") elif ast->kind == AstExpressionKind::GetClassField: @@ -584,6 +600,9 @@ class AstToIR: assert field != NULL result = LLVMBuildExtractValue(self->builder, instance, field->union_id, field->name) + if field_uses_i8_ptr_hack(field) or field->belongs_to_union: + result = self->do_cast_through_pointers(result, type_to_llvm(field->type)) + elif ast->kind == AstExpressionKind::GetVariable: v = get_special_constant(ast->varname) if v == -1: diff --git a/self_hosted/llvm.jou b/self_hosted/llvm.jou index 8001c04d..c4423e61 100644 --- a/self_hosted/llvm.jou +++ b/self_hosted/llvm.jou @@ -198,6 +198,7 @@ declare LLVMDisposeModule(M: LLVMModule*) -> None declare LLVMGetSourceFileName(M: LLVMModule*, Len: long*) -> byte* # Return value not owned declare LLVMSetDataLayout(M: LLVMModule*, DataLayoutStr: byte*) -> None declare LLVMSetTarget(M: LLVMModule*, Triple: byte*) -> None +declare LLVMDumpType(Val: LLVMType*) -> None declare LLVMDumpModule(M: LLVMModule*) -> None declare LLVMPrintModuleToString(M: LLVMModule*) -> byte* declare LLVMAddFunction(M: LLVMModule*, Name: byte*, FunctionTy: LLVMType*) -> LLVMValue* @@ -212,6 +213,7 @@ declare LLVMIntType(NumBits: int) -> LLVMType* declare LLVMGetReturnType(FunctionTy: LLVMType*) -> LLVMType* declare LLVMGetParam(Fn: LLVMValue*, Index: int) -> LLVMValue* declare LLVMGetElementType(Ty: LLVMType*) -> LLVMType* +declare LLVMStructGetTypeAtIndex(StructTy: LLVMType*, i: int) -> LLVMType* declare LLVMTypeOf(Val: LLVMValue*) -> LLVMType* declare LLVMConstNull(Ty: LLVMType*) -> LLVMValue* declare LLVMGetUndef(Ty: LLVMType*) -> LLVMValue* diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index 02c37c70..d6ecf9bb 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -4,7 +4,6 @@ tests/other_errors/missing_value_in_return.jou tests/other_errors/noreturn_but_return_with_value.jou tests/other_errors/noreturn_but_return_without_value.jou tests/should_succeed/compiler_cli.jou -tests/should_succeed/linked_list.jou tests/should_succeed/pointer.jou tests/should_succeed/printf.jou tests/other_errors/return_void.jou diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 2b9d18f5..9265f5a8 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -419,6 +419,7 @@ def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None: name = member->field.name, type = type_from_ast(ft, &member->field.type), union_id = union_id++, + belongs_to_union = False, } elif member->kind == AstClassMemberKind::Union: uid = union_id++ @@ -428,6 +429,7 @@ def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None: name = member->union_fields.fields[k].name, type = type_from_ast(ft, &member->union_fields.fields[k].type), union_id = uid, + belongs_to_union = True, } elif member->kind == AstClassMemberKind::Method: # Don't handle the method body yet: that is a part of stage 3, not stage 2 diff --git a/self_hosted/types.jou b/self_hosted/types.jou index 5656fa0e..7f7158ed 100644 --- a/self_hosted/types.jou +++ b/self_hosted/types.jou @@ -30,6 +30,7 @@ class ClassField: # If multiple fields have the same union_id, they belong to the same union. # It means that only one of the fields can be used at a time. union_id: int + belongs_to_union: bool # are there more fields with same union_id class ClassMembers: fields: ClassField* diff --git a/src/codegen.c b/src/codegen.c index 313b5de4..e2d75753 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -78,8 +78,8 @@ static LLVMTypeRef codegen_type(const Type *type) LLVMTypeRef *flat_elems = malloc(sizeof(flat_elems[0]) * n); // NOLINT for (int i = 0; i < n; i++) { - // Treat all pointers inside structs as if they were void*. - // This allows structs to contain pointers to themselves. + // Treat all pointers inside classes as if they were void*. + // This allows classes to contain pointers to themselves. if (type->data.classdata.fields.ptr[i].type->kind == TYPE_POINTER) flat_elems[i] = codegen_type(voidPtrType); else @@ -323,7 +323,7 @@ static void codegen_instruction(const struct State *st, const CfInstruction *ins LLVMValueRef val = LLVMBuildStructGEP2(st->builder, codegen_type(classtype), getop(0), f->union_id, ins->data.fieldname); // This cast is needed in two cases: - // * All pointers are i8* in structs so we can do self-referencing classes. + // * All pointers are i8* in classes so we can do self-referencing classes. // * This is how unions work. val = LLVMBuildBitCast(st->builder, val, LLVMPointerType(codegen_type(f->type),0), "struct_member_cast"); setdest(val);