Skip to content

Commit

Permalink
Allow taking address of self when it is passed by value (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Dec 26, 2023
1 parent 075ea5c commit b37d0ff
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 58 deletions.
38 changes: 16 additions & 22 deletions examples/aoc2023/day24/bigint.jou
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ class BigInt:
# assume that value fits into 64-bit long
# also assume little-endian
result: long
x = self # TODO: https://github.com/Akuli/jou/issues/485
memcpy(&result, &x.data, sizeof(result))
memcpy(&result, &self.data, sizeof(result))
return result

def add(self: BigInt, other: BigInt) -> BigInt:
result = bigint(0)
carry_bit = 0

for i = 0; i < sizeof(self.data); i++:
x = self # TODO: https://github.com/Akuli/jou/issues/485
result_byte = (x.data[i] as int) + (other.data[i] as int) + carry_bit
result_byte = (self.data[i] as int) + (other.data[i] as int) + carry_bit
if result_byte >= 256:
carry_bit = 1
else:
Expand All @@ -42,10 +40,9 @@ class BigInt:
def neg(self: BigInt) -> BigInt:
# Flipping all bits (~x) is almost same as negating the value.
# For example, -7 is f9ffffff... and ~7 is f8ffffff...
x = self # TODO: https://github.com/Akuli/jou/issues/485
for i = 0; i < sizeof(self.data); i++:
x.data[i] = (0xff as byte) - x.data[i]
return x.add(bigint(1))
self.data[i] = (0xff as byte) - self.data[i]
return self.add(bigint(1))

# x-y
def sub(self: BigInt, other: BigInt) -> BigInt:
Expand All @@ -56,17 +53,16 @@ class BigInt:
# self == other --> 0
# self > other --> 1
def compare(self: BigInt, other: BigInt) -> int:
x = self # TODO: https://github.com/Akuli/jou/issues/485
self_sign_bit = x.data[sizeof(self.data) - 1] / 128
self_sign_bit = self.data[sizeof(self.data) - 1] / 128
other_sign_bit = other.data[sizeof(other.data) - 1] / 128

if self_sign_bit != other_sign_bit:
return other_sign_bit - self_sign_bit

for i = sizeof(self.data) - 1; i >= 0; i--:
if (x.data[i] as int) < (other.data[i] as int):
if (self.data[i] as int) < (other.data[i] as int):
return -1
if (x.data[i] as int) > (other.data[i] as int):
if (self.data[i] as int) > (other.data[i] as int):
return 1

return 0
Expand Down Expand Up @@ -96,13 +92,13 @@ class BigInt:
# x*y
def mul(self: BigInt, other: BigInt) -> BigInt:
result_sign = self.sign() * other.sign()
self2 = self.abs() # TODO: https://github.com/Akuli/jou/issues/485
self = self.abs()
other = other.abs()

result = bigint(0)
for i = 0; i < sizeof(self2.data); i++:
for i = 0; i < sizeof(self.data); i++:
for k = 0; i+k < sizeof(result.data); k++:
temp = (self2.data[i] as int)*(other.data[k] as int)
temp = (self.data[i] as int)*(other.data[k] as int)

gonna_add = bigint(0)
gonna_add.data[i+k] = temp as byte
Expand All @@ -122,10 +118,9 @@ class BigInt:
if n >= sizeof(self.data):
return bigint(0)

self2 = self # TODO: https://github.com/Akuli/jou/issues/485
memmove(&self2.data, &self2.data[n], sizeof(self2.data) - n)
memset(&self2.data[sizeof(self2.data) - n], 0, n)
return self2
memmove(&self.data, &self.data[n], sizeof(self.data) - n)
memset(&self.data[sizeof(self.data) - n], 0, n)
return self

# x * 256^n for x >= 0
def shift_bigger(self: BigInt, n: int) -> BigInt:
Expand All @@ -135,10 +130,9 @@ class BigInt:
if n >= sizeof(self.data):
return bigint(0)

self2 = self # TODO: https://github.com/Akuli/jou/issues/485
memmove(&self2.data[n], &self2.data[0], sizeof(self2.data) - n)
memset(&self2.data, 0, n)
return self2
memmove(&self.data[n], &self.data[0], sizeof(self.data) - n)
memset(&self.data, 0, n)
return self

# [x/y, x%y]
def divmod(self: BigInt, bottom: BigInt) -> BigInt[2]:
Expand Down
87 changes: 53 additions & 34 deletions src/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,33 +462,46 @@ value of the pointer &foo to bar.
errmsg_template can be e.g. "cannot take address of %s" or "cannot assign to %s"
*/
static void ensure_can_take_address(const AstExpression *expr, const char *errmsg_template)
static void ensure_can_take_address(const FunctionOrMethodTypes *fom, const AstExpression *expr, const char *errmsg_template)
{
assert(fom != NULL);

switch(expr->kind) {
case AST_EXPR_DEREFERENCE:
case AST_EXPR_INDEXING: // &foo[bar]
case AST_EXPR_DEREF_AND_GET_FIELD: // &foo->bar = foo + offset (it doesn't use &foo)
break;
return;

case AST_EXPR_GET_FIELD:
// &foo.bar = &foo + offset
{
// Turn "cannot assign to %s" into "cannot assign to a field of %s".
// This assumes that errmsg_template is relatively simple, i.e. it only contains one %s somewhere.
char *newtemplate = malloc(strlen(errmsg_template) + 100);
sprintf(newtemplate, errmsg_template, "a field of %s");
ensure_can_take_address(&expr->data.operands[0], newtemplate);
ensure_can_take_address(fom, &expr->data.operands[0], newtemplate);
free(newtemplate);
}
break;
return;

case AST_EXPR_GET_VARIABLE:
if (strcmp(expr->data.varname, "self") && get_special_constant(expr->data.varname) == -1) {
// not self or a special constant --> ok to take address
break;
if (get_special_constant(expr->data.varname) != -1)
goto error;

// In methods that take self as a pointer, you cannot take address of self
if (!strcmp(expr->data.varname, "self")) {
if (fom->signature.argtypes[0]->kind == TYPE_POINTER)
goto error;
}
__attribute__((fallthrough));

return;

default:
fail(expr->location, errmsg_template, short_expression_description(expr));
goto error;
}

error:
fail(expr->location, errmsg_template, short_expression_description(expr));
}

/*
Expand Down Expand Up @@ -550,7 +563,11 @@ static bool can_cast_implicitly(const Type *from, const Type *to)
}

static void do_implicit_cast(
ExpressionTypes *types, const Type *to, Location location, const char *errormsg_template)
const FunctionOrMethodTypes *fom,
ExpressionTypes *types,
const Type *to,
Location location,
const char *errormsg_template)
{
assert(!types->implicit_cast_type);
assert(!types->implicit_array_to_pointer_cast);
Expand Down Expand Up @@ -580,18 +597,19 @@ static void do_implicit_cast(
types->implicit_array_to_pointer_cast = (from->kind == TYPE_ARRAY && to->kind == TYPE_POINTER);
if (types->implicit_array_to_pointer_cast)
ensure_can_take_address(
fom,
types->expr,
"cannot create a pointer into an array that comes from %s (try storing it to a local variable first)"
);
}

static void cast_array_to_pointer(ExpressionTypes *types)
static void cast_array_to_pointer(const FunctionOrMethodTypes *fom, ExpressionTypes *types)
{
assert(types->type->kind == TYPE_ARRAY);
do_implicit_cast(types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL);
do_implicit_cast(fom, types, get_pointer_type(types->type->data.array.membertype), (Location){0}, NULL);
}

static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location location)
static void do_explicit_cast(const FunctionOrMethodTypes *fom, ExpressionTypes *types, const Type *to, Location location)
{
assert(!types->implicit_cast_type);
const Type *from = types->type;
Expand All @@ -615,7 +633,7 @@ static void do_explicit_cast(ExpressionTypes *types, const Type *to, Location lo
}

if (from->kind == TYPE_ARRAY && is_pointer_type(to))
cast_array_to_pointer(types);
cast_array_to_pointer(fom, types);
}

static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression *expr);
Expand Down Expand Up @@ -647,10 +665,11 @@ static void typecheck_expression_with_implicit_cast(
const char *errormsg_template)
{
ExpressionTypes *types = typecheck_expression_not_void(ft, expr);
do_implicit_cast(types, casttype, expr->location, errormsg_template);
do_implicit_cast(ft->current_fom_types, types, casttype, expr->location, errormsg_template);
}

static const Type *check_binop(
const FunctionOrMethodTypes *fom,
enum AstExpressionKind op,
Location location,
ExpressionTypes *lhstypes,
Expand Down Expand Up @@ -713,8 +732,8 @@ static const Type *check_binop(
cast_type = intType;
assert(cast_type);

do_implicit_cast(lhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(rhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(fom, lhstypes, cast_type, (Location){0}, NULL);
do_implicit_cast(fom, rhstypes, cast_type, (Location){0}, NULL);

switch(op) {
case AST_EXPR_ADD:
Expand Down Expand Up @@ -753,7 +772,7 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi
assert(0);
}

ensure_can_take_address(&expr->data.operands[0], bad_expr_fmt);
ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], bad_expr_fmt);
const Type *t = typecheck_expression_not_void(ft, &expr->data.operands[0])->type;
if (!is_integer_type(t) && !is_pointer_type(t))
fail(expr->location, bad_type_fmt, t->name);
Expand All @@ -775,7 +794,7 @@ static const Type *typecheck_indexing(

const Type *ptrtype;
if (types->type->kind == TYPE_ARRAY) {
cast_array_to_pointer(types);
cast_array_to_pointer(ft->current_fom_types, types);
ptrtype = types->implicit_cast_type;
} else {
if (types->type->kind != TYPE_POINTER)
Expand All @@ -794,7 +813,7 @@ static const Type *typecheck_indexing(

// LLVM assumes that indexes smaller than 64 bits are signed.
// https://github.com/Akuli/jou/issues/48
do_implicit_cast(indextypes, longType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, indextypes, longType, (Location){0}, NULL);

return ptrtype->data.valuetype;
}
Expand Down Expand Up @@ -883,16 +902,16 @@ static const Type *typecheck_function_or_method_call(FileTypes *ft, const AstCal
ExpressionTypes *types = typecheck_expression_not_void(ft, &call->args[i]);

if (types->type->kind == TYPE_ARRAY)
cast_array_to_pointer(types);
cast_array_to_pointer(ft->current_fom_types, types);
else if (
(is_integer_type(types->type) && types->type->data.width_in_bits < 32)
|| types->type == boolType)
{
// Add implicit cast to signed int, just like in C.
do_implicit_cast(types, intType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, types, intType, (Location){0}, NULL);
}
else if (types->type == floatType)
do_implicit_cast(types, doubleType, (Location){0}, NULL);
do_implicit_cast(ft->current_fom_types, types, doubleType, (Location){0}, NULL);
}

free(sigstr);
Expand Down Expand Up @@ -957,7 +976,7 @@ static bool enum_member_exists(const Type *t, const char *name)
return false;
}

static const Type *cast_array_members_to_a_common_type(Location error_location, ExpressionTypes **exprtypes)
static const Type *cast_array_members_to_a_common_type(const FunctionOrMethodTypes *fom, Location error_location, ExpressionTypes **exprtypes)
{
// Avoid O(ntypes^2) code in a long array where all or almost all items have the same type.
// This is at most O(ntypes*k) where k is the number of distinct types.
Expand Down Expand Up @@ -1002,7 +1021,7 @@ static const Type *cast_array_members_to_a_common_type(Location error_location,
free(compatible_with_all.ptr);

for (ExpressionTypes **et = exprtypes; *et; et++)
do_implicit_cast(*et, elemtype, error_location, NULL);
do_implicit_cast(fom, *et, elemtype, error_location, NULL);
return elemtype;
}

Expand Down Expand Up @@ -1044,7 +1063,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
for (int i = 0; i < n; i++)
exprtypes[i] = typecheck_expression_not_void(ft, &expr->data.array.items[i]);

const Type *membertype = cast_array_members_to_a_common_type(expr->location, exprtypes);
const Type *membertype = cast_array_members_to_a_common_type(ft->current_fom_types, expr->location, exprtypes);
free(exprtypes);
result = get_array_type(membertype, n);
}
Expand Down Expand Up @@ -1091,7 +1110,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
tmp, sizeof tmp,
"cannot take address of %%s, needed for calling the %s() method",
expr->data.methodcall.call.calledname);
ensure_can_take_address(expr->data.methodcall.obj, tmp);
ensure_can_take_address(ft->current_fom_types, expr->data.methodcall.obj, tmp);
}
found = true;
break;
Expand All @@ -1106,7 +1125,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
result = typecheck_indexing(ft, &expr->data.operands[0], &expr->data.operands[1]);
break;
case AST_EXPR_ADDRESS_OF:
ensure_can_take_address(&expr->data.operands[0], "the '&' operator cannot be used with %s");
ensure_can_take_address(ft->current_fom_types, &expr->data.operands[0], "the '&' operator cannot be used with %s");
temptype = typecheck_expression_not_void(ft, &expr->data.operands[0])->type;
result = get_pointer_type(temptype);
break;
Expand Down Expand Up @@ -1159,7 +1178,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
{
ExpressionTypes *lhstypes = typecheck_expression_not_void(ft, &expr->data.operands[0]);
ExpressionTypes *rhstypes = typecheck_expression_not_void(ft, &expr->data.operands[1]);
result = check_binop(expr->kind, expr->location, lhstypes, rhstypes);
result = check_binop(ft->current_fom_types, expr->kind, expr->location, lhstypes, rhstypes);
break;
}
case AST_EXPR_PRE_INCREMENT:
Expand All @@ -1172,7 +1191,7 @@ static ExpressionTypes *typecheck_expression(FileTypes *ft, const AstExpression
{
ExpressionTypes *origtypes = typecheck_expression_not_void(ft, expr->data.as.obj);
result = type_from_ast(ft, &expr->data.as.type);
do_explicit_cast(origtypes, result, expr->location);
do_explicit_cast(ft->current_fom_types, origtypes, result, expr->location);
}
break;
}
Expand Down Expand Up @@ -1249,7 +1268,7 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt)
add_variable(ft, types->type, targetexpr->data.varname);
} else {
// Convert value to the type of an existing variable or other assignment target.
ensure_can_take_address(targetexpr, "cannot assign to %s");
ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s");

char errmsg[500];
if (targetexpr->kind == AST_EXPR_DEREFERENCE) {
Expand Down Expand Up @@ -1286,16 +1305,16 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt)
default: assert(0);
}

ensure_can_take_address(targetexpr, "cannot assign to %s");
ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s");
ExpressionTypes *targettypes = typecheck_expression_not_void(ft, targetexpr);
ExpressionTypes *valuetypes = typecheck_expression_not_void(ft, valueexpr);

const Type *t = check_binop(op, stmt->location, targettypes, valuetypes);
const Type *t = check_binop(ft->current_fom_types, op, stmt->location, targettypes, valuetypes);
ExpressionTypes tempvalue_types = { .expr = targetexpr, .type = t };

char msg[500];
snprintf(msg, sizeof msg, "%s produced a value of type FROM which cannot be assigned back to TO", opname);
do_implicit_cast(&tempvalue_types, targettypes->type, stmt->location, msg);
do_implicit_cast(ft->current_fom_types, &tempvalue_types, targettypes->type, stmt->location, msg);

// I think it is currently impossible to cast target.
// If this assert fails, we probably need to add another error message for it.
Expand Down
10 changes: 8 additions & 2 deletions tests/should_succeed/method_by_value.jou
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class Foo:
return Foo{x = self.x + n}

def add_one(self: Foo) -> Foo:
return Foo{x = self.x + 1}
self = Foo{x = self.x + 1}
return self

def print(self: Foo) -> None:
printf("%d\n", self.x)
Expand All @@ -16,8 +17,13 @@ class Foo:
def main() -> int:
Foo{x=5}.add(10).add_one().print() # Output: 16

# test the '->' operator
# The methods get a local copy
f = Foo{x=100}
f.add(123)
f.add_one()
f.print() # Output: 100

# test the '->' operator
(&f)->print() # Output: 100

return 0

0 comments on commit b37d0ff

Please sign in to comment.