Skip to content

Commit

Permalink
Initial compiler! 1.5x to 10x speedups
Browse files Browse the repository at this point in the history
This commit implements compilation of HVM terms. The idea, as explained
in an early commit, is to modify the `@F ~ X` deref rule, so that,
instead of unrolling `@F` and letting reductions happen normally, we
instead pass `X` to a compiled procedure that attempts to perform some
local reductions BEFORE allocating `@F`. For example, the HVM term:

add = λa λb (+ a b)

Is compiled to HVM as:

@add = (a (b r))
& HigherOrderCO#1 ~ <a <b r>>

So, if we apply `(add 123 100)`, we will have the following net:

    (HigherOrderCO#123 (HigherOrderCO#100 R)) ~ ⟦(a (b r))⟧
    ⟦HigherOrderCO#1⟧            ~ ⟦<a <b r>>⟧

Notice that `(HigherOrderCO#123 (HigherOrderCO#100 R))` is a dynamic net, and everything inside
these «angle_brackets» is part of the "source code" of `@add`, i.e.,
these are static nets. As such, once we send `X` to the compiled
`add()`, we can immediatelly detect that `X` isn't an aux port, but is
actually the main port of the `(HigherOrderCO#123 (HigherOrderCO#100 R))` tree. Furthermore, since
the root of `@add` is also two CON nodes (representing `λa` and `λb`),
we can immediatelly substitute `a <~ HigherOrderCO#123`, `b <~ HigherOrderCO#100`, and `r <~ R`,
performing two "local annihilations" before allocating the body of
`@add`. As a result, we'll have the following net:

    R    ~ ⟦r⟧
    ⟦HigherOrderCO#1⟧ ~ ⟦<HigherOrderCO#123 <HigherOrderCO#100 r>>⟧

Now, we have an OP2 node connected to the number HigherOrderCO#1. Normally, that
would require 4 rewrites to reduce to normal form:

    R  ~ r
    HigherOrderCO#1 ~ <HigherOrderCO#123 <HigherOrderCO#100 r>>
    -------------------- OP2
    R    ~ r
    HigherOrderCO#123 ~ <HigherOrderCO#1 <HigherOrderCO#100 r>>
    -------------------- OP1
    R     ~ r
    #+123 ~ <HigherOrderCO#100 r>
    ---------------- OP2
    R    ~ r
    HigherOrderCO#100 ~ <#+123 r>
    ---------------- OP1
    R    ~ r
    #223 ~ r
    -------- subst
    #223 ~ R

Yet, the compiled `add()` function can see, on its local registers, that
`op = HigherOrderCO#1`, `a = HigherOrderCO#123` and `b = HigherOrderCO#100`. As such, it doesn't need to
allocate any OP2 node, and can shortcut the reduction directly to:

    R    ~ ⟦r⟧
    ⟦HigherOrderCO#1⟧ ~ ⟦<HigherOrderCO#123 <HigherOrderCO#100 r>>⟧
    ------------------------ OP2 + OP1 + OP2 + OP1
    R ~ #223

Which bypasses the runtime entirely, saving several allocations and
redex pushing/popping/matching.

Sadly, Rust functions, unlike interaction nets, obey an evaluation
order. As such, keeping a mini "local interaction net runtime" on
registers would be inpractical. As such, we make a choice on the order
that we traverse the "static net"; specifically, we first traverse the
root tree, then the redex trees, in order. This is relevant, because it
means that the order matters for which optimizations are used. For
example, in this case, if we first traversed the redex trees, we'd have:

    ⟦HigherOrderCO#1⟧            ~ ⟦<a <b r>>⟧
    (HigherOrderCO#123 (HigherOrderCO#100 R)) ~ ⟦(a (b r))⟧
    ----------------------------- alloc `<a <b r>>`
    HigherOrderCO#1              ~ <a <b r>>
    (HigherOrderCO#123 (HigherOrderCO#100 R)) ~ ⟦(a (b r))⟧
    ----------------------------- alloc `(a (b r))`
    HigherOrderCO#1              ~ <a <b r>>
    (HigherOrderCO#123 (HigherOrderCO#100 R)) ~ (a (b r))
    ---------------------------
    ... proceed reduction as normal

I.e., when traversing `HigherOrderCO#1 ~ <a <b r>>`, the compiled `add()` function
would see `a` and `b` (i.e., aux ports) instead of concrete numbers and,
as such, it would be forced to allocate 2 OP2 nodes, `<a <b r>>`, and
the optimization would fail, causing it to fall back to the interpreted
speed. As such, it is important that tools emitting HVMC code to sort
redexes in a way that allows optimizations to be performed more often.
If redexes are sorted respecting the corresponding "strict evaluation"
order, then functions compiled from classical paradigm should always hit
the optimization case.

For illustration, here is the compiled `add()` procedure:

pub fn F_add(&mut self, ptr: Ptr, x: Ptr) -> bool {
  let xx : Ptr;
  let xy : Ptr;
  // fast apply
  if x.tag() == CT0 {
    self.anni += 1;
    xx = self.heap.get(x.val(), P1);
    xy = self.heap.get(x.val(), P2);
    self.heap.free(x.val());
  } else {
    let k1 = self.heap.alloc(1);
    xx = Ptr::new(VR1, k1);
    xy = Ptr::new(VR2, k1);
    self.link(Ptr::new(CT0, k1), x);
  }
  let xyx : Ptr;
  let xyy : Ptr;
  // fast apply
  if xy.tag() == CT0 {
    self.anni += 1;
    xyx = self.heap.get(xy.val(), P1);
    xyy = self.heap.get(xy.val(), P2);
    self.heap.free(xy.val());
  } else {
    let k2 = self.heap.alloc(1);
    xyx = Ptr::new(VR1, k2);
    xyy = Ptr::new(VR2, k2);
    self.link(Ptr::new(CT0, k2), xy);
  }
  let _k3 = Ptr::new(NUM, 0x1);
  let k4 : Ptr;
  // fast op
  if _k3.is_num() && xx.is_num() && xyx.is_num() {
    self.oper += 4;
    k4 = Ptr::new(NUM, self.op(self.op(_k3.val(),xx.val()),xyx.val()));
  } else {
    let k5 = self.heap.alloc(1);
    let k6 = self.heap.alloc(1);
    self.heap.set(k5, P2, Ptr::new(OP2, k6));
    self.link(Ptr::new(VR1,k5), xx);
    self.link(Ptr::new(VR1,k6), xyx);
    self.link(Ptr::new(OP2,k5), _k3);
    k4 = Ptr::new(VR2, k6);
  }
  self.link(k4, xyy);
  return true;
}

Each optimization branch is labelled with a comment. The more
optimization branches are hit, the faster your program will be.

This commit results in a 1.55x speedup in the 'burn' benchmark (the one
that decrements λ-encoded bits in parallel), a 2.94x speedup in a tree
recursive sum, and a 5.64x speedup in a tail recursive sum. Note that
tail recursion was NOT implemented yet, and there are still some
allocations that can be skipped. With a better codegen, the maximum
theoretical speedup should be of around 36x, which is what we obtain by
manually polishing the generated functions.
  • Loading branch information
VictorTaelin committed Nov 12, 2023
1 parent b47c777 commit fb5a2a9
Show file tree
Hide file tree
Showing 20 changed files with 288 additions and 327 deletions.
7 changes: 0 additions & 7 deletions examples/alloc_big_tree.hvmc

This file was deleted.

6 changes: 5 additions & 1 deletion examples/dec_bits_tree.hvmc → examples/burn.hvmc
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Decreases a tree of binary counters until they're all 0 (parallel)
// Decreases a tree of λ-encoded binary counters until they're all 0 (parallel).
// Takes about ~16s on Apple M1, and ~0.5s on RTX 4090

@c4 = ([[[(d c) (c b)] (b a)] (a R)] (d R))
@c6 = ([[[[[(f e) (e d)] (d c)] (c b)] (b a)] (a R)] (f R))
@c8 = ([[[[[[[(h g) (g f)] (f e)] (e d)] (d c)] (c b)] (b a)] (a R)] (h R))
@c10 = ([[[[[[[[[(j i) (i h)] (h g)] (g f)] (f e)] (e d)] (d c)] (c b)] (b a)] (a R)] (j R))
@c12 = ([[[[[[[[[[[(l k) (k j)] (j i)] (i h)] (h g)] (g f)] (f e)] (e d)] (d c)] (c b)] (b a)] (a R)] (l R))
@c14 = ([[[[[[[[[[[[[(n m) (m l)] (l k)] (k j)] (j i)] (i h)] (h g)] (g f)] (f e)] (e d)] (d c)] (c b)] (b a)] (a R)] (n R))
@c16 = ([[[[[[[[[[[[[[[(p o) (o n)] (n m)] (m l)] (l k)] (k j)] (j i)] (i h)] (h g)] (g f)] (f e)] (e d)] (d c)] (c b)] (b a)] (a R)] (p R))

Expand Down
12 changes: 12 additions & 0 deletions examples/church.hvm2
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
S = λn λs λz (s (n s z))
Z = λs λz z

c2 = λf λx (f (f x))
c3 = λf λx (f (f (f x)))
c4 = (S (S (S (S Z))))

add = λa λb λs λz (a s (b s z))
mul = λa λb λs λz (a (b s) z)

// 2 * 3 + 4
main = (add (mul c2 c3) c4)
21 changes: 21 additions & 0 deletions examples/church.hvmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@S = ((a (b c)) ({2 a (c d)} (b d)))
@Z = (* (a a))

@add = ((a (b c)) ((d (e b)) ({2 d a} (e c))))
@mul = ((a (b c)) ((d a) (d (b c))))

@c2 = ({2 (a b) (b c)} (a c))
@c3 = ({2 (a b) {3 (b c) (c d)}} (a d))

@c4
= a
& @S ~ (b a)
& @S ~ (c b)
& @S ~ (d c)
& @S ~ (@Z d)

@main
= a
& @add ~ (b (@c4 a))
& @mul ~ (@c2 (@c3 b))

10 changes: 0 additions & 10 deletions examples/church_exp.hvmc

This file was deleted.

24 changes: 0 additions & 24 deletions examples/dec_bits.hvmc

This file was deleted.

117 changes: 0 additions & 117 deletions examples/examples.hvmc

This file was deleted.

46 changes: 0 additions & 46 deletions examples/loop.hvmc

This file was deleted.

2 changes: 2 additions & 0 deletions examples/num_add.hvm2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
add = λa λb (+ a b)
main = (add 123 100)
6 changes: 6 additions & 0 deletions examples/num_add.hvmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// #1 represents addition
@add = (a (b R))
& #1 ~ <a <b R>>

@main = R
& @add ~ (#123 (#100 R))
6 changes: 6 additions & 0 deletions examples/num_match.hvm2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pred = λx match x {
0 : 0
1+p : p
}

main = (pred 10)
6 changes: 6 additions & 0 deletions examples/num_match.hvmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@pred = (? (#0 (x x)) R R)

@main
= R
& @pred ~ (#10 R)

5 changes: 0 additions & 5 deletions examples/num_mt_ex.hvmc

This file was deleted.

6 changes: 0 additions & 6 deletions examples/num_op_ex.hvmc

This file was deleted.

6 changes: 6 additions & 0 deletions examples/sum_rec.hvm2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
sum = λn match n {
0 : 1
1+p : (+ (sum p) (sum p))
}

main = (sum 24)
9 changes: 9 additions & 0 deletions examples/sum_rec.hvmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@sum = (? (#1 @sumS) a a)

@sumS = ({2 a b} c)
& @sum ~ (a e)
& @sum ~ (b d)
& #1 ~ <d <e c>>

@main = R
& @sum ~ (#24 R)
6 changes: 6 additions & 0 deletions examples/sum_tail.hvm2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
sum = λa match a {
0 : λs s
1+p : λs (sum p (+ p s))
}

main = (sum 10000000 0)
10 changes: 10 additions & 0 deletions examples/sum_tail.hvmc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@sum = (? (@sumZ @sumS) a a)

@sumZ = (a a)

@sumS = ({2 a b} (c d))
& @sum ~ (b (e d))
& #1 ~ <a <c e>>

@main = R
& @sum ~ (#10000000 (#0 R))
Loading

0 comments on commit fb5a2a9

Please sign in to comment.