Skip to content

Commit

Permalink
♻️ Don't carry around environments
Browse files Browse the repository at this point in the history
  • Loading branch information
yhs0602 committed Jun 1, 2024
1 parent 15e1926 commit 02703fe
Showing 1 changed file with 152 additions and 71 deletions.
223 changes: 152 additions & 71 deletions lib/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,74 +132,155 @@ let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t =

exception Not_closed_observation

let rec compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) :
Graph.t * Det_exp.t =
ignore env;
ignore pred;
match exp with
| Int n -> (Graph.empty, Det_exp.Int n)
| Real r -> (Graph.empty, Det_exp.Real r)
| Var x -> (Graph.empty, Det_exp.Var x)
| Sample e ->
let g, de = compile env pred e in
let v = gen_vertex () in
let de_fvs = Det_exp.fv de in
let f = Dist.score de v in
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
}
in
(g @+ g', Det_exp.Var v)
| Observe (e1, e2) ->
let g1, de1 = compile env pred e1 in
let g2, de2 = compile env pred e2 in
let v = gen_vertex () in
let f1 = Dist.score de1 v in
let f = Dist.(If_pred (pred, f1, One)) in
let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in
if not @@ Set.is_empty (Det_exp.fv de2) then raise Not_closed_observation;
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.singleton (module Id) v de2;
}
in
(g1 @+ g2 @+ g', de2)
| Assign (x, e, body) ->
let g1, det_exp1 = compile env pred e in
let sub_body = sub body x det_exp1 in
let g2, det_exp2 = compile env pred sub_body in
let g = g1 @+ g2 in
(g, det_exp2)
| If (e_pred, e_con, e_alt) ->
let g1, det_exp_pred = compile env pred e_pred in
let pred_true = Pred.And (det_exp_pred, pred) in
let pred_false = Pred.And_not (det_exp_pred, pred) in
let g2, det_exp_con = compile env pred_true e_con in
let g3, det_exp_alt = compile env pred_false e_alt in
let g = g1 @+ g2 @+ g3 in
(g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt))
| Call (c, params) ->
let f = Env.find_exn env ~name:c in
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile env pred e in
(g @+ g', de))
in
let { params; body; _ } = f in
let param_det_pairs = List.zip_exn params det_exps in
let sub_body =
List.fold param_det_pairs ~init:body
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile env pred sub_body in
(g @+ g_body, det_exp_body)
| _ -> failwith "Not implemented"
let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
let rec compile pred =
let compile' = compile pred in
function
| Exp.Int n -> (Graph.empty, Det_exp.Int n)
| Real r -> (Graph.empty, Det_exp.Real r)
| Var x -> (Graph.empty, Det_exp.Var x)
| Sample e ->
let g, de = compile' e in
let v = gen_vertex () in
let de_fvs = Det_exp.fv de in
let f = Dist.score de v in
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
}
in
(g @+ g', Det_exp.Var v)
| Observe (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
let v = gen_vertex () in
let f1 = Dist.score de1 v in
let f = Dist.(If_pred (pred, f1, One)) in
let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in
if not @@ Set.is_empty (Det_exp.fv de2) then
raise Not_closed_observation;
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.singleton (module Id) v de2;
}
in
(g1 @+ g2 @+ g', de2)
| Assign (x, e, body) ->
let g1, det_exp1 = compile' e in
let sub_body = sub body x det_exp1 in
let g2, det_exp2 = compile' sub_body in
let g = g1 @+ g2 in
(g, det_exp2)
| If (e_pred, e_con, e_alt) ->
let g1, det_exp_pred = compile' e_pred in
let pred_true = Pred.And (det_exp_pred, pred) in
let pred_false = Pred.And_not (det_exp_pred, pred) in
let g2, det_exp_con = compile pred_true e_con in
let g3, det_exp_alt = compile pred_false e_alt in
let g = g1 @+ g2 @+ g3 in
(g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt))
| Call (c, params) ->
let f = Env.find_exn env ~name:c in
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
in
let { params; body; _ } = f in
let param_det_pairs = List.zip_exn params det_exps in
let sub_body =
List.fold param_det_pairs ~init:body
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile' sub_body in
(g @+ g_body, det_exp_body)
| Add (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Add (de1, de2))
| Radd (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Radd (de1, de2))
| Minus (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Minus (de1, de2))
| Rminus (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rminus (de1, de2))
| Neg e ->
let g, de = compile' e in
(g, Det_exp.Neg de)
| Rneg e ->
let g, de = compile' e in
(g, Det_exp.Rneg de)
| Mult (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Mult (de1, de2))
| Rmult (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rmult (de1, de2))
| Div (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Div (de1, de2))
| Rdiv (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Rdiv (de1, de2))
| Eq (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Eq (de1, de2))
| Noteq (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Noteq (de1, de2))
| Less (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Less (de1, de2))
| And (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.And (de1, de2))
| Or (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, Det_exp.Or (de1, de2))
| Seq (e1, e2) ->
let g1, _ = compile' e1 in
let g2, de2 = compile' e2 in
(g1 @+ g2, de2)
| Not e ->
let g, de = compile' e in
(g, Det_exp.Not de)
| Exp.List es ->
let g, des =
List.fold_map es ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
in
(g, Det_exp.List des)
| Record fields ->
let g, des =
List.fold_map fields ~init:Graph.empty ~f:(fun g (k, v) ->
let g_k, de_k = compile' k in
let g_v, de_v = compile' v in
(g @+ g_k @+ g_v, (de_k, de_v)))
in
(g, Det_exp.Record des)
in
compile pred exp

0 comments on commit 02703fe

Please sign in to comment.