Skip to content

Commit

Permalink
Allow using match statement with integers
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Jan 26, 2025
1 parent dd896f4 commit 0f8b238
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 19 deletions.
36 changes: 22 additions & 14 deletions compiler/typecheck/step3_function_and_method_bodies.jou
Original file line number Diff line number Diff line change
Expand Up @@ -1075,21 +1075,32 @@ def typecheck_if_statement(state: State*, ifstmt: AstIfStatement*) -> None:

def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> None:
msg: byte[500]
sig_string: byte* = NULL
remaining: byte** = NULL
nremaining = -1

if match_stmt->func_name[0] == '\0':
case_type = typecheck_expression_not_void(state, &match_stmt->match_obj)
if case_type->kind != TypeKind.Enum:
# TODO: extend match statements to other equals-comparable values
snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", case_type->name)
fail(match_stmt->match_obj.location, msg)

# Ensure user checks all possible enum values
nremaining = case_type->enummembers.count
remaining: byte** = malloc(sizeof(remaining[0]) * nremaining)
assert remaining != NULL
for i = 0; i < nremaining; i++:
remaining[i] = case_type->enummembers.names[i]
sig_string: byte* = NULL
match case_type->kind:
case TypeKind.SignedInteger | TypeKind.UnsignedInteger:
# no special handling needed
pass
case TypeKind.Enum:
# Ensure user checks all possible enum values
nremaining = case_type->enummembers.count
remaining = malloc(sizeof(remaining[0]) * nremaining)
assert remaining != NULL
for i = 0; i < nremaining; i++:
remaining[i] = case_type->enummembers.names[i]
case _:
snprintf(msg, sizeof(msg), "cannot match a value of type %s", case_type->name)
if strlen(msg) + 100 < sizeof(msg):
if case_type == byteType->pointer_type():
strcat(msg, " (try adding 'with strcmp')")
if case_type == boolType:
strcat(msg, " (use if/else instead)")
fail(match_stmt->match_obj.location, msg)
else:
sig = state->file_types->find_function(match_stmt->func_name)
if sig == NULL:
Expand All @@ -1109,10 +1120,7 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) ->

snprintf(msg, sizeof(msg), "cannot match <from> with %s", sig_string)
typecheck_expression_with_implicit_cast(state, &match_stmt->match_obj, sig->argtypes[0], msg)

case_type = sig->argtypes[1]
remaining = NULL
nremaining = -1

for i = 0; i < match_stmt->ncases; i++:
for j = 0; j < match_stmt->cases[i].n_case_objs; j++:
Expand Down
8 changes: 8 additions & 0 deletions tests/should_succeed/match.jou
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def show_evaluation(foo: Foo, msg: byte*) -> Foo:


def main() -> int:
match 1 + 2:
case 2:
printf("lol wat?\n")
case 3:
printf("Three\n") # Output: Three
case 4:
printf("lol wat?!!\n")

f = Foo.Bar
match f:
case Foo.Bar:
Expand Down
7 changes: 7 additions & 0 deletions tests/wrong_type/match_bool.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import "stdlib/io.jou"

def main() -> int:
match True: # Error: cannot match a value of type bool (use if/else instead)
case True:
printf("ya\n")
return 0
11 changes: 11 additions & 0 deletions tests/wrong_type/match_class.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import "stdlib/io.jou"

class Foo:
x: int
y: int

def main() -> int:
match Foo{x=1, y=2}: # Error: cannot match a value of type Foo
case Foo{x=1, y=2}:
printf("ya\n")
return 0
6 changes: 6 additions & 0 deletions tests/wrong_type/match_float.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def bruh() -> None:
# Floats and doubles usually shouldn't be == compared, so matching them
# is a bad idea (unless you use "match ... with ..." syntax)
match 1.0: # Error: cannot match a value of type double
case 2.0:
pass
5 changes: 0 additions & 5 deletions tests/wrong_type/match_not_enum.jou

This file was deleted.

7 changes: 7 additions & 0 deletions tests/wrong_type/match_string_without_strcmp.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import "stdlib/io.jou"

def main() -> int:
match "hi": # Error: cannot match a value of type byte* (try adding 'with strcmp')
case "ho":
printf("ya\n")
return 0

0 comments on commit 0f8b238

Please sign in to comment.