diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index 4c6f6a93..bb26220a 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -242,9 +242,13 @@ class AstToIR: if rhs_type->kind == TypeKind::Enum: rhs_type = int_type - got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type() - got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type() - assert got_numbers or got_pointers + if lhs_type == &bool_type and rhs_type == &bool_type: + # bools are 1-bit integers in llvm + if op == AstExpressionKind::Eq: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq") + if op == AstExpressionKind::Ne: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne") + assert False if lhs_type->kind == TypeKind::FloatingPoint and rhs_type->kind == TypeKind::FloatingPoint: if op == AstExpressionKind::Add: diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 9b582e60..56baeaeb 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -688,6 +688,7 @@ def check_binop( do_what = "compare" result_is_bool = True + got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type() got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type() got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum @@ -703,7 +704,7 @@ def check_binop( ) if ( - (not got_numbers and not got_enums and not got_pointers) + (not got_bools and not got_numbers and not got_enums and not got_pointers) or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers) ): message: byte[500] @@ -714,7 +715,9 @@ def check_binop( ) fail(location, message) - if got_integers: + if got_bools: + cast_type = &bool_type + elif got_integers: size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits) if ( lhs_types->original_type->kind == TypeKind::SignedInteger