Skip to content

Commit

Permalink
Methods that take self by value (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Dec 26, 2023
1 parent 284c402 commit 37dddc9
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 40 deletions.
2 changes: 2 additions & 0 deletions self_hosted/parses_wrong.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
tests/syntax_error/assign_to_None.jou
tests/syntax_error/None_as_value.jou
tests/should_succeed/method_by_value.jou
tests/wrong_type/self_annotation.jou
2 changes: 2 additions & 0 deletions self_hosted/runs_wrong.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ tests/wrong_type/cannot_be_indexed.jou
tests/wrong_type/index.jou
tests/syntax_error/assign_to_None.jou
tests/syntax_error/None_as_value.jou
tests/should_succeed/method_by_value.jou
tests/wrong_type/self_annotation.jou
60 changes: 38 additions & 22 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -404,25 +404,23 @@ static const LocalVariable *build_address_of_expression(struct State *st, const
assert(0);
}

static const LocalVariable *build_function_or_method_call(struct State *st, const Location location, const AstCall *call, const LocalVariable *self)
static const LocalVariable *build_function_or_method_call(
struct State *st,
const Location location,
const AstCall *call,
const AstExpression *self,
bool self_is_a_pointer)
{
if(self) {
assert(self->type->kind == TYPE_POINTER);
assert(self->type->data.valuetype->kind == TYPE_CLASS);
}

const LocalVariable **args = calloc(call->nargs + 2, sizeof(args[0])); // NOLINT
int k = 0;
if (self)
args[k++] = self;
for (int i = 0; i < call->nargs; i++)
args[k++] = build_expression(st, &call->args[i]);

const Signature *sig = NULL;

if(self) {
assert(self->type->kind == TYPE_POINTER);
const Type *selfclass = self->type->data.valuetype;
const Type *selfclass = get_expr_types(st, self)->type;
if (self_is_a_pointer) {
assert(selfclass->kind == TYPE_POINTER);
selfclass = selfclass->data.valuetype;
}
assert(selfclass->kind == TYPE_CLASS);

for (const Signature *s = selfclass->data.classdata.methods.ptr; s < End(selfclass->data.classdata.methods); s++) {
assert(get_self_class(s) == selfclass);
if (!strcmp(s->name, call->calledname)) {
Expand All @@ -440,6 +438,28 @@ static const LocalVariable *build_function_or_method_call(struct State *st, cons
}
assert(sig);

const LocalVariable **args = calloc(call->nargs + 2, sizeof(args[0])); // NOLINT
int k = 0;

if (self) {
if (is_pointer_type(sig->argtypes[0]) && !self_is_a_pointer) {
args[k++] = build_address_of_expression(st, self);
} else if (!is_pointer_type(sig->argtypes[0]) && self_is_a_pointer) {
const LocalVariable *self_ptr = build_expression(st, self);
assert(self_ptr->type->kind == TYPE_POINTER);

// dereference the pointer
const LocalVariable *val = add_local_var(st, self_ptr->type->data.valuetype);
add_unary_op(st, self->location, CF_PTR_LOAD, self_ptr, val);
args[k++] = val;
} else {
args[k++] = build_expression(st, self);
}
}

for (int i = 0; i < call->nargs; i++)
args[k++] = build_expression(st, &call->args[i]);

const LocalVariable *return_value;
if (sig->returntype)
return_value = add_local_var(st, sig->returntype);
Expand Down Expand Up @@ -539,21 +559,17 @@ static const LocalVariable *build_expression(struct State *st, const AstExpressi

switch(expr->kind) {
case AST_EXPR_DEREF_AND_CALL_METHOD:
temp = build_expression(st, expr->data.methodcall.obj);
assert(temp);
result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, temp);
result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, expr->data.methodcall.obj, true);
if (!result)
return NULL;
break;
case AST_EXPR_CALL_METHOD:
temp = build_address_of_expression(st, expr->data.methodcall.obj);
assert(temp);
result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, temp);
result = build_function_or_method_call(st, expr->location, &expr->data.methodcall.call, expr->data.methodcall.obj, false);
if (!result)
return NULL;
break;
case AST_EXPR_FUNCTION_CALL:
result = build_function_or_method_call(st, expr->location, &expr->data.call, NULL);
result = build_function_or_method_call(st, expr->location, &expr->data.call, NULL, false);
if (!result)
return NULL;
break;
Expand Down
6 changes: 5 additions & 1 deletion src/parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ static AstSignature parse_function_signature(ParserState *ps, bool accept_self)
} else if (is_keyword(ps->tokens, "self")) {
if (!accept_self)
fail(ps->tokens->location, "'self' cannot be used here");
AstNameTypeValue self_arg = { .name="self", .name_location=ps->tokens++->location };
AstNameTypeValue self_arg = { .name="self", .type.kind = AST_TYPE_NAMED, .type.data.name = "", .name_location=ps->tokens++->location };
if (is_operator(ps->tokens, ":")) {
ps->tokens++;
self_arg.type = parse_type(ps);
}
Append(&result.args, self_arg);
used_self = true;
} else {
Expand Down
56 changes: 41 additions & 15 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ static ExportSymbol handle_global_var(FileTypes *ft, const AstNameTypeValue *var
return es;
}

static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, const Type *self_type)
static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, const Type *self_class)
{
if (find_function_or_method(ft, self_type, astsig->name))
fail(astsig->name_location, "a %s named '%s' already exists", self_type ? "method" : "function", astsig->name);
if (find_function_or_method(ft, self_class, astsig->name))
fail(astsig->name_location, "a %s named '%s' already exists", self_class ? "method" : "function", astsig->name);

Signature sig = { .nargs = astsig->args.len, .takes_varargs = astsig->takes_varargs };
safe_strcpy(sig.name, astsig->name);
Expand All @@ -231,10 +231,24 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con

sig.argtypes = malloc(sizeof(sig.argtypes[0]) * sig.nargs); // NOLINT
for (int i = 0; i < sig.nargs; i++) {
if (!strcmp(sig.argnames[i], "self"))
sig.argtypes[i] = get_pointer_type(self_type);
else
sig.argtypes[i] = type_from_ast(ft, &astsig->args.ptr[i].type);
const Type *argtype;
if (
!strcmp(sig.argnames[i], "self")
&& astsig->args.ptr[i].type.kind == AST_TYPE_NAMED
&& astsig->args.ptr[i].type.data.name[0] == '\0'
) {
// just "self" without a type after it --> default to "self: Foo*" in class Foo
argtype = get_pointer_type(self_class);
} else {
argtype = type_from_ast(ft, &astsig->args.ptr[i].type);
}

if (!strcmp(sig.argnames[i], "self") && argtype != self_class && argtype != get_pointer_type(self_class))
{
fail(astsig->args.ptr[i].type.location, "type of self must be %s* (default) or %s", self_class->name, self_class->name);
}

sig.argtypes[i] = argtype;
}

sig.is_noreturn = is_noreturn(&astsig->returntype);
Expand All @@ -245,7 +259,7 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con
else
sig.returntype = type_from_ast(ft, &astsig->returntype);

if (!self_type && !strcmp(sig.name, "main")) {
if (!self_class && !strcmp(sig.name, "main")) {
// special main() function checks
if (sig.returntype != intType)
fail(astsig->returntype.location, "the main() function must return int");
Expand All @@ -263,7 +277,7 @@ static Signature handle_signature(FileTypes *ft, const AstSignature *astsig, con

sig.returntype_location = astsig->returntype.location;

if (!self_type)
if (!self_class)
Append(&ft->functions, (struct SignatureAndUsedPtr){ .signature=copy_signature(&sig), .usedptr=NULL });

return sig;
Expand Down Expand Up @@ -1066,12 +1080,24 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
temptype = typecheck_expression_not_void(ft, expr->data.methodcall.obj)->type;
result = typecheck_function_or_method_call(ft, &expr->data.methodcall.call, temptype, expr->location);

char tmp[500];
snprintf(
tmp, sizeof tmp,
"cannot take address of %%s, needed for calling the %s() method",
expr->data.methodcall.call.calledname);
ensure_can_take_address(expr->data.methodcall.obj, tmp);
// If self argument is passed by pointer, make sure we can create that pointer
assert(temptype->kind == TYPE_CLASS);
bool found = false;
for (const Signature *m = temptype->data.classdata.methods.ptr; m < End(temptype->data.classdata.methods); m++) {
if (!strcmp(m->name, expr->data.methodcall.call.calledname)) {
if (is_pointer_type(m->argtypes[0])) {
char tmp[500];
snprintf(
tmp, sizeof tmp,
"cannot take address of %%s, needed for calling the %s() method",
expr->data.methodcall.call.calledname);
ensure_can_take_address(expr->data.methodcall.obj, tmp);
}
found = true;
break;
}
}
assert(found);

if (!result)
return NULL;
Expand Down
10 changes: 8 additions & 2 deletions src/types.c
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,14 @@ Type *create_enum(const char *name, int membercount, char (*membernames)[100])
const Type *get_self_class(const Signature *sig)
{
if (sig->nargs > 0 && !strcmp(sig->argnames[0], "self")) {
assert(sig->argtypes[0]->kind == TYPE_POINTER);
return sig->argtypes[0]->data.valuetype;
switch (sig->argtypes[0]->kind) {
case TYPE_POINTER:
return sig->argtypes[0]->data.valuetype;
case TYPE_CLASS:
return sig->argtypes[0];
default:
assert(0);
}
}
return NULL;
}
Expand Down
23 changes: 23 additions & 0 deletions tests/should_succeed/method_by_value.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import "stdlib/io.jou"

class Foo:
x: int

def add(self: Foo, n: int) -> Foo:
return Foo{x = self.x + n}

def add_one(self: Foo) -> Foo:
return Foo{x = self.x + 1}

def print(self: Foo) -> None:
printf("%d\n", self.x)


def main() -> int:
Foo{x=5}.add(10).add_one().print() # Output: 16

# test the '->' operator
f = Foo{x=100}
(&f)->print() # Output: 100

return 0
5 changes: 5 additions & 0 deletions tests/wrong_type/self_annotation.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Foo:
def bar(
self: int, # Error: type of self must be Foo* (default) or Foo
) -> None:
return

0 comments on commit 37dddc9

Please sign in to comment.