diff --git a/examples/aoc2023/day24/bigint.jou b/examples/aoc2023/day24/bigint.jou index 4ce4acac..0dbb0334 100644 --- a/examples/aoc2023/day24/bigint.jou +++ b/examples/aoc2023/day24/bigint.jou @@ -1,10 +1,6 @@ import "stdlib/mem.jou" -class BigInt: - data: byte[48] # little endian, last byte includes sign bit - - def bigint(value: long) -> BigInt: # assumes little-endian CPU result = BigInt{} @@ -16,211 +12,195 @@ def bigint(value: long) -> BigInt: return result -def bigint_to_long(x: BigInt) -> long: - # assume that value fits into 64-bit long - # also assume little-endian - result: long - memcpy(&result, &x.data, sizeof(result)) - return result +class BigInt: + data: byte[48] # little endian, last byte includes sign bit + def to_long(self: BigInt) -> long: + # assume that value fits into 64-bit long + # also assume little-endian + result: long + x = self # TODO: https://github.com/Akuli/jou/issues/485 + memcpy(&result, &x.data, sizeof(result)) + return result -# TODO: methods are kinda annoying because you'd need temporary values. -# e.g. foo.bar().baz() doesn't work, tries to do take address of bar() return value + def add(self: BigInt, other: BigInt) -> BigInt: + result = bigint(0) + carry_bit = 0 + for i = 0; i < sizeof(self.data); i++: + x = self # TODO: https://github.com/Akuli/jou/issues/485 + result_byte = (x.data[i] as int) + (other.data[i] as int) + carry_bit + if result_byte >= 256: + carry_bit = 1 + else: + carry_bit = 0 + result.data[i] = result_byte as byte -# x+y -def bigadd(x: BigInt, y: BigInt) -> BigInt: - result = bigint(0) - carry_bit = 0 + return result - for i = 0; i < sizeof(x.data); i++: - result_byte = (x.data[i] as int) + (y.data[i] as int) + carry_bit - if result_byte >= 256: - carry_bit = 1 + # -x + def neg(self: BigInt) -> BigInt: + # Flipping all bits (~x) is almost same as negating the value. + # For example, -7 is f9ffffff... and ~7 is f8ffffff... + x = self # TODO: https://github.com/Akuli/jou/issues/485 + for i = 0; i < sizeof(self.data); i++: + x.data[i] = (0xff as byte) - x.data[i] + return x.add(bigint(1)) + + # x-y + def sub(self: BigInt, other: BigInt) -> BigInt: + return self.add(other.neg()) + + # Return values: + # self < other --> -1 + # self == other --> 0 + # self > other --> 1 + def compare(self: BigInt, other: BigInt) -> int: + x = self # TODO: https://github.com/Akuli/jou/issues/485 + self_sign_bit = x.data[sizeof(self.data) - 1] / 128 + other_sign_bit = other.data[sizeof(other.data) - 1] / 128 + + if self_sign_bit != other_sign_bit: + return other_sign_bit - self_sign_bit + + for i = sizeof(self.data) - 1; i >= 0; i--: + if (x.data[i] as int) < (other.data[i] as int): + return -1 + if (x.data[i] as int) > (other.data[i] as int): + return 1 + + return 0 + + # x == y + def equals(self: BigInt, other: BigInt) -> bool: + return self.compare(other) == 0 + + # x == 0 + def is_zero(self: BigInt) -> bool: + return self.compare(bigint(0)) == 0 + + # Return values: + # positive --> 1 + # zero --> 0 + # negative --> -1 + def sign(self: BigInt) -> int: + return self.compare(bigint(0)) + + # |x| + def abs(self: BigInt) -> BigInt: + if self.sign() < 0: + return self.neg() else: - carry_bit = 0 - result.data[i] = result_byte as byte - - return result - - -def bigadd3(x: BigInt, y: BigInt, z: BigInt) -> BigInt: - return bigadd(bigadd(x, y), z) -def bigadd4(x: BigInt, y: BigInt, z: BigInt, zz: BigInt) -> BigInt: - return bigadd(bigadd(bigadd(x, y), z), zz) -def bigadd5(x: BigInt, y: BigInt, z: BigInt, zz: BigInt, zzz: BigInt) -> BigInt: - return bigadd(bigadd(bigadd(bigadd(x, y), z), zz), zzz) -def bigadd6(x: BigInt, y: BigInt, z: BigInt, zz: BigInt, zzz: BigInt, zzzz: BigInt) -> BigInt: - return bigadd(bigadd(bigadd(bigadd(bigadd(x, y), z), zz), zzz), zzzz) - - -# -x -def bigneg(x: BigInt) -> BigInt: - # Flipping all bits (~x) is almost same as negating the value. - # For example, -7 is f9ffffff... and ~7 is f8ffffff... - for i = 0; i < sizeof(x.data); i++: - x.data[i] = (0xff as byte) - x.data[i] - return bigadd(x, bigint(1)) - - -# x-y -def bigsub(x: BigInt, y: BigInt) -> BigInt: - return bigadd(x, bigneg(y)) - - -# Return values: -# x < y --> -1 -# x == y --> 0 -# x > y --> 1 -def bigcmp(x: BigInt, y: BigInt) -> int: - x_sign_bit = x.data[sizeof(x.data) - 1] / 128 - y_sign_bit = y.data[sizeof(y.data) - 1] / 128 - - if x_sign_bit != y_sign_bit: - return y_sign_bit - x_sign_bit - - for i = sizeof(x.data) - 1; i >= 0; i--: - if (x.data[i] as int) < (y.data[i] as int): - return -1 - if (x.data[i] as int) > (y.data[i] as int): - return 1 - - return 0 - - -# x == y -def bigeq(x: BigInt, y: BigInt) -> bool: - return bigcmp(x, y) == 0 - - -# Return values: -# positive --> 1 -# zero --> 0 -# negative --> -1 -def bigsign(x: BigInt) -> int: - return bigcmp(x, bigint(0)) - - -# |x| -def bigabs(x: BigInt) -> BigInt: - if bigsign(x) < 0: - return bigneg(x) - else: - return x - - -# x*y -def bigmul(x: BigInt, y: BigInt) -> BigInt: - result_sign = bigsign(x) * bigsign(y) - x = bigabs(x) - y = bigabs(y) - - result = bigint(0) - for i = 0; i < sizeof(x.data); i++: - for k = 0; i+k < sizeof(result.data); k++: - temp = (x.data[i] as int)*(y.data[k] as int) - - gonna_add = bigint(0) - gonna_add.data[i+k] = temp as byte - if i+k+1 < sizeof(gonna_add.data): - gonna_add.data[i+k+1] = (temp / 256) as byte - result = bigadd(result, gonna_add) - - if bigsign(result) == result_sign: + return self + + # x*y + def mul(self: BigInt, other: BigInt) -> BigInt: + result_sign = self.sign() * other.sign() + self2 = self.abs() # TODO: https://github.com/Akuli/jou/issues/485 + other = other.abs() + + result = bigint(0) + for i = 0; i < sizeof(self2.data); i++: + for k = 0; i+k < sizeof(result.data); k++: + temp = (self2.data[i] as int)*(other.data[k] as int) + + gonna_add = bigint(0) + gonna_add.data[i+k] = temp as byte + if i+k+1 < sizeof(gonna_add.data): + gonna_add.data[i+k+1] = (temp / 256) as byte + result = result.add(gonna_add) + + if result.sign() != result_sign: + result = result.neg() return result - else: - return bigneg(result) - - -# x / 256^n for x >= 0 -def shift_smaller(x: BigInt, n: int) -> BigInt: - assert bigsign(x) >= 0 - assert n >= 0 - if n >= sizeof(x.data): - return bigint(0) - - memmove(&x.data, &x.data[n], sizeof(x.data) - n) - memset(&x.data[sizeof(x.data) - n], 0, n) - return x - - -# x * 256^n for x >= 0 -def shift_bigger(x: BigInt, n: int) -> BigInt: - assert bigsign(x) >= 0 - assert n >= 0 - - if n >= sizeof(x.data): - return bigint(0) - - memmove(&x.data[n], &x.data[0], sizeof(x.data) - n) - memset(&x.data, 0, n) - return x - - -# [x/y, x%y] -def bigdivmod(x: BigInt, y: BigInt) -> BigInt[2]: - assert not bigeq(y, bigint(0)) - - quotient = bigint(0) - remainder = bigabs(x) - yabs = bigabs(y) - - n = 0 - while bigcmp(shift_smaller(remainder, n), yabs) >= 0: - n++ - - assert n < sizeof(quotient.data) - while n --> 0: - # Find nth base-256 digit of result with trial and error. - d = 0 - bigger_y = shift_bigger(yabs, n) - while bigcmp(bigmul(bigger_y, bigint(d+1)), remainder) <= 0: - if d == 0: + # x / 256^n for x >= 0 + def shift_smaller(self: BigInt, n: int) -> BigInt: + assert self.sign() >= 0 + assert n >= 0 + + if n >= sizeof(self.data): + return bigint(0) + + self2 = self # TODO: https://github.com/Akuli/jou/issues/485 + memmove(&self2.data, &self2.data[n], sizeof(self2.data) - n) + memset(&self2.data[sizeof(self2.data) - n], 0, n) + return self2 + + # x * 256^n for x >= 0 + def shift_bigger(self: BigInt, n: int) -> BigInt: + assert self.sign() >= 0 + assert n >= 0 + + if n >= sizeof(self.data): + return bigint(0) + + self2 = self # TODO: https://github.com/Akuli/jou/issues/485 + memmove(&self2.data[n], &self2.data[0], sizeof(self2.data) - n) + memset(&self2.data, 0, n) + return self2 + + # [x/y, x%y] + def divmod(self: BigInt, bottom: BigInt) -> BigInt[2]: + assert not bottom.is_zero() + + quotient = bigint(0) + remainder = self.abs() + bottom_abs = bottom.abs() + + n = 0 + while remainder.shift_smaller(n).compare(bottom_abs) >= 0: + n++ + + assert n < sizeof(quotient.data) + while n --> 0: + # Find nth base-256 digit of result with trial and error. + d = 0 + bigger_bottom = bottom_abs.shift_bigger(n) + while bigger_bottom.mul(bigint(d+1)).compare(remainder) <= 0: + if d == 0: + d++ + else: + d *= 2 + d /= 2 + while bigger_bottom.mul(bigint(d+1)).compare(remainder) <= 0: d++ - else: - d *= 2 - d /= 2 - while bigcmp(bigmul(bigger_y, bigint(d+1)), remainder) <= 0: - d++ - - assert d < 256 - quotient.data[n] = d as byte - remainder = bigsub(remainder, bigmul(bigint(d), bigger_y)) - - if bigsign(x)*bigsign(y) < 0: - quotient = bigneg(quotient) - if bigsign(x) < 0: - remainder = bigneg(remainder) - - # When nonzero remainder, force its sign to be same sign as y, similar to jou % - if bigsign(remainder) != 0 and bigsign(remainder) != bigsign(y): - remainder = bigadd(remainder, y) - quotient = bigsub(quotient, bigint(1)) - - return [quotient, remainder] - - # Tests: - # - # for x = -100; x <= 100; x++: - # for y = -100; y <= 100; y++: - # if y != 0: - # result = bigdivmod(bigint(x), bigint(y)) - # assert x == (x/y)*y + (x%y) - # assert x == bigint_to_long(result[0])*y + bigint_to_long(result[1]) - # assert bigint_to_long(result[0]) == x / y - # assert bigint_to_long(result[1]) == x % y - - -# x / y -def bigdiv(x: BigInt, y: BigInt) -> BigInt: - pair = bigdivmod(x, y) - return pair[0] - - -# assert x % y == 0 -# x / y -def bigdiv_exact(x: BigInt, y: BigInt) -> BigInt: - pair = bigdivmod(x, y) - assert bigeq(pair[1], bigint(0)) - return pair[0] + + assert d < 256 + quotient.data[n] = d as byte + remainder = remainder.sub(bigint(d).mul(bigger_bottom)) + + if self.sign()*bottom.sign() < 0: + quotient = quotient.neg() + if self.sign() < 0: + remainder = remainder.neg() + + # When nonzero remainder, force its sign to be same sign as bottom, similar to jou % + if remainder.sign() == -bottom.sign(): + remainder = remainder.add(bottom) + quotient = quotient.sub(bigint(1)) + + return [quotient, remainder] + + # Tests: + # + # for x = -100; x <= 100; x++: + # for y = -100; y <= 100; y++: + # if y != 0: + # result = bigint(x).divmod(bigint(y)) + # assert x == (x/y)*y + (x%y) + # assert x == result[0].to_long()*y + result[1].to_long() + # assert result[0].to_long() == x / y + # assert result[1].to_long() == x % y + + # self / y + def div(self: BigInt, other: BigInt) -> BigInt: + pair = self.divmod(other) + return pair[0] + + # assert x % y == 0 + # x / y + def div_exact(self: BigInt, other: BigInt) -> BigInt: + pair = self.divmod(other) + assert pair[1].is_zero() + return pair[0] diff --git a/examples/aoc2023/day24/part1.jou b/examples/aoc2023/day24/part1.jou index a43cacd2..2ad745b9 100644 --- a/examples/aoc2023/day24/part1.jou +++ b/examples/aoc2023/day24/part1.jou @@ -7,11 +7,11 @@ def det(matrix: long[2][2]) -> BigInt: b = matrix[0][1] c = matrix[1][0] d = matrix[1][1] - return bigsub(bigmul(bigint(a), bigint(d)), bigmul(bigint(b), bigint(c))) + return bigint(a).mul(bigint(d)).sub(bigint(b).mul(bigint(c))) def dot(v1: long[2], v2: long[2]) -> BigInt: - return bigadd(bigmul(bigint(v1[0]), bigint(v2[0])), bigmul(bigint(v1[1]), bigint(v2[1]))) + return bigint(v1[0]).mul(bigint(v2[0])).add(bigint(v1[1]).mul(bigint(v2[1]))) def matrix_times_vector(matrix: long[2][2], vector: long[2]) -> BigInt[2]: @@ -21,10 +21,7 @@ def matrix_times_vector(matrix: long[2][2], vector: long[2]) -> BigInt[2]: d = bigint(matrix[1][1]) x = bigint(vector[0]) y = bigint(vector[1]) - return [ - bigadd(bigmul(a, x), bigmul(b, y)), - bigadd(bigmul(c, x), bigmul(d, y)), - ] + return [a.mul(x).add(b.mul(y)), c.mul(x).add(d.mul(y))] # Returns the x and y of matrix*[x,y] = coeff_vector as fractions: @@ -37,7 +34,7 @@ def solve_linear_system_of_2_equations(matrix: long[2][2], coeff_vector: long[2] d = matrix[1][1] determinant = det(matrix) - assert not bigeq(determinant, bigint(0)) # assume inverse matrix exists + assert not determinant.is_zero() # assume inverse matrix exists inverse_matrix_times_determinant = [ [d, -b], @@ -62,15 +59,15 @@ class Rectangle: ) def contains_fraction(self, qx: BigInt, qy: BigInt, q: BigInt) -> bool: - qxmin = bigmul(q, bigint(self->x_min)) - qxmax = bigmul(q, bigint(self->x_max)) - qymin = bigmul(q, bigint(self->y_min)) - qymax = bigmul(q, bigint(self->y_max)) + assert q.sign() > 0 - assert bigsign(q) > 0 + qxmin = bigint(self->x_min).mul(q) + qxmax = bigint(self->x_max).mul(q) + qymin = bigint(self->y_min).mul(q) + qymax = bigint(self->y_max).mul(q) return ( - bigcmp(qxmin, qx) <= 0 and bigcmp(qx, qxmax) <= 0 - and bigcmp(qymin, qy) <= 0 and bigcmp(qy, qymax) <= 0 + qxmin.compare(qx) <= 0 and qx.compare(qxmax) <= 0 + and qymin.compare(qy) <= 0 and qy.compare(qymax) <= 0 ) @@ -79,13 +76,13 @@ class Ray: dir: long[2] def intersects(self, other: Ray*, test_area: Rectangle) -> bool: - if bigeq(det([self->dir, other->dir]), bigint(0)): + if det([self->dir, other->dir]).is_zero(): # Rays go in parallel directions. start_diff = [ self->start[0] - other->start[0], self->start[1] - other->start[1], ] - if not bigeq(det([start_diff, self->dir]), bigint(0)): + if not det([start_diff, self->dir]).is_zero(): # Rays are not aligned to go along the same line. return False @@ -95,18 +92,18 @@ class Ray: other_start = dot(self->dir, other->start) other_dir = dot(self->dir, other->dir) - assert bigsign(self_dir) > 0 - assert not bigeq(other_dir, bigint(0)) - assert not bigeq(self_start, other_start) + assert self_dir.sign() > 0 + assert not other_dir.is_zero() + assert not self_start.equals(other_start) - if bigsign(other_dir) > 0: + if other_dir.sign() > 0: # Rays go in the same direction. Eventually one ray will reach the start of the other. - if bigcmp(self_start, other_start) > 0: + if self_start.compare(other_start) > 0: return test_area.contains(self->start) else: return test_area.contains(other->start) - if bigcmp(self_start, other_start) > 0: + if self_start.compare(other_start) > 0: # Rays point away from each other return False @@ -118,11 +115,11 @@ class Ray: # Math gives a solution: # # t = p/q, p = other_start - self_start, q = self_dir - other_dir - q = bigsub(self_dir, other_dir) - assert bigsign(q) > 0 - qt = bigsub(other_start, self_start) - qx = bigadd(bigmul(q, bigint(self->start[0])), bigmul(qt, bigint(self->dir[0]))) - qy = bigadd(bigmul(q, bigint(self->start[1])), bigmul(qt, bigint(self->dir[1]))) + q = self_dir.sub(other_dir) + assert q.sign() > 0 + qt = other_start.sub(self_start) + qx = bigint(self->start[0]).mul(q).add(bigint(self->dir[0]).mul(qt)) + qy = bigint(self->start[1]).mul(q).add(bigint(self->dir[1]).mul(qt)) return test_area.contains_fraction(qx, qy, q) # Vectors are not parallel. They will intersect somewhere, but where? @@ -144,17 +141,17 @@ class Ray: qb = solve_result[1] q = solve_result[2] - if bigsign(q) < 0: - qa = bigneg(qa) - qb = bigneg(qb) - q = bigneg(q) + if q.sign() < 0: + qa = qa.neg() + qb = qb.neg() + q = q.neg() # rays do not extend backwards - if bigsign(qa) < 0 or bigsign(qb) < 0: + if qa.sign() < 0 or qb.sign() < 0: return False - qx = bigadd(bigmul(q, bigint(self->start[0])), bigmul(qa, bigint(self->dir[0]))) - qy = bigadd(bigmul(q, bigint(self->start[1])), bigmul(qa, bigint(self->dir[1]))) + qx = bigint(self->start[0]).mul(q).add(bigint(self->dir[0]).mul(qa)) + qy = bigint(self->start[1]).mul(q).add(bigint(self->dir[1]).mul(qa)) return test_area.contains_fraction(qx, qy, q)