Skip to content

Commit

Permalink
Add memswap() function to stdlib/mem.jou (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Dec 11, 2024
1 parent 9dbdea9 commit 9a33db4
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 55 deletions.
9 changes: 2 additions & 7 deletions examples/aoc2023/day07/part1.jou
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import "stdlib/io.jou"
import "stdlib/str.jou"
import "stdlib/mem.jou"

class Hand:
letters: byte[6]
Expand Down Expand Up @@ -110,18 +111,12 @@ class Hand:
return 0


def swap(h1: Hand*, h2: Hand*) -> None:
temp = *h1
*h1 = *h2
*h2 = temp


def sort(hands: Hand*, nhands: int) -> None:
# bubble sort go brrr
for sorted_part_len = 1; sorted_part_len < nhands; sorted_part_len++:
i = sorted_part_len
while i > 0 and hands[i-1].compare(&hands[i]) == 1:
swap(&hands[i-1], &hands[i])
memswap(&hands[i-1], &hands[i], sizeof(hands[i]))
i--


Expand Down
8 changes: 1 addition & 7 deletions examples/aoc2023/day07/part2.jou
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,12 @@ class Hand:
return 0


def swap(h1: Hand*, h2: Hand*) -> None:
temp = *h1
*h1 = *h2
*h2 = temp


def sort(hands: Hand*, nhands: int) -> None:
# bubble sort go brrr
for sorted_part_len = 1; sorted_part_len < nhands; sorted_part_len++:
i = sorted_part_len
while i > 0 and hands[i-1].compare(&hands[i]) == 1:
swap(&hands[i-1], &hands[i])
memswap(&hands[i-1], &hands[i], sizeof(hands[i]))
i--


Expand Down
7 changes: 1 addition & 6 deletions examples/aoc2023/day08/part2.jou
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,13 @@ class List:
self->append(*v)


def swap(a: long*, b: long*) -> None:
temp = *a
*a = *b
*b = temp

def gcd(a: long, b: long) -> long:
assert a > 0 and b > 0
while True:
a %= b
if a == 0:
return b
swap(&a, &b)
memswap(&a, &b, sizeof(a))

def lcm(a: long, b: long) -> long:
return (a/gcd(a,b)) * b
Expand Down
8 changes: 1 addition & 7 deletions examples/aoc2023/day09/part2.jou
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,11 @@ def predict_next(nums: long*, len: int) -> long:
return result


def swap(a: long*, b: long*) -> None:
tmp = *a
*a = *b
*b = tmp


def reverse(nums: long*, len: int) -> None:
p = nums
q = &nums[len-1]
while p < q:
swap(p++, q--)
memswap(p++, q--, sizeof(*p))


# return value is an array terminated by nums_len=-1
Expand Down
8 changes: 1 addition & 7 deletions examples/aoc2023/day20/part2.jou
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,13 @@ def run_part_of_input(start_flip_flop: Module*, end: Module*) -> int[5]:
assert False


def swap(a: long*, b: long*) -> None:
old_a = *a
*a = *b
*b = old_a


def gcd(a: long, b: long) -> long:
assert a > 0 and b > 0
while True:
a %= b
if a == 0:
return b
swap(&a, &b)
memswap(&a, &b, sizeof(a))

def lcm(a: long, b: long) -> long:
return (a/gcd(a,b)) * b
Expand Down
9 changes: 2 additions & 7 deletions examples/aoc2024/day01/part1.jou
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import "stdlib/io.jou"
import "stdlib/math.jou"


def swap(a: int*, b: int*) -> None:
temp = *a
*a = *b
*b = temp
import "stdlib/mem.jou"


def sort(array: int*, length: int) -> None:
Expand All @@ -15,7 +10,7 @@ def sort(array: int*, length: int) -> None:
for i = 1; i < length; i++:
if array[i] < array[smallest]:
smallest = i
swap(&array[0], &array[smallest])
memswap(&array[0], &array[smallest], sizeof(array[0]))
array++
length--

Expand Down
8 changes: 1 addition & 7 deletions examples/aoc2024/day05/part2.jou
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ def job_is_valid(rules: int[2]*, nrules: int, job: int*) -> bool:
return True


def swap(a: int*, b: int*) -> None:
old_a = *a
*a = *b
*b = old_a


def fix_job(rules: int[2]*, nrules: int, job: int*) -> None:
length = 0
while job[length] != -1:
Expand All @@ -94,7 +88,7 @@ def fix_job(rules: int[2]*, nrules: int, job: int*) -> None:
first_idx = find_page(job, rule[0])
second_idx = find_page(job, rule[1])
if first_idx != -1 and second_idx != -1 and first_idx >= second_idx:
swap(&job[first_idx], &job[second_idx])
memswap(&job[first_idx], &job[second_idx], sizeof(job[0]))


def main() -> int:
Expand Down
9 changes: 2 additions & 7 deletions examples/quicksort.jou
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
# Maybe this will some day become a part of the standard library :)

import "stdlib/io.jou"


def swap(a: int*, b: int*) -> None:
temp = *a
*a = *b
*b = temp
import "stdlib/mem.jou"


def print_array(prefix: byte*, arr: int*, length: int) -> None:
Expand Down Expand Up @@ -90,7 +85,7 @@ def quicksort(arr: int*, length: int, depth: int) -> None:
# neither range can expand because of >pivot and <pivot elements
assert arr[end_of_small] > pivot
assert arr[start_of_big - 1] < pivot
swap(&arr[end_of_small], &arr[start_of_big - 1])
memswap(&arr[end_of_small], &arr[start_of_big - 1], sizeof(arr[0]))

# Add back the removed pivot. It becomes a part of the overlap.
arr[length++] = arr[end_of_small]
Expand Down
12 changes: 12 additions & 0 deletions stdlib/mem.jou
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,15 @@ declare memset(dest: void*, fill_byte: int, size: long) -> void*
# - it may be slightly faster.
declare memcpy(dest: void*, source: void*, size: long) -> void* # copy memory, overlaps are UB
declare memmove(dest: void*, source: void*, size: long) -> void* # copy memory, overlaps are ok

# Swaps the contents of two memory regions of the same size.
# This does nothing if the same memory region is passed twice.
# This probably doesn't do what you want if the memory regions overlap in some other way.
def memswap(a: void*, b: void*, size: long) -> None:
a_bytes: byte* = a
b_bytes: byte* = b

for i = 0L; i < size; i++:
old_a = a_bytes[i]
a_bytes[i] = b_bytes[i]
b_bytes[i] = old_a
15 changes: 15 additions & 0 deletions tests/should_succeed/memlibtest.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import "stdlib/mem.jou"
import "stdlib/io.jou"


def main() -> int:
a = 123
b = 456

memswap(&a, &b, sizeof(a))
printf("%d %d\n", a, b) # Output: 456 123

memswap(&a, &a, sizeof(a)) # does nothing
printf("%d %d\n", a, b) # Output: 456 123

return 0

0 comments on commit 9a33db4

Please sign in to comment.