diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 2b9d18f5..70461ed2 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -389,10 +389,19 @@ def handle_signature(ft: FileTypes*, astsig: AstSignature*, self_type: Type*) -> else: sig.return_type = type_from_ast(ft, &astsig->return_type) - # TODO: validate main() parameters - # TODO: test main() taking parameters - if self_type == NULL and strcmp(sig.name, "main") == 0 and sig.return_type != int_type: - fail(astsig->return_type.location, "the main() function must return int") + if self_type == NULL and strcmp(sig.name, "main") == 0: + # special main() function checks + if sig.return_type != int_type: + fail(astsig->return_type.location, "the main() function must return int") + if sig.nargs != 0 and not ( + sig.nargs == 2 + and sig.argtypes[0] == int_type + and sig.argtypes[1] == byte_type->get_pointer_type()->get_pointer_type() + ): + fail( + astsig->args[0].type.location, + "if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int" + ) return sig diff --git a/src/typecheck.c b/src/typecheck.c index ffdc4979..fb7ec78a 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -245,10 +245,20 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con else sig.returntype = type_from_ast(ft, &astsig->returntype); - // TODO: validate main() parameters - // TODO: test main() taking parameters - if (!self_type && !strcmp(sig.name, "main") && sig.returntype != intType) { - fail(astsig->returntype.location, "the main() function must return int"); + if (!self_type && !strcmp(sig.name, "main")) { + // special main() function checks + if (sig.returntype != intType) + fail(astsig->returntype.location, "the main() function must return int"); + if (sig.nargs != 0 && !( + sig.nargs == 2 + && sig.argtypes[0] == intType + && sig.argtypes[1] == get_pointer_type(get_pointer_type(byteType)))) + { + fail( + astsig->args.ptr[0].type.location, + "if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int" + ); + } } sig.returntype_location = astsig->returntype.location; diff --git a/tests/should_succeed/main_funny_arg_names.jou b/tests/should_succeed/main_funny_arg_names.jou new file mode 100644 index 00000000..7322f8ef --- /dev/null +++ b/tests/should_succeed/main_funny_arg_names.jou @@ -0,0 +1,6 @@ +import "stdlib/io.jou" + +# Usually the args are named "argc" and "argv", but you can name them whatever you want. +def main(lol: int, wat: byte**) -> int: + printf("lol = %d\n", lol) # Output: lol = 1 + return 0 diff --git a/tests/wrong_type/main_1_arg.jou b/tests/wrong_type/main_1_arg.jou new file mode 100644 index 00000000..7250996a --- /dev/null +++ b/tests/wrong_type/main_1_arg.jou @@ -0,0 +1,4 @@ +def main( + argc: int # Error: if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int +) -> int: + return 0 diff --git a/tests/wrong_type/main_3_args.jou b/tests/wrong_type/main_3_args.jou new file mode 100644 index 00000000..43bf2b22 --- /dev/null +++ b/tests/wrong_type/main_3_args.jou @@ -0,0 +1,4 @@ +def main( + argc: int, argv: byte**, lol: int # Error: if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int +) -> int: + return 0 diff --git a/tests/wrong_type/main_argc.jou b/tests/wrong_type/main_argc.jou new file mode 100644 index 00000000..ef11e7c0 --- /dev/null +++ b/tests/wrong_type/main_argc.jou @@ -0,0 +1,4 @@ +def main( + argc: long, argv: byte** # Error: if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int +) -> int: + return 0 diff --git a/tests/wrong_type/main_argv.jou b/tests/wrong_type/main_argv.jou new file mode 100644 index 00000000..a896919a --- /dev/null +++ b/tests/wrong_type/main_argv.jou @@ -0,0 +1,4 @@ +def main( + argc: int, argv: byte* # Error: if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int +) -> int: + return 0