diff --git a/lib/compiler.ml b/lib/compiler.ml index 9f4755c..24f3913 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -15,40 +15,41 @@ let gen_vertex = let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t = let sub' exp = sub exp x det_exp in + let s2 ctor e1 e2 = ctor (sub' e1) (sub' e2) in + let s1 ctor e = ctor (sub' e) in + let open Exp in match exp with - | Int n -> Int n - | Real r -> Real r | Var y when Id.(x = y) -> Exp.of_det_exp det_exp - | Var y -> Var y - | Add (e1, e2) -> Add (sub' e1, sub' e2) - | Radd (e1, e2) -> Radd (sub' e1, sub' e2) - | Minus (e1, e2) -> Minus (sub' e1, sub' e2) - | Rminus (e1, e2) -> Rminus (sub' e1, sub' e2) - | Neg e -> Neg (sub' e) - | Rneg e -> Rneg (sub' e) - | Mult (e1, e2) -> Mult (sub' e1, sub' e2) - | Rmult (e1, e2) -> Rmult (sub' e1, sub' e2) - | Div (e1, e2) -> Div (sub' e1, sub' e2) - | Rdiv (e1, e2) -> Rdiv (sub' e1, sub' e2) - | Eq (e1, e2) -> Eq (sub' e1, sub' e2) - | Noteq (e1, e2) -> Noteq (sub' e1, sub' e2) - | Less (e1, e2) -> Less (sub' e1, sub' e2) - | And (e1, e2) -> And (sub' e1, sub' e2) - | Or (e1, e2) -> Or (sub' e1, sub' e2) - | Seq (e1, e2) -> Seq (sub' e1, sub' e2) - | Not e -> Not (sub' e) + | Int _ | Real _ | Var _ -> exp + | Add (e1, e2) -> s2 add e1 e2 + | Radd (e1, e2) -> s2 radd e1 e2 + | Minus (e1, e2) -> s2 minus e1 e2 + | Rminus (e1, e2) -> s2 rminus e1 e2 + | Neg e -> s1 neg e + | Rneg e -> s1 rneg e + | Mult (e1, e2) -> s2 mult e1 e2 + | Rmult (e1, e2) -> s2 rmult e1 e2 + | Div (e1, e2) -> s2 div e1 e2 + | Rdiv (e1, e2) -> s2 rdiv e1 e2 + | Eq (e1, e2) -> s2 eq e1 e2 + | Noteq (e1, e2) -> s2 noteq e1 e2 + | Less (e1, e2) -> s2 less e1 e2 + | And (e1, e2) -> s2 and_ e1 e2 + | Or (e1, e2) -> s2 or_ e1 e2 + | Seq (e1, e2) -> s2 seq e1 e2 + | Not e -> s1 not e | List es -> List (List.map es ~f:sub') | Record fs -> Record (List.map fs ~f:(fun (f, e) -> (f, sub' e))) | Assign (y, e, body) when Id.(x = y) -> Assign (y, sub' e, body) - | Assign (y, e, body) when not (Set.mem (Det_exp.fv det_exp) y) -> + | Assign (y, e, body) when Core.not (Set.mem (Det_exp.fv det_exp) y) -> Assign (y, sub' e, sub' body) | Assign (y, e, body) -> let z = gen_sym () in Assign (z, sub' e, sub' @@ sub body y (Det_exp.Var z)) | If (e_pred, e_con, e_alt) -> If (sub' e_pred, sub' e_con, sub' e_alt) | Call (f, es) -> Call (f, List.map es ~f:sub') - | Sample e -> Sample (sub' e) - | Observe (e1, e2) -> Observe (sub' e1, sub' e2) + | Sample e -> s1 sample e + | Observe (e1, e2) -> s2 observe e1 e2 let gather_functions (prog : program) : Env.t = List.fold prog.funs ~init:Env.empty ~f:(fun env f -> @@ -58,19 +59,33 @@ exception Not_closed_observation let compile (program : program) : Graph.t * Det_exp.t = let env = gather_functions program in - let rec compile pred = + + let rec compile (pred : Pred.t) : Exp.t -> Graph.t * Det_exp.t = let compile' e = compile pred e in - let open Graph in + + let c0 x = (Graph.empty, x) in + let c1 e ctor = + let g, de = compile' e in + (g, ctor de) + in + let c2 e1 e2 ctor = + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + Graph.(g1 @+ g2, ctor de1 de2) + in + + let open Det_exp in function - | Exp.Int n -> (Graph.empty, Det_exp.Int n) - | Real r -> (Graph.empty, Det_exp.Real r) - | Var x -> (Graph.empty, Det_exp.Var x) + | Int n -> c0 (Int n) + | Real r -> c0 (Real r) + | Var x -> c0 (Var x) | Sample e -> let g, de = compile' e in let v = gen_vertex () in - let de_fvs = Det_exp.fv de in + let de_fvs = fv de in let f = Dist.score de v in + let open Graph in let g' = { vertices = [ v ]; @@ -79,16 +94,17 @@ let compile (program : program) : Graph.t * Det_exp.t = obs_map = Id.Map.empty; } in - (g @+ g', Det_exp.Var v) + (g @+ g', Var v) | Observe (e1, e2) -> let g1, de1 = compile' e1 in let g2, de2 = compile' e2 in let v = gen_vertex () in let f1 = Dist.score de1 v in let f = Dist.(If_pred (pred, f1, One)) in - let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in - if not @@ Set.is_empty (Det_exp.fv de2) then - raise Not_closed_observation; + let fvs = Set.union (fv de1) (Pred.fv pred) in + if Core.not @@ Set.is_empty (fv de2) then raise Not_closed_observation; + + let open Graph in let g' = { vertices = [ v ]; @@ -102,23 +118,21 @@ let compile (program : program) : Graph.t * Det_exp.t = let g1, det_exp1 = compile' e in let sub_body = sub body x det_exp1 in let g2, det_exp2 = compile' sub_body in - let g = g1 @+ g2 in - (g, det_exp2) + Graph.(g1 @+ g2, det_exp2) | If (e_pred, e_con, e_alt) -> let g1, det_exp_pred = compile' e_pred in let pred_true = Pred.And (det_exp_pred, pred) in let pred_false = Pred.And_not (det_exp_pred, pred) in let g2, det_exp_con = compile pred_true e_con in let g3, det_exp_alt = compile pred_false e_alt in - let g = g1 @+ g2 @+ g3 in - (g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt)) + Graph.(g1 @+ g2 @+ g3, If (det_exp_pred, det_exp_con, det_exp_alt)) | Call (c, params) -> ( match Env.find env ~name:c with | Some f -> let g, det_exps = List.fold_map params ~init:Graph.empty ~f:(fun g e -> let g', de = compile' e in - (g @+ g', de)) + Graph.(g @+ g', de)) in let { params; body; _ } = f in let param_det_pairs = List.zip_exn params det_exps in @@ -127,93 +141,45 @@ let compile (program : program) : Graph.t * Det_exp.t = ~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp) in let g_body, det_exp_body = compile' sub_body in - (g @+ g_body, det_exp_body) + Graph.(g @+ g_body, det_exp_body) | None -> let g, det_exps = List.fold_map params ~init:Graph.empty ~f:(fun g e -> let g', de = compile' e in - (g @+ g', de)) + Graph.(g @+ g', de)) in (g, Prim_call (c, det_exps))) - | Add (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Add (de1, de2)) - | Radd (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Radd (de1, de2)) - | Minus (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Minus (de1, de2)) - | Rminus (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Rminus (de1, de2)) - | Neg e -> - let g, de = compile' e in - (g, Det_exp.Neg de) - | Rneg e -> - let g, de = compile' e in - (g, Det_exp.Rneg de) - | Mult (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Mult (de1, de2)) - | Rmult (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Rmult (de1, de2)) - | Div (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Div (de1, de2)) - | Rdiv (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Rdiv (de1, de2)) - | Eq (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Eq (de1, de2)) - | Noteq (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Noteq (de1, de2)) - | Less (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Less (de1, de2)) - | And (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.And (de1, de2)) - | Or (e1, e2) -> - let g1, de1 = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, Det_exp.Or (de1, de2)) - | Seq (e1, e2) -> - let g1, _ = compile' e1 in - let g2, de2 = compile' e2 in - (g1 @+ g2, de2) - | Not e -> - let g, de = compile' e in - (g, Det_exp.Not de) + | Add (e1, e2) -> c2 e1 e2 add + | Radd (e1, e2) -> c2 e1 e2 radd + | Minus (e1, e2) -> c2 e1 e2 minus + | Rminus (e1, e2) -> c2 e1 e2 rminus + | Neg e -> c1 e neg + | Rneg e -> c1 e rneg + | Mult (e1, e2) -> c2 e1 e2 mult + | Rmult (e1, e2) -> c2 e1 e2 rmult + | Div (e1, e2) -> c2 e1 e2 div + | Rdiv (e1, e2) -> c2 e1 e2 rdiv + | Eq (e1, e2) -> c2 e1 e2 eq + | Noteq (e1, e2) -> c2 e1 e2 noteq + | Less (e1, e2) -> c2 e1 e2 less + | And (e1, e2) -> c2 e1 e2 and_ + | Or (e1, e2) -> c2 e1 e2 or_ + | Seq (e1, e2) -> c2 e1 e2 (fun _ de -> de) + | Not e -> c1 e not | List es -> let g, des = List.fold_map es ~init:Graph.empty ~f:(fun g e -> let g', de = compile' e in - (g @+ g', de)) + Graph.(g @+ g', de)) in - (g, Det_exp.List des) + (g, List des) | Record fields -> let g, des = List.fold_map fields ~init:Graph.empty ~f:(fun g (k, v) -> let g_k, de_k = compile' k in let g_v, de_v = compile' v in - (g @+ g_k @+ g_v, (de_k, de_v))) + Graph.(g @+ g_k @+ g_v, (de_k, de_v))) in - (g, Det_exp.Record des) + (g, Record des) in compile Pred.Empty program.exp diff --git a/lib/program.ml b/lib/program.ml index 0158423..a1e8691 100644 --- a/lib/program.ml +++ b/lib/program.ml @@ -26,7 +26,7 @@ module Det_exp = struct | Record of (t * t) list | If of t * t * t | Prim_call of Id.t * t list - [@@deriving sexp, stable_variant] + [@@deriving sexp, variants, stable_variant] let rec fv : t -> (Id.t, Id.comparator_witness) Set.t = function | Int _ | Real _ -> Set.empty (module Id) @@ -92,6 +92,7 @@ module Exp = struct | Observe of t * t [@@deriving sexp, + variants, stable_variant ~version:Det_exp.t ~remove:[ Call; Seq; Assign; Sample; Observe ] ~add:[ Prim_call ]]