Skip to content

Commit

Permalink
♻️ Refactor using first-class constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 1, 2024
1 parent 05c2ef7 commit 86bd7da
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 110 deletions.
184 changes: 75 additions & 109 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,41 @@ let gen_vertex =

let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t =
let sub' exp = sub exp x det_exp in
let s2 ctor e1 e2 = ctor (sub' e1) (sub' e2) in
let s1 ctor e = ctor (sub' e) in
let open Exp in
match exp with
| Int n -> Int n
| Real r -> Real r
| Var y when Id.(x = y) -> Exp.of_det_exp det_exp
| Var y -> Var y
| Add (e1, e2) -> Add (sub' e1, sub' e2)
| Radd (e1, e2) -> Radd (sub' e1, sub' e2)
| Minus (e1, e2) -> Minus (sub' e1, sub' e2)
| Rminus (e1, e2) -> Rminus (sub' e1, sub' e2)
| Neg e -> Neg (sub' e)
| Rneg e -> Rneg (sub' e)
| Mult (e1, e2) -> Mult (sub' e1, sub' e2)
| Rmult (e1, e2) -> Rmult (sub' e1, sub' e2)
| Div (e1, e2) -> Div (sub' e1, sub' e2)
| Rdiv (e1, e2) -> Rdiv (sub' e1, sub' e2)
| Eq (e1, e2) -> Eq (sub' e1, sub' e2)
| Noteq (e1, e2) -> Noteq (sub' e1, sub' e2)
| Less (e1, e2) -> Less (sub' e1, sub' e2)
| And (e1, e2) -> And (sub' e1, sub' e2)
| Or (e1, e2) -> Or (sub' e1, sub' e2)
| Seq (e1, e2) -> Seq (sub' e1, sub' e2)
| Not e -> Not (sub' e)
| Int _ | Real _ | Var _ -> exp
| Add (e1, e2) -> s2 add e1 e2
| Radd (e1, e2) -> s2 radd e1 e2
| Minus (e1, e2) -> s2 minus e1 e2
| Rminus (e1, e2) -> s2 rminus e1 e2
| Neg e -> s1 neg e
| Rneg e -> s1 rneg e
| Mult (e1, e2) -> s2 mult e1 e2
| Rmult (e1, e2) -> s2 rmult e1 e2
| Div (e1, e2) -> s2 div e1 e2
| Rdiv (e1, e2) -> s2 rdiv e1 e2
| Eq (e1, e2) -> s2 eq e1 e2
| Noteq (e1, e2) -> s2 noteq e1 e2
| Less (e1, e2) -> s2 less e1 e2
| And (e1, e2) -> s2 and_ e1 e2
| Or (e1, e2) -> s2 or_ e1 e2
| Seq (e1, e2) -> s2 seq e1 e2
| Not e -> s1 not e
| List es -> List (List.map es ~f:sub')
| Record fs -> Record (List.map fs ~f:(fun (f, e) -> (f, sub' e)))
| Assign (y, e, body) when Id.(x = y) -> Assign (y, sub' e, body)
| Assign (y, e, body) when not (Set.mem (Det_exp.fv det_exp) y) ->
| Assign (y, e, body) when Core.not (Set.mem (Det_exp.fv det_exp) y) ->
Assign (y, sub' e, sub' body)
| Assign (y, e, body) ->
let z = gen_sym () in
Assign (z, sub' e, sub' @@ sub body y (Det_exp.Var z))
| If (e_pred, e_con, e_alt) -> If (sub' e_pred, sub' e_con, sub' e_alt)
| Call (f, es) -> Call (f, List.map es ~f:sub')
| Sample e -> Sample (sub' e)
| Observe (e1, e2) -> Observe (sub' e1, sub' e2)
| Sample e -> s1 sample e
| Observe (e1, e2) -> s2 observe e1 e2

let gather_functions (prog : program) : Env.t =
List.fold prog.funs ~init:Env.empty ~f:(fun env f ->
Expand All @@ -58,19 +59,33 @@ exception Not_closed_observation

let compile (program : program) : Graph.t * Det_exp.t =
let env = gather_functions program in
let rec compile pred =

let rec compile (pred : Pred.t) : Exp.t -> Graph.t * Det_exp.t =
let compile' e = compile pred e in
let open Graph in

let c0 x = (Graph.empty, x) in
let c1 e ctor =
let g, de = compile' e in
(g, ctor de)
in
let c2 e1 e2 ctor =
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
Graph.(g1 @+ g2, ctor de1 de2)
in

let open Det_exp in
function
| Exp.Int n -> (Graph.empty, Det_exp.Int n)
| Real r -> (Graph.empty, Det_exp.Real r)
| Var x -> (Graph.empty, Det_exp.Var x)
| Int n -> c0 (Int n)
| Real r -> c0 (Real r)
| Var x -> c0 (Var x)
| Sample e ->
let g, de = compile' e in
let v = gen_vertex () in
let de_fvs = Det_exp.fv de in
let de_fvs = fv de in
let f = Dist.score de v in

let open Graph in
let g' =
{
vertices = [ v ];
Expand All @@ -79,16 +94,17 @@ let compile (program : program) : Graph.t * Det_exp.t =
obs_map = Id.Map.empty;
}
in
(g @+ g', Det_exp.Var v)
(g @+ g', Var v)
| Observe (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
let v = gen_vertex () in
let f1 = Dist.score de1 v in
let f = Dist.(If_pred (pred, f1, One)) in
let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in
if not @@ Set.is_empty (Det_exp.fv de2) then
raise Not_closed_observation;
let fvs = Set.union (fv de1) (Pred.fv pred) in
if Core.not @@ Set.is_empty (fv de2) then raise Not_closed_observation;

let open Graph in
let g' =
{
vertices = [ v ];
Expand All @@ -102,23 +118,21 @@ let compile (program : program) : Graph.t * Det_exp.t =
let g1, det_exp1 = compile' e in
let sub_body = sub body x det_exp1 in
let g2, det_exp2 = compile' sub_body in
let g = g1 @+ g2 in
(g, det_exp2)
Graph.(g1 @+ g2, det_exp2)
| If (e_pred, e_con, e_alt) ->
let g1, det_exp_pred = compile' e_pred in
let pred_true = Pred.And (det_exp_pred, pred) in
let pred_false = Pred.And_not (det_exp_pred, pred) in
let g2, det_exp_con = compile pred_true e_con in
let g3, det_exp_alt = compile pred_false e_alt in
let g = g1 @+ g2 @+ g3 in
(g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt))
Graph.(g1 @+ g2 @+ g3, If (det_exp_pred, det_exp_con, det_exp_alt))
| Call (c, params) -> (
match Env.find env ~name:c with
| Some f ->
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
Graph.(g @+ g', de))
in
let { params; body; _ } = f in
let param_det_pairs = List.zip_exn params det_exps in
Expand All @@ -127,93 +141,45 @@ let compile (program : program) : Graph.t * Det_exp.t =
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile' sub_body in
(g @+ g_body, det_exp_body)
Graph.(g @+ g_body, det_exp_body)
| None ->
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
Graph.(g @+ g', de))
in
(g, Prim_call (c, det_exps)))
| Add (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Add (de1, de2))
| Radd (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Radd (de1, de2))
| Minus (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Minus (de1, de2))
| Rminus (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rminus (de1, de2))
| Neg e ->
let g, de = compile' e in
(g, Det_exp.Neg de)
| Rneg e ->
let g, de = compile' e in
(g, Det_exp.Rneg de)
| Mult (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Mult (de1, de2))
| Rmult (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rmult (de1, de2))
| Div (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Div (de1, de2))
| Rdiv (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rdiv (de1, de2))
| Eq (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Eq (de1, de2))
| Noteq (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Noteq (de1, de2))
| Less (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Less (de1, de2))
| And (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.And (de1, de2))
| Or (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Or (de1, de2))
| Seq (e1, e2) ->
let g1, _ = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, de2)
| Not e ->
let g, de = compile' e in
(g, Det_exp.Not de)
| Add (e1, e2) -> c2 e1 e2 add
| Radd (e1, e2) -> c2 e1 e2 radd
| Minus (e1, e2) -> c2 e1 e2 minus
| Rminus (e1, e2) -> c2 e1 e2 rminus
| Neg e -> c1 e neg
| Rneg e -> c1 e rneg
| Mult (e1, e2) -> c2 e1 e2 mult
| Rmult (e1, e2) -> c2 e1 e2 rmult
| Div (e1, e2) -> c2 e1 e2 div
| Rdiv (e1, e2) -> c2 e1 e2 rdiv
| Eq (e1, e2) -> c2 e1 e2 eq
| Noteq (e1, e2) -> c2 e1 e2 noteq
| Less (e1, e2) -> c2 e1 e2 less
| And (e1, e2) -> c2 e1 e2 and_
| Or (e1, e2) -> c2 e1 e2 or_
| Seq (e1, e2) -> c2 e1 e2 (fun _ de -> de)
| Not e -> c1 e not
| List es ->
let g, des =
List.fold_map es ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
Graph.(g @+ g', de))
in
(g, Det_exp.List des)
(g, List des)
| Record fields ->
let g, des =
List.fold_map fields ~init:Graph.empty ~f:(fun g (k, v) ->
let g_k, de_k = compile' k in
let g_v, de_v = compile' v in
(g @+ g_k @+ g_v, (de_k, de_v)))
Graph.(g @+ g_k @+ g_v, (de_k, de_v)))
in
(g, Det_exp.Record des)
(g, Record des)
in
compile Pred.Empty program.exp
3 changes: 2 additions & 1 deletion lib/program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ module Det_exp = struct
| Record of (t * t) list
| If of t * t * t
| Prim_call of Id.t * t list
[@@deriving sexp, stable_variant]
[@@deriving sexp, variants, stable_variant]

let rec fv : t -> (Id.t, Id.comparator_witness) Set.t = function
| Int _ | Real _ -> Set.empty (module Id)
Expand Down Expand Up @@ -92,6 +92,7 @@ module Exp = struct
| Observe of t * t
[@@deriving
sexp,
variants,
stable_variant ~version:Det_exp.t
~remove:[ Call; Seq; Assign; Sample; Observe ]
~add:[ Prim_call ]]
Expand Down

0 comments on commit 86bd7da

Please sign in to comment.