diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index 4c6f6a93..bb26220a 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -242,9 +242,13 @@ class AstToIR: if rhs_type->kind == TypeKind::Enum: rhs_type = int_type - got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type() - got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type() - assert got_numbers or got_pointers + if lhs_type == &bool_type and rhs_type == &bool_type: + # bools are 1-bit integers in llvm + if op == AstExpressionKind::Eq: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq") + if op == AstExpressionKind::Ne: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne") + assert False if lhs_type->kind == TypeKind::FloatingPoint and rhs_type->kind == TypeKind::FloatingPoint: if op == AstExpressionKind::Add: diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 9b582e60..56baeaeb 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -688,6 +688,7 @@ def check_binop( do_what = "compare" result_is_bool = True + got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type() got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type() got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum @@ -703,7 +704,7 @@ def check_binop( ) if ( - (not got_numbers and not got_enums and not got_pointers) + (not got_bools and not got_numbers and not got_enums and not got_pointers) or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers) ): message: byte[500] @@ -714,7 +715,9 @@ def check_binop( ) fail(location, message) - if got_integers: + if got_bools: + cast_type = &bool_type + elif got_integers: size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits) if ( lhs_types->original_type->kind == TypeKind::SignedInteger diff --git a/src/build_cfg.c b/src/build_cfg.c index dcb55cd2..56c9dcce 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -167,6 +167,41 @@ static const LocalVariable *build_cast( assert(0); } +static const LocalVariable *build_bool_eq(struct State *st, Location location, const LocalVariable *a, const LocalVariable *b) +{ + assert(a->type == boolType); + assert(b->type == boolType); + + /* + Pseudo code: + + if a: + result = b + else: + result = not b + */ + const LocalVariable *result = add_local_var(st, boolType); + + CfBlock *atrue = add_block(st); + CfBlock *afalse = add_block(st); + CfBlock *done = add_block(st); + + // if a: + add_jump(st, a, atrue, afalse, atrue); + + // result = b + add_unary_op(st, location, CF_VARCPY, b, result); + + // else: + add_jump(st, NULL, done, done, afalse); + + // result = not b + add_unary_op(st, location, CF_BOOL_NEGATE, b, result); + + add_jump(st, NULL, done, done, done); + return result; +} + static const LocalVariable *build_binop( struct State *st, enum AstExpressionKind op, @@ -175,32 +210,43 @@ static const LocalVariable *build_binop( const LocalVariable *rhs, const Type *result_type) { + bool got_bools = lhs->type == boolType && rhs->type == boolType; bool got_numbers = is_number_type(lhs->type) && is_number_type(rhs->type); bool got_pointers = is_pointer_type(lhs->type) && is_pointer_type(rhs->type); - assert(got_numbers || got_pointers); + assert(got_bools || got_numbers || got_pointers); - enum CfInstructionKind k; bool negate = false; bool swap = false; - switch(op) { - case AST_EXPR_ADD: k = CF_NUM_ADD; break; - case AST_EXPR_SUB: k = CF_NUM_SUB; break; - case AST_EXPR_MUL: k = CF_NUM_MUL; break; - case AST_EXPR_DIV: k = CF_NUM_DIV; break; - case AST_EXPR_MOD: k = CF_NUM_MOD; break; - case AST_EXPR_EQ: k = CF_NUM_EQ; break; - case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break; - case AST_EXPR_LT: k = CF_NUM_LT; break; - case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break; - case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break; - case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break; - default: assert(0); + const LocalVariable *destvar; + if (got_bools) { + assert(result_type == boolType); + destvar = build_bool_eq(st, location, lhs, rhs); + switch(op) { + case AST_EXPR_EQ: break; + case AST_EXPR_NE: negate=true; break; + default: assert(0); break; + } + } else { + destvar = add_local_var(st, result_type); + enum CfInstructionKind k; + switch(op) { + case AST_EXPR_ADD: k = CF_NUM_ADD; break; + case AST_EXPR_SUB: k = CF_NUM_SUB; break; + case AST_EXPR_MUL: k = CF_NUM_MUL; break; + case AST_EXPR_DIV: k = CF_NUM_DIV; break; + case AST_EXPR_MOD: k = CF_NUM_MOD; break; + case AST_EXPR_EQ: k = CF_NUM_EQ; break; + case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break; + case AST_EXPR_LT: k = CF_NUM_LT; break; + case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break; + case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break; + case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break; + default: assert(0); + } + add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar); } - const LocalVariable *destvar = add_local_var(st, result_type); - add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar); - if (!negate) return destvar; diff --git a/src/typecheck.c b/src/typecheck.c index 9ecfd4b9..387c1f39 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -696,6 +696,7 @@ static const Type *check_binop( assert(0); } + bool got_bools = lhstypes->type == boolType && rhstypes->type == boolType; bool got_integers = is_integer_type(lhstypes->type) && is_integer_type(rhstypes->type); bool got_numbers = is_number_type(lhstypes->type) && is_number_type(rhstypes->type); bool got_enums = lhstypes->type->kind == TYPE_ENUM && rhstypes->type->kind == TYPE_ENUM; @@ -711,13 +712,15 @@ static const Type *check_binop( ); if ( - (!got_numbers && !got_enums && !got_pointers) - || (got_enums && op != AST_EXPR_EQ && op != AST_EXPR_NE) + (!got_bools && !got_numbers && !got_enums && !got_pointers) + || ((got_bools || got_enums) && op != AST_EXPR_EQ && op != AST_EXPR_NE) || (got_pointers && op != AST_EXPR_EQ && op != AST_EXPR_NE && op != AST_EXPR_GT && op != AST_EXPR_GE && op != AST_EXPR_LT && op != AST_EXPR_LE) ) fail(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name); const Type *cast_type = NULL; + if (got_bools) + cast_type = boolType; if (got_integers) { cast_type = get_integer_type( max(lhstypes->type->data.width_in_bits, rhstypes->type->data.width_in_bits), diff --git a/tests/should_succeed/bool_eq_bool.jou b/tests/should_succeed/bool_eq_bool.jou new file mode 100644 index 00000000..8856782b --- /dev/null +++ b/tests/should_succeed/bool_eq_bool.jou @@ -0,0 +1,19 @@ +import "stdlib/io.jou" + +def do_stuff(a: bool, b: bool) -> None: + if a == b: + puts("Hi") + +def main() -> int: + do_stuff(False, False) # Output: Hi + + # Output: 1001 + printf( + "%d%d%d%d\n", + True == True, # Warning: this code will never run + True == False, # Warning: this code will never run + False == True, # Warning: this code will never run + False == False, # Warning: this code will never run + ) + + return 0 diff --git a/tests/wrong_type/bool_plus_bool.jou b/tests/wrong_type/bool_plus_bool.jou new file mode 100644 index 00000000..2dffae45 --- /dev/null +++ b/tests/wrong_type/bool_plus_bool.jou @@ -0,0 +1,3 @@ +def main() -> int: + x = True + False # Error: wrong types: cannot add bool and bool + return 0