Skip to content

Commit

Permalink
Add == for booleans
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Dec 26, 2023
1 parent 075ea5c commit 82e84bb
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 20 deletions.
82 changes: 64 additions & 18 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,41 @@ static const LocalVariable *build_cast(
assert(0);
}

static const LocalVariable *build_bool_eq(struct State *st, Location location, const LocalVariable *a, const LocalVariable *b)
{
assert(a->type == boolType);
assert(b->type == boolType);

/*
Pseudo code:
if a:
result = b
else:
result = not b
*/
const LocalVariable *result = add_local_var(st, boolType);

CfBlock *atrue = add_block(st);
CfBlock *afalse = add_block(st);
CfBlock *done = add_block(st);

// if a:
add_jump(st, a, atrue, afalse, atrue);

// result = b
add_unary_op(st, location, CF_VARCPY, b, result);

// else:
add_jump(st, NULL, done, done, afalse);

// result = not b
add_unary_op(st, location, CF_BOOL_NEGATE, b, result);

add_jump(st, NULL, done, done, done);
return result;
}

static const LocalVariable *build_binop(
struct State *st,
enum AstExpressionKind op,
Expand All @@ -175,32 +210,43 @@ static const LocalVariable *build_binop(
const LocalVariable *rhs,
const Type *result_type)
{
bool got_bools = lhs->type == boolType && rhs->type == boolType;
bool got_numbers = is_number_type(lhs->type) && is_number_type(rhs->type);
bool got_pointers = is_pointer_type(lhs->type) && is_pointer_type(rhs->type);
assert(got_numbers || got_pointers);
assert(got_bools || got_numbers || got_pointers);

enum CfInstructionKind k;
bool negate = false;
bool swap = false;

switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
const LocalVariable *destvar;
if (got_bools) {
assert(result_type == boolType);
destvar = build_bool_eq(st, location, lhs, rhs);
switch(op) {
case AST_EXPR_EQ: break;
case AST_EXPR_NE: negate=true; break;
default: assert(0); break;
}
} else {
destvar = add_local_var(st, result_type);
enum CfInstructionKind k;
switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
}
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);
}

const LocalVariable *destvar = add_local_var(st, result_type);
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);

if (!negate)
return destvar;

Expand Down
7 changes: 5 additions & 2 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ static const Type *check_binop(
assert(0);
}

bool got_bools = lhstypes->type == boolType && rhstypes->type == boolType;
bool got_integers = is_integer_type(lhstypes->type) && is_integer_type(rhstypes->type);
bool got_numbers = is_number_type(lhstypes->type) && is_number_type(rhstypes->type);
bool got_enums = lhstypes->type->kind == TYPE_ENUM && rhstypes->type->kind == TYPE_ENUM;
Expand All @@ -692,13 +693,15 @@ static const Type *check_binop(
);

if (
(!got_numbers && !got_enums && !got_pointers)
|| (got_enums && op != AST_EXPR_EQ && op != AST_EXPR_NE)
(!got_bools && !got_numbers && !got_enums && !got_pointers)
|| ((got_bools || got_enums) && op != AST_EXPR_EQ && op != AST_EXPR_NE)
|| (got_pointers && op != AST_EXPR_EQ && op != AST_EXPR_NE && op != AST_EXPR_GT && op != AST_EXPR_GE && op != AST_EXPR_LT && op != AST_EXPR_LE)
)
fail(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name);

const Type *cast_type = NULL;
if (got_bools)
cast_type = boolType;
if (got_integers) {
cast_type = get_integer_type(
max(lhstypes->type->data.width_in_bits, rhstypes->type->data.width_in_bits),
Expand Down
19 changes: 19 additions & 0 deletions tests/should_succeed/bool_eq_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import "stdlib/io.jou"

def do_stuff(a: bool, b: bool) -> None:
if a == b:
puts("Hi")

def main() -> int:
do_stuff(False, False) # Output: Hi

# Output: 1001
printf(
"%d%d%d%d\n",
True == True, # Warning: this code will never run
True == False, # Warning: this code will never run
False == True, # Warning: this code will never run
False == False, # Warning: this code will never run
)

return 0
3 changes: 3 additions & 0 deletions tests/wrong_type/bool_plus_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def main() -> int:
x = True + False # Error: wrong types: cannot add bool and bool
return 0

0 comments on commit 82e84bb

Please sign in to comment.