Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add == for booleans #490

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,13 @@ class AstToIR:
if rhs_type->kind == TypeKind::Enum:
rhs_type = int_type

got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type()
got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type()
assert got_numbers or got_pointers
if lhs_type == &bool_type and rhs_type == &bool_type:
# bools are 1-bit integers in llvm
if op == AstExpressionKind::Eq:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq")
if op == AstExpressionKind::Ne:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne")
assert False

if lhs_type->kind == TypeKind::FloatingPoint and rhs_type->kind == TypeKind::FloatingPoint:
if op == AstExpressionKind::Add:
Expand Down
7 changes: 5 additions & 2 deletions self_hosted/typecheck.jou
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def check_binop(
do_what = "compare"
result_is_bool = True

got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type
got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type()
got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type()
got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum
Expand All @@ -703,7 +704,7 @@ def check_binop(
)

if (
(not got_numbers and not got_enums and not got_pointers)
(not got_bools and not got_numbers and not got_enums and not got_pointers)
or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers)
):
message: byte[500]
Expand All @@ -714,7 +715,9 @@ def check_binop(
)
fail(location, message)

if got_integers:
if got_bools:
cast_type = &bool_type
elif got_integers:
size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits)
if (
lhs_types->original_type->kind == TypeKind::SignedInteger
Expand Down
82 changes: 64 additions & 18 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,41 @@ static const LocalVariable *build_cast(
assert(0);
}

static const LocalVariable *build_bool_eq(struct State *st, Location location, const LocalVariable *a, const LocalVariable *b)
{
assert(a->type == boolType);
assert(b->type == boolType);

/*
Pseudo code:

if a:
result = b
else:
result = not b
*/
const LocalVariable *result = add_local_var(st, boolType);

CfBlock *atrue = add_block(st);
CfBlock *afalse = add_block(st);
CfBlock *done = add_block(st);

// if a:
add_jump(st, a, atrue, afalse, atrue);

// result = b
add_unary_op(st, location, CF_VARCPY, b, result);

// else:
add_jump(st, NULL, done, done, afalse);

// result = not b
add_unary_op(st, location, CF_BOOL_NEGATE, b, result);

add_jump(st, NULL, done, done, done);
return result;
}

static const LocalVariable *build_binop(
struct State *st,
enum AstExpressionKind op,
Expand All @@ -175,32 +210,43 @@ static const LocalVariable *build_binop(
const LocalVariable *rhs,
const Type *result_type)
{
bool got_bools = lhs->type == boolType && rhs->type == boolType;
bool got_numbers = is_number_type(lhs->type) && is_number_type(rhs->type);
bool got_pointers = is_pointer_type(lhs->type) && is_pointer_type(rhs->type);
assert(got_numbers || got_pointers);
assert(got_bools || got_numbers || got_pointers);

enum CfInstructionKind k;
bool negate = false;
bool swap = false;

switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
const LocalVariable *destvar;
if (got_bools) {
assert(result_type == boolType);
destvar = build_bool_eq(st, location, lhs, rhs);
switch(op) {
case AST_EXPR_EQ: break;
case AST_EXPR_NE: negate=true; break;
default: assert(0); break;
}
} else {
destvar = add_local_var(st, result_type);
enum CfInstructionKind k;
switch(op) {
case AST_EXPR_ADD: k = CF_NUM_ADD; break;
case AST_EXPR_SUB: k = CF_NUM_SUB; break;
case AST_EXPR_MUL: k = CF_NUM_MUL; break;
case AST_EXPR_DIV: k = CF_NUM_DIV; break;
case AST_EXPR_MOD: k = CF_NUM_MOD; break;
case AST_EXPR_EQ: k = CF_NUM_EQ; break;
case AST_EXPR_NE: k = CF_NUM_EQ; negate=true; break;
case AST_EXPR_LT: k = CF_NUM_LT; break;
case AST_EXPR_GT: k = CF_NUM_LT; swap=true; break;
case AST_EXPR_LE: k = CF_NUM_LT; negate=true; swap=true; break;
case AST_EXPR_GE: k = CF_NUM_LT; negate=true; break;
default: assert(0);
}
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);
}

const LocalVariable *destvar = add_local_var(st, result_type);
add_binary_op(st, location, k, swap?rhs:lhs, swap?lhs:rhs, destvar);

if (!negate)
return destvar;

Expand Down
7 changes: 5 additions & 2 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ static const Type *check_binop(
assert(0);
}

bool got_bools = lhstypes->type == boolType && rhstypes->type == boolType;
bool got_integers = is_integer_type(lhstypes->type) && is_integer_type(rhstypes->type);
bool got_numbers = is_number_type(lhstypes->type) && is_number_type(rhstypes->type);
bool got_enums = lhstypes->type->kind == TYPE_ENUM && rhstypes->type->kind == TYPE_ENUM;
Expand All @@ -692,13 +693,15 @@ static const Type *check_binop(
);

if (
(!got_numbers && !got_enums && !got_pointers)
|| (got_enums && op != AST_EXPR_EQ && op != AST_EXPR_NE)
(!got_bools && !got_numbers && !got_enums && !got_pointers)
|| ((got_bools || got_enums) && op != AST_EXPR_EQ && op != AST_EXPR_NE)
|| (got_pointers && op != AST_EXPR_EQ && op != AST_EXPR_NE && op != AST_EXPR_GT && op != AST_EXPR_GE && op != AST_EXPR_LT && op != AST_EXPR_LE)
)
fail(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name);

const Type *cast_type = NULL;
if (got_bools)
cast_type = boolType;
if (got_integers) {
cast_type = get_integer_type(
max(lhstypes->type->data.width_in_bits, rhstypes->type->data.width_in_bits),
Expand Down
19 changes: 19 additions & 0 deletions tests/should_succeed/bool_eq_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import "stdlib/io.jou"

def do_stuff(a: bool, b: bool) -> None:
if a == b:
puts("Hi")

def main() -> int:
do_stuff(False, False) # Output: Hi

# Output: 1001
printf(
"%d%d%d%d\n",
True == True, # Warning: this code will never run
True == False, # Warning: this code will never run
False == True, # Warning: this code will never run
False == False, # Warning: this code will never run
)

return 0
3 changes: 3 additions & 0 deletions tests/wrong_type/bool_plus_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def main() -> int:
x = True + False # Error: wrong types: cannot add bool and bool
return 0