diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index 21ccdc2b..ea9349c4 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -85,9 +85,12 @@ def class_type_to_llvm(fields: ClassField*, nfields: int) -> LLVMType*: end = start + 1 while end < nfields and fields[start].union_id == fields[end].union_id: end++ + assert fields[start].union_id == combined_len elem_types[combined_len++] = create_llvm_union_type(&elem_types[start], end-start) result = LLVMStructType(elem_types, combined_len, False as int) + for i = 0; i < combined_len; i++: + assert LLVMStructGetTypeAtIndex(result, i) == elem_types[i] free(elem_types) return result @@ -234,7 +237,7 @@ class AstToIR: p1 = LLVMBuildAlloca(self->builder, LLVMTypeOf(value), "cast_through_ptr_temp") LLVMBuildStore(self->builder, value, p1) p2 = LLVMBuildBitCast(self->builder, p1, LLVMPointerType(to, 0), "cast_through_ptr_temp") - return LLVMBuildLoad(self->builder, p2, "cast_through_ptr_result") + return LLVMBuildLoad2(self->builder, to, p2, "cast_through_ptr_result") def do_binop( self, @@ -584,6 +587,7 @@ class AstToIR: value = self->do_expression(&ast->instantiation.field_values[i]) if field_uses_i8_ptr_hack(field) or field->belongs_to_union: value = self->do_cast_through_pointers(value, LLVMStructGetTypeAtIndex(type_to_llvm(instance_type), field->union_id)) + assert LLVMTypeOf(value) == LLVMStructGetTypeAtIndex(type_to_llvm(instance_type), field->union_id) result = LLVMBuildInsertValue(self->builder, result, value, field->union_id, "instance") elif ast->kind == AstExpressionKind::GetClassField: diff --git a/self_hosted/llvm.jou b/self_hosted/llvm.jou index c4423e61..abd7b1ee 100644 --- a/self_hosted/llvm.jou +++ b/self_hosted/llvm.jou @@ -215,6 +215,7 @@ declare LLVMGetParam(Fn: LLVMValue*, Index: int) -> LLVMValue* declare LLVMGetElementType(Ty: LLVMType*) -> LLVMType* declare LLVMStructGetTypeAtIndex(StructTy: LLVMType*, i: int) -> LLVMType* declare LLVMTypeOf(Val: LLVMValue*) -> LLVMType* +declare LLVMAlignOf(Ty: LLVMType*) -> LLVMValue* declare LLVMConstNull(Ty: LLVMType*) -> LLVMValue* declare LLVMGetUndef(Ty: LLVMType*) -> LLVMValue* declare LLVMConstInt(IntTy: LLVMType*, N: long, SignExtend: int) -> LLVMValue* @@ -252,8 +253,10 @@ declare LLVMBuildXor(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Na declare LLVMBuildNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildFNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildMemSet(Builder: LLVMBuilder*, Ptr: LLVMValue*, Val: LLVMValue*, Len: LLVMValue*, Align: int) -> LLVMValue* +declare LLVMBuildMemCpy(Builder: LLVMBuilder*, Dst: LLVMValue*, DstAlign: int, Src: LLVMValue*, SrcAlign: int, Size: LLVMValue*) -> LLVMValue* declare LLVMBuildAlloca(Builder: LLVMBuilder*, Ty: LLVMType*, Name: byte*) -> LLVMValue* declare LLVMBuildLoad(Builder: LLVMBuilder*, PointerVal: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildLoad2(Builder: LLVMBuilder*, Ty: LLVMType*, PointerVal: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildStore(Builder: LLVMBuilder*, Val: LLVMValue*, Ptr: LLVMValue*) -> LLVMValue* declare LLVMBuildGEP(Builder: LLVMBuilder*, Pointer: LLVMValue*, Indices: LLVMValue**, NumIndices: int, Name: byte*) -> LLVMValue* declare LLVMBuildStructGEP2(Builder: LLVMBuilder*, Ty: LLVMType*, Pointer: LLVMValue*, Idx: int, Name: byte*) -> LLVMValue*