From b37d0ffe05106ba8dd8be105bc7d7269e1334446 Mon Sep 17 00:00:00 2001 From: Akuli Date: Tue, 26 Dec 2023 17:16:29 +0200 Subject: [PATCH] Allow taking address of self when it is passed by value (#487) --- examples/aoc2023/day24/bigint.jou | 38 +++++------ src/typecheck.c | 87 +++++++++++++++--------- tests/should_succeed/method_by_value.jou | 10 ++- 3 files changed, 77 insertions(+), 58 deletions(-) diff --git a/examples/aoc2023/day24/bigint.jou b/examples/aoc2023/day24/bigint.jou index 0dbb0334..e96750ab 100644 --- a/examples/aoc2023/day24/bigint.jou +++ b/examples/aoc2023/day24/bigint.jou @@ -19,8 +19,7 @@ class BigInt: # assume that value fits into 64-bit long # also assume little-endian result: long - x = self # TODO: https://github.com/Akuli/jou/issues/485 - memcpy(&result, &x.data, sizeof(result)) + memcpy(&result, &self.data, sizeof(result)) return result def add(self: BigInt, other: BigInt) -> BigInt: @@ -28,8 +27,7 @@ class BigInt: carry_bit = 0 for i = 0; i < sizeof(self.data); i++: - x = self # TODO: https://github.com/Akuli/jou/issues/485 - result_byte = (x.data[i] as int) + (other.data[i] as int) + carry_bit + result_byte = (self.data[i] as int) + (other.data[i] as int) + carry_bit if result_byte >= 256: carry_bit = 1 else: @@ -42,10 +40,9 @@ class BigInt: def neg(self: BigInt) -> BigInt: # Flipping all bits (~x) is almost same as negating the value. # For example, -7 is f9ffffff... and ~7 is f8ffffff... - x = self # TODO: https://github.com/Akuli/jou/issues/485 for i = 0; i < sizeof(self.data); i++: - x.data[i] = (0xff as byte) - x.data[i] - return x.add(bigint(1)) + self.data[i] = (0xff as byte) - self.data[i] + return self.add(bigint(1)) # x-y def sub(self: BigInt, other: BigInt) -> BigInt: @@ -56,17 +53,16 @@ class BigInt: # self == other --> 0 # self > other --> 1 def compare(self: BigInt, other: BigInt) -> int: - x = self # TODO: https://github.com/Akuli/jou/issues/485 - self_sign_bit = x.data[sizeof(self.data) - 1] / 128 + self_sign_bit = self.data[sizeof(self.data) - 1] / 128 other_sign_bit = other.data[sizeof(other.data) - 1] / 128 if self_sign_bit != other_sign_bit: return other_sign_bit - self_sign_bit for i = sizeof(self.data) - 1; i >= 0; i--: - if (x.data[i] as int) < (other.data[i] as int): + if (self.data[i] as int) < (other.data[i] as int): return -1 - if (x.data[i] as int) > (other.data[i] as int): + if (self.data[i] as int) > (other.data[i] as int): return 1 return 0 @@ -96,13 +92,13 @@ class BigInt: # x*y def mul(self: BigInt, other: BigInt) -> BigInt: result_sign = self.sign() * other.sign() - self2 = self.abs() # TODO: https://github.com/Akuli/jou/issues/485 + self = self.abs() other = other.abs() result = bigint(0) - for i = 0; i < sizeof(self2.data); i++: + for i = 0; i < sizeof(self.data); i++: for k = 0; i+k < sizeof(result.data); k++: - temp = (self2.data[i] as int)*(other.data[k] as int) + temp = (self.data[i] as int)*(other.data[k] as int) gonna_add = bigint(0) gonna_add.data[i+k] = temp as byte @@ -122,10 +118,9 @@ class BigInt: if n >= sizeof(self.data): return bigint(0) - self2 = self # TODO: https://github.com/Akuli/jou/issues/485 - memmove(&self2.data, &self2.data[n], sizeof(self2.data) - n) - memset(&self2.data[sizeof(self2.data) - n], 0, n) - return self2 + memmove(&self.data, &self.data[n], sizeof(self.data) - n) + memset(&self.data[sizeof(self.data) - n], 0, n) + return self # x * 256^n for x >= 0 def shift_bigger(self: BigInt, n: int) -> BigInt: @@ -135,10 +130,9 @@ class BigInt: if n >= sizeof(self.data): return bigint(0) - self2 = self # TODO: https://github.com/Akuli/jou/issues/485 - memmove(&self2.data[n], &self2.data[0], sizeof(self2.data) - n) - memset(&self2.data, 0, n) - return self2 + memmove(&self.data[n], &self.data[0], sizeof(self.data) - n) + memset(&self.data, 0, n) + return self # [x/y, x%y] def divmod(self: BigInt, bottom: BigInt) -> BigInt[2]: diff --git a/src/typecheck.c b/src/typecheck.c index 52d5a96d..9ecfd4b9 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -462,13 +462,16 @@ value of the pointer &foo to bar. errmsg_template can be e.g. "cannot take address of %s" or "cannot assign to %s" */ -static void ensure_can_take_address(const AstExpression *expr, const char *errmsg_template) +static void ensure_can_take_address(const FunctionOrMethodTypes *fom, const AstExpression *expr, const char *errmsg_template) { + assert(fom != NULL); + switch(expr->kind) { case AST_EXPR_DEREFERENCE: case AST_EXPR_INDEXING: // &foo[bar] case AST_EXPR_DEREF_AND_GET_FIELD: // &foo->bar = foo + offset (it doesn't use &foo) - break; + return; + case AST_EXPR_GET_FIELD: // &foo.bar = &foo + offset { @@ -476,19 +479,29 @@ static void ensure_can_take_address(const AstExpression *expr, const char *errms // This assumes that errmsg_template is relatively simple, i.e. it only contains one %s somewhere. char *newtemplate = malloc(strlen(errmsg_template) + 100); sprintf(newtemplate, errmsg_template, "a field of %s"); - ensure_can_take_address(&expr->data.operands[0], newtemplate); + ensure_can_take_address(fom, &expr->data.operands[0], newtemplate); free(newtemplate); } - break; + return; + case AST_EXPR_GET_VARIABLE: - if (strcmp(expr->data.varname, "self") && get_special_constant(expr->data.varname) == -1) { - // not self or a special constant --> ok to take address - break; + if (get_special_constant(expr->data.varname) != -1) + goto error; + + // In methods that take self as a pointer, you cannot take address of self + if (!strcmp(expr->data.varname, "self")) { + if (fom->signature.argtypes[0]->kind == TYPE_POINTER) + goto error; } - __attribute__((fallthrough)); + + return; + default: - fail(expr->location, errmsg_template, short_expression_description(expr)); + goto error; } + +error: + fail(expr->location, errmsg_template, short_expression_description(expr)); } /* @@ -550,7 +563,11 @@ static bool can_cast_implicitly(const Type *from, const Type *to) } static void do_implicit_cast( - ExpressionTypes *types, const Type *to, Location location, const char *errormsg_template) + const FunctionOrMethodTypes *fom, + ExpressionTypes *types, + const Type *to, + Location location, + const char *errormsg_template) { assert(!types->implicit_cast_type); assert(!types->implicit_array_to_pointer_cast); @@ -580,18 +597,19 @@ static void do_implicit_cast( types->implicit_array_to_pointer_cast = (from->kind == TYPE_ARRAY && to->kind == TYPE_POINTER); if (types->implicit_array_to_pointer_cast) ensure_can_take_address( + fom, types->expr, "cannot create a pointer into an array that comes from %s (try storing it to a local variable first)" ); } -static void cast_array_to_pointer(ExpressionTypes *types) +static void cast_array_to_pointer(const FunctionOrMethodTypes *fom, ExpressionTypes *types) { assert(types->type->kind == TYPE_ARRAY); - do_implicit_cast(types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL); + do_implicit_cast(fom, types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL); } -static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location location) +static void do_explicit_cast(const FunctionOrMethodTypes *fom, ExpressionTypes *types, const Type *to, Location location) { assert(!types->implicit_cast_type); const Type *from = types->type; @@ -615,7 +633,7 @@ static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location lo } if (from->kind == TYPE_ARRAY && is_pointer_type(to)) - cast_array_to_pointer(types); + cast_array_to_pointer(fom, types); } static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression *expr); @@ -647,10 +665,11 @@ static void typecheck_expression_with_implicit_cast( const char *errormsg_template) { ExpressionTypes *types = typecheck_expression_not_void(ft, expr); - do_implicit_cast(types, casttype, expr->location, errormsg_template); + do_implicit_cast(ft->current_fom_types, types, casttype, expr->location, errormsg_template); } static const Type *check_binop( + const FunctionOrMethodTypes *fom, enum AstExpressionKind op, Location location, ExpressionTypes *lhstypes, @@ -713,8 +732,8 @@ static const Type *check_binop( cast_type = intType; assert(cast_type); - do_implicit_cast(lhstypes, cast_type, (Location){0}, NULL); - do_implicit_cast(rhstypes, cast_type, (Location){0}, NULL); + do_implicit_cast(fom, lhstypes, cast_type, (Location){0}, NULL); + do_implicit_cast(fom, rhstypes, cast_type, (Location){0}, NULL); switch(op) { case AST_EXPR_ADD: @@ -753,7 +772,7 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi assert(0); } - ensure_can_take_address(&expr->data.operands[0], bad_expr_fmt); + ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], bad_expr_fmt); const Type *t = typecheck_expression_not_void(ft, &expr->data.operands[0])->type; if (!is_integer_type(t) && !is_pointer_type(t)) fail(expr->location, bad_type_fmt, t->name); @@ -775,7 +794,7 @@ static const Type *typecheck_indexing( const Type *ptrtype; if (types->type->kind == TYPE_ARRAY) { - cast_array_to_pointer(types); + cast_array_to_pointer(ft->current_fom_types, types); ptrtype = types->implicit_cast_type; } else { if (types->type->kind != TYPE_POINTER) @@ -794,7 +813,7 @@ static const Type *typecheck_indexing( // LLVM assumes that indexes smaller than 64 bits are signed. // https://github.com/Akuli/jou/issues/48 - do_implicit_cast(indextypes, longType, (Location){0}, NULL); + do_implicit_cast(ft->current_fom_types, indextypes, longType, (Location){0}, NULL); return ptrtype->data.valuetype; } @@ -883,16 +902,16 @@ static const Type *typecheck_function_or_method_call(FileTypes *ft, const AstCal ExpressionTypes *types = typecheck_expression_not_void(ft, &call->args[i]); if (types->type->kind == TYPE_ARRAY) - cast_array_to_pointer(types); + cast_array_to_pointer(ft->current_fom_types, types); else if ( (is_integer_type(types->type) && types->type->data.width_in_bits < 32) || types->type == boolType) { // Add implicit cast to signed int, just like in C. - do_implicit_cast(types, intType, (Location){0}, NULL); + do_implicit_cast(ft->current_fom_types, types, intType, (Location){0}, NULL); } else if (types->type == floatType) - do_implicit_cast(types, doubleType, (Location){0}, NULL); + do_implicit_cast(ft->current_fom_types, types, doubleType, (Location){0}, NULL); } free(sigstr); @@ -957,7 +976,7 @@ static bool enum_member_exists(const Type *t, const char *name) return false; } -static const Type *cast_array_members_to_a_common_type(Location error_location, ExpressionTypes **exprtypes) +static const Type *cast_array_members_to_a_common_type(const FunctionOrMethodTypes *fom, Location error_location, ExpressionTypes **exprtypes) { // Avoid O(ntypes^2) code in a long array where all or almost all items have the same type. // This is at most O(ntypes*k) where k is the number of distinct types. @@ -1002,7 +1021,7 @@ static const Type *cast_array_members_to_a_common_type(Location error_location, free(compatible_with_all.ptr); for (ExpressionTypes **et = exprtypes; *et; et++) - do_implicit_cast(*et, elemtype, error_location, NULL); + do_implicit_cast(fom, *et, elemtype, error_location, NULL); return elemtype; } @@ -1044,7 +1063,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression for (int i = 0; i < n; i++) exprtypes[i] = typecheck_expression_not_void(ft, &expr->data.array.items[i]); - const Type *membertype = cast_array_members_to_a_common_type(expr->location, exprtypes); + const Type *membertype = cast_array_members_to_a_common_type(ft->current_fom_types, expr->location, exprtypes); free(exprtypes); result = get_array_type(membertype, n); } @@ -1091,7 +1110,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression tmp, sizeof tmp, "cannot take address of %%s, needed for calling the %s() method", expr->data.methodcall.call.calledname); - ensure_can_take_address(expr->data.methodcall.obj, tmp); + ensure_can_take_address(ft->current_fom_types, expr->data.methodcall.obj, tmp); } found = true; break; @@ -1106,7 +1125,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression result = typecheck_indexing(ft, &expr->data.operands[0], &expr->data.operands[1]); break; case AST_EXPR_ADDRESS_OF: - ensure_can_take_address(&expr->data.operands[0], "the '&' operator cannot be used with %s"); + ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], "the '&' operator cannot be used with %s"); temptype = typecheck_expression_not_void(ft, &expr->data.operands[0])->type; result = get_pointer_type(temptype); break; @@ -1159,7 +1178,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression { ExpressionTypes *lhstypes = typecheck_expression_not_void(ft, &expr->data.operands[0]); ExpressionTypes *rhstypes = typecheck_expression_not_void(ft, &expr->data.operands[1]); - result = check_binop(expr->kind, expr->location, lhstypes, rhstypes); + result = check_binop(ft->current_fom_types, expr->kind, expr->location, lhstypes, rhstypes); break; } case AST_EXPR_PRE_INCREMENT: @@ -1172,7 +1191,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression { ExpressionTypes *origtypes = typecheck_expression_not_void(ft, expr->data.as.obj); result = type_from_ast(ft, &expr->data.as.type); - do_explicit_cast(origtypes, result, expr->location); + do_explicit_cast(ft->current_fom_types, origtypes, result, expr->location); } break; } @@ -1249,7 +1268,7 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt) add_variable(ft, types->type, targetexpr->data.varname); } else { // Convert value to the type of an existing variable or other assignment target. - ensure_can_take_address(targetexpr, "cannot assign to %s"); + ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s"); char errmsg[500]; if (targetexpr->kind == AST_EXPR_DEREFERENCE) { @@ -1286,16 +1305,16 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt) default: assert(0); } - ensure_can_take_address(targetexpr, "cannot assign to %s"); + ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s"); ExpressionTypes *targettypes = typecheck_expression_not_void(ft, targetexpr); ExpressionTypes *valuetypes = typecheck_expression_not_void(ft, valueexpr); - const Type *t = check_binop(op, stmt->location, targettypes, valuetypes); + const Type *t = check_binop(ft->current_fom_types, op, stmt->location, targettypes, valuetypes); ExpressionTypes tempvalue_types = { .expr = targetexpr, .type = t }; char msg[500]; snprintf(msg, sizeof msg, "%s produced a value of type FROM which cannot be assigned back to TO", opname); - do_implicit_cast(&tempvalue_types, targettypes->type, stmt->location, msg); + do_implicit_cast(ft->current_fom_types, &tempvalue_types, targettypes->type, stmt->location, msg); // I think it is currently impossible to cast target. // If this assert fails, we probably need to add another error message for it. diff --git a/tests/should_succeed/method_by_value.jou b/tests/should_succeed/method_by_value.jou index 05ec4a1b..507965c4 100644 --- a/tests/should_succeed/method_by_value.jou +++ b/tests/should_succeed/method_by_value.jou @@ -7,7 +7,8 @@ class Foo: return Foo{x = self.x + n} def add_one(self: Foo) -> Foo: - return Foo{x = self.x + 1} + self = Foo{x = self.x + 1} + return self def print(self: Foo) -> None: printf("%d\n", self.x) @@ -16,8 +17,13 @@ class Foo: def main() -> int: Foo{x=5}.add(10).add_one().print() # Output: 16 - # test the '->' operator + # The methods get a local copy f = Foo{x=100} + f.add(123) + f.add_one() + f.print() # Output: 100 + + # test the '->' operator (&f)->print() # Output: 100 return 0