Skip to content

Commit

Permalink
class field pointer hell
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Dec 16, 2023
1 parent 06c74f5 commit 9ab14ca
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
41 changes: 30 additions & 11 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,16 @@ def create_llvm_union_type(types: LLVMType**, ntypes: int) -> LLVMType*:

return LLVMArrayType(LLVMInt64Type(), ((size_needed + 7) / 8) as int) # ceil division


# Pointers in classes are stored as i8*, so that a struct can contain a pointer to itself.
def field_uses_i8_ptr_hack(field: ClassField*) -> bool:
return field->type->kind == TypeKind::Pointer


def class_type_to_llvm(fields: ClassField*, nfields: int) -> LLVMType*:
elem_types: LLVMType** = malloc(nfields * sizeof elem_types[0])
for i = 0; i < nfields; i++:
# Store all pointers in structs as i8*, so that a struct can contain a pointer to itself for example.
if fields[i].type->kind == TypeKind::Pointer:
if field_uses_i8_ptr_hack(&fields[i]):
elem_types[i] = LLVMPointerType(LLVMInt8Type(), 0)
else:
elem_types[i] = type_to_llvm(fields[i].type)
Expand Down Expand Up @@ -224,6 +229,13 @@ class AstToIR:
printf("unimplemented cast: %s --> %s\n", from->name, to->name)
assert False

# Makes a temporary pointer, places the value there, then casts and reads the pointer.
def do_cast_through_pointers(self, value: LLVMValue*, to: LLVMType*) -> LLVMValue*:
p1 = LLVMBuildAlloca(self->builder, LLVMTypeOf(value), "cast_through_ptr_temp")
LLVMBuildStore(self->builder, value, p1)
p2 = LLVMBuildBitCast(self->builder, p1, LLVMPointerType(to, 0), "cast_through_ptr_temp")
return LLVMBuildLoad(self->builder, p2, "cast_through_ptr_result")

def do_binop(
self,
op: AstExpressionKind,
Expand Down Expand Up @@ -390,19 +402,21 @@ class AstToIR:
assert field != NULL
field_pointer = LLVMBuildStructGEP2(
self->builder,
type_to_llvm(class_type), instance_pointer,
type_to_llvm(class_type),
instance_pointer,
field->union_id,
field->name,
)

# This cast is needed for two reasons two cases:
# * All pointers are i8* in structs so we can do self-referencing classes.
# * This is how unions work.
return LLVMBuildBitCast(
self->builder,
field_pointer, LLVMPointerType(type_to_llvm(field->type),0),
"struct_member_cast",
)
if field_uses_i8_ptr_hack(field) or field->belongs_to_union:
field_pointer = LLVMBuildBitCast(
self->builder,
field_pointer,
LLVMPointerType(type_to_llvm(field->type), 0),
"class_field_ptr_cast",
)

return field_pointer

if ast->kind == AstExpressionKind::Indexing:
# &pointer[index] = pointer + some offset
Expand Down Expand Up @@ -568,6 +582,8 @@ class AstToIR:
field = instance_type->class_members.find_field(ast->instantiation.field_names[i])
assert field != NULL
value = self->do_expression(&ast->instantiation.field_values[i])
if field_uses_i8_ptr_hack(field) or field->belongs_to_union:
value = self->do_cast_through_pointers(value, LLVMStructGetTypeAtIndex(type_to_llvm(instance_type), field->union_id))
result = LLVMBuildInsertValue(self->builder, result, value, field->union_id, "instance")

elif ast->kind == AstExpressionKind::GetClassField:
Expand All @@ -584,6 +600,9 @@ class AstToIR:
assert field != NULL
result = LLVMBuildExtractValue(self->builder, instance, field->union_id, field->name)

if field_uses_i8_ptr_hack(field) or field->belongs_to_union:
result = self->do_cast_through_pointers(result, type_to_llvm(field->type))

elif ast->kind == AstExpressionKind::GetVariable:
v = get_special_constant(ast->varname)
if v == -1:
Expand Down
2 changes: 2 additions & 0 deletions self_hosted/llvm.jou
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ declare LLVMDisposeModule(M: LLVMModule*) -> None
declare LLVMGetSourceFileName(M: LLVMModule*, Len: long*) -> byte* # Return value not owned
declare LLVMSetDataLayout(M: LLVMModule*, DataLayoutStr: byte*) -> None
declare LLVMSetTarget(M: LLVMModule*, Triple: byte*) -> None
declare LLVMDumpType(Val: LLVMType*) -> None
declare LLVMDumpModule(M: LLVMModule*) -> None
declare LLVMPrintModuleToString(M: LLVMModule*) -> byte*
declare LLVMAddFunction(M: LLVMModule*, Name: byte*, FunctionTy: LLVMType*) -> LLVMValue*
Expand All @@ -212,6 +213,7 @@ declare LLVMIntType(NumBits: int) -> LLVMType*
declare LLVMGetReturnType(FunctionTy: LLVMType*) -> LLVMType*
declare LLVMGetParam(Fn: LLVMValue*, Index: int) -> LLVMValue*
declare LLVMGetElementType(Ty: LLVMType*) -> LLVMType*
declare LLVMStructGetTypeAtIndex(StructTy: LLVMType*, i: int) -> LLVMType*
declare LLVMTypeOf(Val: LLVMValue*) -> LLVMType*
declare LLVMConstNull(Ty: LLVMType*) -> LLVMValue*
declare LLVMGetUndef(Ty: LLVMType*) -> LLVMValue*
Expand Down
1 change: 0 additions & 1 deletion self_hosted/runs_wrong.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ tests/other_errors/missing_value_in_return.jou
tests/other_errors/noreturn_but_return_with_value.jou
tests/other_errors/noreturn_but_return_without_value.jou
tests/should_succeed/compiler_cli.jou
tests/should_succeed/linked_list.jou
tests/should_succeed/pointer.jou
tests/should_succeed/printf.jou
tests/other_errors/return_void.jou
Expand Down
2 changes: 2 additions & 0 deletions self_hosted/typecheck.jou
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None:
name = member->field.name,
type = type_from_ast(ft, &member->field.type),
union_id = union_id++,
belongs_to_union = False,
}
elif member->kind == AstClassMemberKind::Union:
uid = union_id++
Expand All @@ -428,6 +429,7 @@ def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None:
name = member->union_fields.fields[k].name,
type = type_from_ast(ft, &member->union_fields.fields[k].type),
union_id = uid,
belongs_to_union = True,
}
elif member->kind == AstClassMemberKind::Method:
# Don't handle the method body yet: that is a part of stage 3, not stage 2
Expand Down
1 change: 1 addition & 0 deletions self_hosted/types.jou
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ClassField:
# If multiple fields have the same union_id, they belong to the same union.
# It means that only one of the fields can be used at a time.
union_id: int
belongs_to_union: bool # are there more fields with same union_id

class ClassMembers:
fields: ClassField*
Expand Down
6 changes: 3 additions & 3 deletions src/codegen.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ static LLVMTypeRef codegen_type(const Type *type)

LLVMTypeRef *flat_elems = malloc(sizeof(flat_elems[0]) * n); // NOLINT
for (int i = 0; i < n; i++) {
// Treat all pointers inside structs as if they were void*.
// This allows structs to contain pointers to themselves.
// Treat all pointers inside classes as if they were void*.
// This allows classes to contain pointers to themselves.
if (type->data.classdata.fields.ptr[i].type->kind == TYPE_POINTER)
flat_elems[i] = codegen_type(voidPtrType);
else
Expand Down Expand Up @@ -323,7 +323,7 @@ static void codegen_instruction(const struct State *st, const CfInstruction *ins

LLVMValueRef val = LLVMBuildStructGEP2(st->builder, codegen_type(classtype), getop(0), f->union_id, ins->data.fieldname);
// This cast is needed in two cases:
// * All pointers are i8* in structs so we can do self-referencing classes.
// * All pointers are i8* in classes so we can do self-referencing classes.
// * This is how unions work.
val = LLVMBuildBitCast(st->builder, val, LLVMPointerType(codegen_type(f->type),0), "struct_member_cast");
setdest(val);
Expand Down

0 comments on commit 9ab14ca

Please sign in to comment.