From 82e84bb68725b0484b51ee594345254a7248cf0a Mon Sep 17 00:00:00 2001 From: Akuli Date: Tue, 26 Dec 2023 17:50:30 +0200 Subject: [PATCH] Add == for booleans --- src/build_cfg.c | 82 +++++++++++++++++++++------ src/typecheck.c | 7 ++- tests/should_succeed/bool_eq_bool.jou | 19 +++++++ tests/wrong_type/bool_plus_bool.jou | 3 + 4 files changed, 91 insertions(+), 20 deletions(-) create mode 100644 tests/should_succeed/bool_eq_bool.jou create mode 100644 tests/wrong_type/bool_plus_bool.jou diff --git a/src/build_cfg.c b/src/build_cfg.c index dcb55cd2..56c9dcce 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -167,6 +167,41 @@ static const LocalVariable *build_cast( assert(0); } +static const LocalVariable *build_bool_eq(struct State *st, Location location, const LocalVariable *a, const LocalVariable *b) +{ + assert(a->type == boolType); + assert(b->type == boolType); + + /* + Pseudo code: + + if a: + result = b + else: + result = not b + */ + const LocalVariable *result = add_local_var(st, boolType); + + CfBlock *atrue = add_block(st); + CfBlock *afalse = add_block(st); + CfBlock *done = add_block(st); + + // if a: + add_jump(st, a, atrue, afalse, atrue); + + // result = b + add_unary_op(st, location, CF_VARCPY, b, result); + + // else: + add_jump(st, NULL, done, done, afalse); + + // result = not b + add_unary_op(st, location, CF_BOOL_NEGATE, b, result); + + add_jump(st, NULL, done, done, done); + return result; +} + static const LocalVariable *build_binop( struct State *st, enum AstExpressionKind op, @@ -175,32 +210,43 @@ static const LocalVariable *build_binop( const LocalVariable *rhs, const Type *result_type) { + bool got_bools = lhs->type == boolType && rhs->type == boolType; bool got_numbers = is_number_type(lhs->type) && is_number_type(rhs->type); bool got_pointers = is_pointer_type(lhs->type) && is_pointer_type(rhs->type); - assert(got_numbers || got_pointers); + assert(got_bools || got_numbers || got_pointers); - enum CfInstructionKind k; bool negate = false; bool swap = false; - switch(op) { - case AST_EXPR_ADD: k = CF_NUM_ADD; break; - case AST_EXPR_SUB: k = CF_NUM_SUB; break; - case AST_EXPR_MUL: k = CF_NUM_MUL; break; - case AST_EXPR_DIV: k = CF_NUM_DIV; break; - case AST_EXPR_MOD: k = CF_NUM_MOD; break; - case AST_EXPR_EQ: k = CF_NUM_EQ; break; - case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break; - case AST_EXPR_LT: k = CF_NUM_LT; break; - case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break; - case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break; - case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break; - default: assert(0); + const LocalVariable *destvar; + if (got_bools) { + assert(result_type == boolType); + destvar = build_bool_eq(st, location, lhs, rhs); + switch(op) { + case AST_EXPR_EQ: break; + case AST_EXPR_NE: negate=true; break; + default: assert(0); break; + } + } else { + destvar = add_local_var(st, result_type); + enum CfInstructionKind k; + switch(op) { + case AST_EXPR_ADD: k = CF_NUM_ADD; break; + case AST_EXPR_SUB: k = CF_NUM_SUB; break; + case AST_EXPR_MUL: k = CF_NUM_MUL; break; + case AST_EXPR_DIV: k = CF_NUM_DIV; break; + case AST_EXPR_MOD: k = CF_NUM_MOD; break; + case AST_EXPR_EQ: k = CF_NUM_EQ; break; + case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break; + case AST_EXPR_LT: k = CF_NUM_LT; break; + case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break; + case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break; + case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break; + default: assert(0); + } + add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar); } - const LocalVariable *destvar = add_local_var(st, result_type); - add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar); - if (!negate) return destvar; diff --git a/src/typecheck.c b/src/typecheck.c index 52d5a96d..c2955e6f 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -677,6 +677,7 @@ static const Type *check_binop( assert(0); } + bool got_bools = lhstypes->type == boolType && rhstypes->type == boolType; bool got_integers = is_integer_type(lhstypes->type) && is_integer_type(rhstypes->type); bool got_numbers = is_number_type(lhstypes->type) && is_number_type(rhstypes->type); bool got_enums = lhstypes->type->kind == TYPE_ENUM && rhstypes->type->kind == TYPE_ENUM; @@ -692,13 +693,15 @@ static const Type *check_binop( ); if ( - (!got_numbers && !got_enums && !got_pointers) - || (got_enums && op != AST_EXPR_EQ && op != AST_EXPR_NE) + (!got_bools && !got_numbers && !got_enums && !got_pointers) + || ((got_bools || got_enums) && op != AST_EXPR_EQ && op != AST_EXPR_NE) || (got_pointers && op != AST_EXPR_EQ && op != AST_EXPR_NE && op != AST_EXPR_GT && op != AST_EXPR_GE && op != AST_EXPR_LT && op != AST_EXPR_LE) ) fail(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name); const Type *cast_type = NULL; + if (got_bools) + cast_type = boolType; if (got_integers) { cast_type = get_integer_type( max(lhstypes->type->data.width_in_bits, rhstypes->type->data.width_in_bits), diff --git a/tests/should_succeed/bool_eq_bool.jou b/tests/should_succeed/bool_eq_bool.jou new file mode 100644 index 00000000..8856782b --- /dev/null +++ b/tests/should_succeed/bool_eq_bool.jou @@ -0,0 +1,19 @@ +import "stdlib/io.jou" + +def do_stuff(a: bool, b: bool) -> None: + if a == b: + puts("Hi") + +def main() -> int: + do_stuff(False, False) # Output: Hi + + # Output: 1001 + printf( + "%d%d%d%d\n", + True == True, # Warning: this code will never run + True == False, # Warning: this code will never run + False == True, # Warning: this code will never run + False == False, # Warning: this code will never run + ) + + return 0 diff --git a/tests/wrong_type/bool_plus_bool.jou b/tests/wrong_type/bool_plus_bool.jou new file mode 100644 index 00000000..2dffae45 --- /dev/null +++ b/tests/wrong_type/bool_plus_bool.jou @@ -0,0 +1,3 @@ +def main() -> int: + x = True + False # Error: wrong types: cannot add bool and bool + return 0