Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow taking address of self when it is passed by value #487

Merged
merged 3 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions examples/aoc2023/day24/bigint.jou
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ class BigInt:
# assume that value fits into 64-bit long
# also assume little-endian
result: long
x = self # TODO: https://github.com/Akuli/jou/issues/485
memcpy(&result, &x.data, sizeof(result))
memcpy(&result, &self.data, sizeof(result))
return result

def add(self: BigInt, other: BigInt) -> BigInt:
result = bigint(0)
carry_bit = 0

for i = 0; i < sizeof(self.data); i++:
x = self # TODO: https://github.com/Akuli/jou/issues/485
result_byte = (x.data[i] as int) + (other.data[i] as int) + carry_bit
result_byte = (self.data[i] as int) + (other.data[i] as int) + carry_bit
if result_byte >= 256:
carry_bit = 1
else:
Expand All @@ -42,10 +40,9 @@ class BigInt:
def neg(self: BigInt) -> BigInt:
# Flipping all bits (~x) is almost same as negating the value.
# For example, -7 is f9ffffff... and ~7 is f8ffffff...
x = self # TODO: https://github.com/Akuli/jou/issues/485
for i = 0; i < sizeof(self.data); i++:
x.data[i] = (0xff as byte) - x.data[i]
return x.add(bigint(1))
self.data[i] = (0xff as byte) - self.data[i]
return self.add(bigint(1))

# x-y
def sub(self: BigInt, other: BigInt) -> BigInt:
Expand All @@ -56,17 +53,16 @@ class BigInt:
# self == other --> 0
# self > other --> 1
def compare(self: BigInt, other: BigInt) -> int:
x = self # TODO: https://github.com/Akuli/jou/issues/485
self_sign_bit = x.data[sizeof(self.data) - 1] / 128
self_sign_bit = self.data[sizeof(self.data) - 1] / 128
other_sign_bit = other.data[sizeof(other.data) - 1] / 128

if self_sign_bit != other_sign_bit:
return other_sign_bit - self_sign_bit

for i = sizeof(self.data) - 1; i >= 0; i--:
if (x.data[i] as int) < (other.data[i] as int):
if (self.data[i] as int) < (other.data[i] as int):
return -1
if (x.data[i] as int) > (other.data[i] as int):
if (self.data[i] as int) > (other.data[i] as int):
return 1

return 0
Expand Down Expand Up @@ -96,13 +92,13 @@ class BigInt:
# x*y
def mul(self: BigInt, other: BigInt) -> BigInt:
result_sign = self.sign() * other.sign()
self2 = self.abs() # TODO: https://github.com/Akuli/jou/issues/485
self = self.abs()
other = other.abs()

result = bigint(0)
for i = 0; i < sizeof(self2.data); i++:
for i = 0; i < sizeof(self.data); i++:
for k = 0; i+k < sizeof(result.data); k++:
temp = (self2.data[i] as int)*(other.data[k] as int)
temp = (self.data[i] as int)*(other.data[k] as int)

gonna_add = bigint(0)
gonna_add.data[i+k] = temp as byte
Expand All @@ -122,10 +118,9 @@ class BigInt:
if n >= sizeof(self.data):
return bigint(0)

self2 = self # TODO: https://github.com/Akuli/jou/issues/485
memmove(&self2.data, &self2.data[n], sizeof(self2.data) - n)
memset(&self2.data[sizeof(self2.data) - n], 0, n)
return self2
memmove(&self.data, &self.data[n], sizeof(self.data) - n)
memset(&self.data[sizeof(self.data) - n], 0, n)
return self

# x * 256^n for x >= 0
def shift_bigger(self: BigInt, n: int) -> BigInt:
Expand All @@ -135,10 +130,9 @@ class BigInt:
if n >= sizeof(self.data):
return bigint(0)

self2 = self # TODO: https://github.com/Akuli/jou/issues/485
memmove(&self2.data[n], &self2.data[0], sizeof(self2.data) - n)
memset(&self2.data, 0, n)
return self2
memmove(&self.data[n], &self.data[0], sizeof(self.data) - n)
memset(&self.data, 0, n)
return self

# [x/y, x%y]
def divmod(self: BigInt, bottom: BigInt) -> BigInt[2]:
Expand Down
87 changes: 53 additions & 34 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,33 +462,46 @@ value of the pointer &foo to bar.

errmsg_template can be e.g. "cannot take address of %s" or "cannot assign to %s"
*/
static void ensure_can_take_address(const AstExpression *expr, const char *errmsg_template)
static void ensure_can_take_address(const FunctionOrMethodTypes *fom, const AstExpression *expr, const char *errmsg_template)
{
assert(fom != NULL);

switch(expr->kind) {
case AST_EXPR_DEREFERENCE:
case AST_EXPR_INDEXING: // &foo[bar]
case AST_EXPR_DEREF_AND_GET_FIELD: // &foo->bar = foo + offset (it doesn't use &foo)
break;
return;

case AST_EXPR_GET_FIELD:
// &foo.bar = &foo + offset
{
// Turn "cannot assign to %s" into "cannot assign to a field of %s".
// This assumes that errmsg_template is relatively simple, i.e. it only contains one %s somewhere.
char *newtemplate = malloc(strlen(errmsg_template) + 100);
sprintf(newtemplate, errmsg_template, "a field of %s");
ensure_can_take_address(&expr->data.operands[0], newtemplate);
ensure_can_take_address(fom, &expr->data.operands[0], newtemplate);
free(newtemplate);
}
break;
return;

case AST_EXPR_GET_VARIABLE:
if (strcmp(expr->data.varname, "self") && get_special_constant(expr->data.varname) == -1) {
// not self or a special constant --> ok to take address
break;
if (get_special_constant(expr->data.varname) != -1)
goto error;

// In methods that take self as a pointer, you cannot take address of self
if (!strcmp(expr->data.varname, "self")) {
if (fom->signature.argtypes[0]->kind == TYPE_POINTER)
goto error;
}
__attribute__((fallthrough));

return;

default:
fail(expr->location, errmsg_template, short_expression_description(expr));
goto error;
}

error:
fail(expr->location, errmsg_template, short_expression_description(expr));
}

/*
Expand Down Expand Up @@ -550,7 +563,11 @@ static bool can_cast_implicitly(const Type *from, const Type *to)
}

static void do_implicit_cast(
ExpressionTypes *types, const Type *to, Location location, const char *errormsg_template)
const FunctionOrMethodTypes *fom,
ExpressionTypes *types,
const Type *to,
Location location,
const char *errormsg_template)
{
assert(!types->implicit_cast_type);
assert(!types->implicit_array_to_pointer_cast);
Expand Down Expand Up @@ -580,18 +597,19 @@ static void do_implicit_cast(
types->implicit_array_to_pointer_cast = (from->kind == TYPE_ARRAY && to->kind == TYPE_POINTER);
if (types->implicit_array_to_pointer_cast)
ensure_can_take_address(
fom,
types->expr,
"cannot create a pointer into an array that comes from %s (try storing it to a local variable first)"
);
}

static void cast_array_to_pointer(ExpressionTypes *types)
static void cast_array_to_pointer(const FunctionOrMethodTypes *fom, ExpressionTypes *types)
{
assert(types->type->kind == TYPE_ARRAY);
do_implicit_cast(types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL);
do_implicit_cast(fom, types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL);
}

static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location location)
static void do_explicit_cast(const FunctionOrMethodTypes *fom, ExpressionTypes *types, const Type *to, Location location)
{
assert(!types->implicit_cast_type);
const Type *from = types->type;
Expand All @@ -615,7 +633,7 @@ static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location lo
}

if (from->kind == TYPE_ARRAY && is_pointer_type(to))
cast_array_to_pointer(types);
cast_array_to_pointer(fom, types);
}

static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression *expr);
Expand Down Expand Up @@ -647,10 +665,11 @@ static void typecheck_expression_with_implicit_cast(
const char *errormsg_template)
{
ExpressionTypes *types = typecheck_expression_not_void(ft, expr);
do_implicit_cast(types, casttype, expr->location, errormsg_template);
do_implicit_cast(ft->current_fom_types, types, casttype, expr->location, errormsg_template);
}

static const Type *check_binop(
const FunctionOrMethodTypes *fom,
enum AstExpressionKind op,
Location location,
ExpressionTypes *lhstypes,
Expand Down Expand Up @@ -713,8 +732,8 @@ static const Type *check_binop(
cast_type = intType;
assert(cast_type);

do_implicit_cast(lhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(rhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(fom, lhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(fom, rhstypes, cast_type, (Location){0}, NULL);

switch(op) {
case AST_EXPR_ADD:
Expand Down Expand Up @@ -753,7 +772,7 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi
assert(0);
}

ensure_can_take_address(&expr->data.operands[0], bad_expr_fmt);
ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], bad_expr_fmt);
const Type *t = typecheck_expression_not_void(ft, &expr->data.operands[0])->type;
if (!is_integer_type(t) && !is_pointer_type(t))
fail(expr->location, bad_type_fmt, t->name);
Expand All @@ -775,7 +794,7 @@ static const Type *typecheck_indexing(

const Type *ptrtype;
if (types->type->kind == TYPE_ARRAY) {
cast_array_to_pointer(types);
cast_array_to_pointer(ft->current_fom_types, types);
ptrtype = types->implicit_cast_type;
} else {
if (types->type->kind != TYPE_POINTER)
Expand All @@ -794,7 +813,7 @@ static const Type *typecheck_indexing(

// LLVM assumes that indexes smaller than 64 bits are signed.
// https://github.com/Akuli/jou/issues/48
do_implicit_cast(indextypes, longType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, indextypes, longType, (Location){0}, NULL);

return ptrtype->data.valuetype;
}
Expand Down Expand Up @@ -883,16 +902,16 @@ static const Type *typecheck_function_or_method_call(FileTypes *ft, const AstCal
ExpressionTypes *types = typecheck_expression_not_void(ft, &call->args[i]);

if (types->type->kind == TYPE_ARRAY)
cast_array_to_pointer(types);
cast_array_to_pointer(ft->current_fom_types, types);
else if (
(is_integer_type(types->type) && types->type->data.width_in_bits < 32)
|| types->type == boolType)
{
// Add implicit cast to signed int, just like in C.
do_implicit_cast(types, intType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, types, intType, (Location){0}, NULL);
}
else if (types->type == floatType)
do_implicit_cast(types, doubleType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, types, doubleType, (Location){0}, NULL);
}

free(sigstr);
Expand Down Expand Up @@ -957,7 +976,7 @@ static bool enum_member_exists(const Type *t, const char *name)
return false;
}

static const Type *cast_array_members_to_a_common_type(Location error_location, ExpressionTypes **exprtypes)
static const Type *cast_array_members_to_a_common_type(const FunctionOrMethodTypes *fom, Location error_location, ExpressionTypes **exprtypes)
{
// Avoid O(ntypes^2) code in a long array where all or almost all items have the same type.
// This is at most O(ntypes*k) where k is the number of distinct types.
Expand Down Expand Up @@ -1002,7 +1021,7 @@ static const Type *cast_array_members_to_a_common_type(Location error_location,
free(compatible_with_all.ptr);

for (ExpressionTypes **et = exprtypes; *et; et++)
do_implicit_cast(*et, elemtype, error_location, NULL);
do_implicit_cast(fom, *et, elemtype, error_location, NULL);
return elemtype;
}

Expand Down Expand Up @@ -1044,7 +1063,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
for (int i = 0; i < n; i++)
exprtypes[i] = typecheck_expression_not_void(ft, &expr->data.array.items[i]);

const Type *membertype = cast_array_members_to_a_common_type(expr->location, exprtypes);
const Type *membertype = cast_array_members_to_a_common_type(ft->current_fom_types, expr->location, exprtypes);
free(exprtypes);
result = get_array_type(membertype, n);
}
Expand Down Expand Up @@ -1091,7 +1110,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
tmp, sizeof tmp,
"cannot take address of %%s, needed for calling the %s() method",
expr->data.methodcall.call.calledname);
ensure_can_take_address(expr->data.methodcall.obj, tmp);
ensure_can_take_address(ft->current_fom_types, expr->data.methodcall.obj, tmp);
}
found = true;
break;
Expand All @@ -1106,7 +1125,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
result = typecheck_indexing(ft, &expr->data.operands[0], &expr->data.operands[1]);
break;
case AST_EXPR_ADDRESS_OF:
ensure_can_take_address(&expr->data.operands[0], "the '&' operator cannot be used with %s");
ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], "the '&' operator cannot be used with %s");
temptype = typecheck_expression_not_void(ft, &expr->data.operands[0])->type;
result = get_pointer_type(temptype);
break;
Expand Down Expand Up @@ -1159,7 +1178,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
{
ExpressionTypes *lhstypes = typecheck_expression_not_void(ft, &expr->data.operands[0]);
ExpressionTypes *rhstypes = typecheck_expression_not_void(ft, &expr->data.operands[1]);
result = check_binop(expr->kind, expr->location, lhstypes, rhstypes);
result = check_binop(ft->current_fom_types, expr->kind, expr->location, lhstypes, rhstypes);
break;
}
case AST_EXPR_PRE_INCREMENT:
Expand All @@ -1172,7 +1191,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
{
ExpressionTypes *origtypes = typecheck_expression_not_void(ft, expr->data.as.obj);
result = type_from_ast(ft, &expr->data.as.type);
do_explicit_cast(origtypes, result, expr->location);
do_explicit_cast(ft->current_fom_types, origtypes, result, expr->location);
}
break;
}
Expand Down Expand Up @@ -1249,7 +1268,7 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt)
add_variable(ft, types->type, targetexpr->data.varname);
} else {
// Convert value to the type of an existing variable or other assignment target.
ensure_can_take_address(targetexpr, "cannot assign to %s");
ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s");

char errmsg[500];
if (targetexpr->kind == AST_EXPR_DEREFERENCE) {
Expand Down Expand Up @@ -1286,16 +1305,16 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt)
default: assert(0);
}

ensure_can_take_address(targetexpr, "cannot assign to %s");
ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s");
ExpressionTypes *targettypes = typecheck_expression_not_void(ft, targetexpr);
ExpressionTypes *valuetypes = typecheck_expression_not_void(ft, valueexpr);

const Type *t = check_binop(op, stmt->location, targettypes, valuetypes);
const Type *t = check_binop(ft->current_fom_types, op, stmt->location, targettypes, valuetypes);
ExpressionTypes tempvalue_types = { .expr = targetexpr, .type = t };

char msg[500];
snprintf(msg, sizeof msg, "%s produced a value of type FROM which cannot be assigned back to TO", opname);
do_implicit_cast(&tempvalue_types, targettypes->type, stmt->location, msg);
do_implicit_cast(ft->current_fom_types, &tempvalue_types, targettypes->type, stmt->location, msg);

// I think it is currently impossible to cast target.
// If this assert fails, we probably need to add another error message for it.
Expand Down
10 changes: 8 additions & 2 deletions tests/should_succeed/method_by_value.jou
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class Foo:
return Foo{x = self.x + n}

def add_one(self: Foo) -> Foo:
return Foo{x = self.x + 1}
self = Foo{x = self.x + 1}
return self

def print(self: Foo) -> None:
printf("%d\n", self.x)
Expand All @@ -16,8 +17,13 @@ class Foo:
def main() -> int:
Foo{x=5}.add(10).add_one().print() # Output: 16

# test the '->' operator
# The methods get a local copy
f = Foo{x=100}
f.add(123)
f.add_one()
f.print() # Output: 100

# test the '->' operator
(&f)->print() # Output: 100

return 0