diff --git a/self_hosted/parses_wrong.txt b/self_hosted/parses_wrong.txt index 05a0c544..f3dcefd8 100644 --- a/self_hosted/parses_wrong.txt +++ b/self_hosted/parses_wrong.txt @@ -1,2 +1,4 @@ tests/syntax_error/assign_to_None.jou tests/syntax_error/None_as_value.jou +tests/should_succeed/method_by_value.jou +tests/wrong_type/self_annotation.jou diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index 02c37c70..429a878e 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -14,3 +14,5 @@ tests/wrong_type/cannot_be_indexed.jou tests/wrong_type/index.jou tests/syntax_error/assign_to_None.jou tests/syntax_error/None_as_value.jou +tests/should_succeed/method_by_value.jou +tests/wrong_type/self_annotation.jou diff --git a/src/build_cfg.c b/src/build_cfg.c index 2619e8a1..dcb55cd2 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -404,25 +404,23 @@ static const LocalVariable *build_address_of_expression(struct State *st, const assert(0); } -static const LocalVariable *build_function_or_method_call(struct State *st, const Location location, const AstCall *call, const LocalVariable *self) +static const LocalVariable *build_function_or_method_call( + struct State *st, + const Location location, + const AstCall *call, + const AstExpression *self, + bool self_is_a_pointer) { - if(self) { - assert(self->type->kind == TYPE_POINTER); - assert(self->type->data.valuetype->kind == TYPE_CLASS); - } - - const LocalVariable **args = calloc(call->nargs + 2, sizeof(args[0])); // NOLINT - int k = 0; - if (self) - args[k++] = self; - for (int i = 0; i < call->nargs; i++) - args[k++] = build_expression(st, &call->args[i]); - const Signature *sig = NULL; + if(self) { - assert(self->type->kind == TYPE_POINTER); - const Type *selfclass = self->type->data.valuetype; + const Type *selfclass = get_expr_types(st, self)->type; + if (self_is_a_pointer) { + assert(selfclass->kind == TYPE_POINTER); + selfclass = selfclass->data.valuetype; + } assert(selfclass->kind == TYPE_CLASS); + for (const Signature *s = selfclass->data.classdata.methods.ptr; s < End(selfclass->data.classdata.methods); s++) { assert(get_self_class(s) == selfclass); if (!strcmp(s->name, call->calledname)) { @@ -440,6 +438,28 @@ static const LocalVariable *build_function_or_method_call(struct State *st, cons } assert(sig); + const LocalVariable **args = calloc(call->nargs + 2, sizeof(args[0])); // NOLINT + int k = 0; + + if (self) { + if (is_pointer_type(sig->argtypes[0]) && !self_is_a_pointer) { + args[k++] = build_address_of_expression(st, self); + } else if (!is_pointer_type(sig->argtypes[0]) && self_is_a_pointer) { + const LocalVariable *self_ptr = build_expression(st, self); + assert(self_ptr->type->kind == TYPE_POINTER); + + // dereference the pointer + const LocalVariable *val = add_local_var(st, self_ptr->type->data.valuetype); + add_unary_op(st, self->location, CF_PTR_LOAD, self_ptr, val); + args[k++] = val; + } else { + args[k++] = build_expression(st, self); + } + } + + for (int i = 0; i < call->nargs; i++) + args[k++] = build_expression(st, &call->args[i]); + const LocalVariable *return_value; if (sig->returntype) return_value = add_local_var(st, sig->returntype); @@ -539,21 +559,17 @@ static const LocalVariable *build_expression(struct State *st, const AstExpressi switch(expr->kind) { case AST_EXPR_DEREF_AND_CALL_METHOD: - temp = build_expression(st, expr->data.methodcall.obj); - assert(temp); - result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, temp); + result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, expr->data.methodcall.obj, true); if (!result) return NULL; break; case AST_EXPR_CALL_METHOD: - temp = build_address_of_expression(st, expr->data.methodcall.obj); - assert(temp); - result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, temp); + result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, expr->data.methodcall.obj, false); if (!result) return NULL; break; case AST_EXPR_FUNCTION_CALL: - result = build_function_or_method_call(st, expr->location, &expr->data.call, NULL); + result = build_function_or_method_call(st, expr->location, &expr->data.call, NULL, false); if (!result) return NULL; break; diff --git a/src/parse.c b/src/parse.c index 0abe0fc1..a821f33a 100644 --- a/src/parse.c +++ b/src/parse.c @@ -155,7 +155,11 @@ static AstSignature parse_function_signature(ParserState *ps, bool accept_self) } else if (is_keyword(ps->tokens, "self")) { if (!accept_self) fail(ps->tokens->location, "'self' cannot be used here"); - AstNameTypeValue self_arg = { .name="self", .name_location=ps->tokens++->location }; + AstNameTypeValue self_arg = { .name="self", .type.kind = AST_TYPE_NAMED, .type.data.name = "", .name_location=ps->tokens++->location }; + if (is_operator(ps->tokens, ":")) { + ps->tokens++; + self_arg.type = parse_type(ps); + } Append(&result.args, self_arg); used_self = true; } else { diff --git a/src/typecheck.c b/src/typecheck.c index 1bab6820..52d5a96d 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -216,10 +216,10 @@ static ExportSymbol handle_global_var(FileTypes *ft, const AstNameTypeValue *var return es; } -static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, const Type *self_type) +static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, const Type *self_class) { - if (find_function_or_method(ft, self_type, astsig->name)) - fail(astsig->name_location, "a %s named '%s' already exists", self_type ? "method" : "function", astsig->name); + if (find_function_or_method(ft, self_class, astsig->name)) + fail(astsig->name_location, "a %s named '%s' already exists", self_class ? "method" : "function", astsig->name); Signature sig = { .nargs = astsig->args.len, .takes_varargs = astsig->takes_varargs }; safe_strcpy(sig.name, astsig->name); @@ -231,10 +231,24 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con sig.argtypes = malloc(sizeof(sig.argtypes[0]) * sig.nargs); // NOLINT for (int i = 0; i < sig.nargs; i++) { - if (!strcmp(sig.argnames[i], "self")) - sig.argtypes[i] = get_pointer_type(self_type); - else - sig.argtypes[i] = type_from_ast(ft, &astsig->args.ptr[i].type); + const Type *argtype; + if ( + !strcmp(sig.argnames[i], "self") + && astsig->args.ptr[i].type.kind == AST_TYPE_NAMED + && astsig->args.ptr[i].type.data.name[0] == '\0' + ) { + // just "self" without a type after it --> default to "self: Foo*" in class Foo + argtype = get_pointer_type(self_class); + } else { + argtype = type_from_ast(ft, &astsig->args.ptr[i].type); + } + + if (!strcmp(sig.argnames[i], "self") && argtype != self_class && argtype != get_pointer_type(self_class)) + { + fail(astsig->args.ptr[i].type.location, "type of self must be %s* (default) or %s", self_class->name, self_class->name); + } + + sig.argtypes[i] = argtype; } sig.is_noreturn = is_noreturn(&astsig->returntype); @@ -245,7 +259,7 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con else sig.returntype = type_from_ast(ft, &astsig->returntype); - if (!self_type && !strcmp(sig.name, "main")) { + if (!self_class && !strcmp(sig.name, "main")) { // special main() function checks if (sig.returntype != intType) fail(astsig->returntype.location, "the main() function must return int"); @@ -263,7 +277,7 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con sig.returntype_location = astsig->returntype.location; - if (!self_type) + if (!self_class) Append(&ft->functions, (struct SignatureAndUsedPtr){ .signature=copy_signature(&sig), .usedptr=NULL }); return sig; @@ -1066,12 +1080,24 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression temptype = typecheck_expression_not_void(ft, expr->data.methodcall.obj)->type; result = typecheck_function_or_method_call(ft, &expr->data.methodcall.call, temptype, expr->location); - char tmp[500]; - snprintf( - tmp, sizeof tmp, - "cannot take address of %%s, needed for calling the %s() method", - expr->data.methodcall.call.calledname); - ensure_can_take_address(expr->data.methodcall.obj, tmp); + // If self argument is passed by pointer, make sure we can create that pointer + assert(temptype->kind == TYPE_CLASS); + bool found = false; + for (const Signature *m = temptype->data.classdata.methods.ptr; m < End(temptype->data.classdata.methods); m++) { + if (!strcmp(m->name, expr->data.methodcall.call.calledname)) { + if (is_pointer_type(m->argtypes[0])) { + char tmp[500]; + snprintf( + tmp, sizeof tmp, + "cannot take address of %%s, needed for calling the %s() method", + expr->data.methodcall.call.calledname); + ensure_can_take_address(expr->data.methodcall.obj, tmp); + } + found = true; + break; + } + } + assert(found); if (!result) return NULL; diff --git a/src/types.c b/src/types.c index 72eb44d0..5ce425ae 100644 --- a/src/types.c +++ b/src/types.c @@ -196,8 +196,14 @@ Type *create_enum(const char *name, int membercount, char (*membernames)[100]) const Type *get_self_class(const Signature *sig) { if (sig->nargs > 0 && !strcmp(sig->argnames[0], "self")) { - assert(sig->argtypes[0]->kind == TYPE_POINTER); - return sig->argtypes[0]->data.valuetype; + switch (sig->argtypes[0]->kind) { + case TYPE_POINTER: + return sig->argtypes[0]->data.valuetype; + case TYPE_CLASS: + return sig->argtypes[0]; + default: + assert(0); + } } return NULL; } diff --git a/tests/should_succeed/method_by_value.jou b/tests/should_succeed/method_by_value.jou new file mode 100644 index 00000000..05ec4a1b --- /dev/null +++ b/tests/should_succeed/method_by_value.jou @@ -0,0 +1,23 @@ +import "stdlib/io.jou" + +class Foo: + x: int + + def add(self: Foo, n: int) -> Foo: + return Foo{x = self.x + n} + + def add_one(self: Foo) -> Foo: + return Foo{x = self.x + 1} + + def print(self: Foo) -> None: + printf("%d\n", self.x) + + +def main() -> int: + Foo{x=5}.add(10).add_one().print() # Output: 16 + + # test the '->' operator + f = Foo{x=100} + (&f)->print() # Output: 100 + + return 0 diff --git a/tests/wrong_type/self_annotation.jou b/tests/wrong_type/self_annotation.jou new file mode 100644 index 00000000..ba3584e3 --- /dev/null +++ b/tests/wrong_type/self_annotation.jou @@ -0,0 +1,5 @@ +class Foo: + def bar( + self: int, # Error: type of self must be Foo* (default) or Foo + ) -> None: + return