diff --git a/lib/compile.ml b/lib/compile.ml index 81c7980..e85eb0c 100644 --- a/lib/compile.ml +++ b/lib/compile.ml @@ -132,74 +132,155 @@ let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t = exception Not_closed_observation -let rec compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : - Graph.t * Det_exp.t = - ignore env; - ignore pred; - match exp with - | Int n -> (Graph.empty, Det_exp.Int n) - | Real r -> (Graph.empty, Det_exp.Real r) - | Var x -> (Graph.empty, Det_exp.Var x) - | Sample e -> - let g, de = compile env pred e in - let v = gen_vertex () in - let de_fvs = Det_exp.fv de in - let f = Dist.score de v in - let g' = - Graph. - { - vertices = [ v ]; - arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v)); - det_map = Map.singleton (module Id) v f; - obs_map = Map.empty (module Id); - } - in - (g @+ g', Det_exp.Var v) - | Observe (e1, e2) -> - let g1, de1 = compile env pred e1 in - let g2, de2 = compile env pred e2 in - let v = gen_vertex () in - let f1 = Dist.score de1 v in - let f = Dist.(If_pred (pred, f1, One)) in - let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in - if not @@ Set.is_empty (Det_exp.fv de2) then raise Not_closed_observation; - let g' = - Graph. - { - vertices = [ v ]; - arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v)); - det_map = Map.singleton (module Id) v f; - obs_map = Map.singleton (module Id) v de2; - } - in - (g1 @+ g2 @+ g', de2) - | Assign (x, e, body) -> - let g1, det_exp1 = compile env pred e in - let sub_body = sub body x det_exp1 in - let g2, det_exp2 = compile env pred sub_body in - let g = g1 @+ g2 in - (g, det_exp2) - | If (e_pred, e_con, e_alt) -> - let g1, det_exp_pred = compile env pred e_pred in - let pred_true = Pred.And (det_exp_pred, pred) in - let pred_false = Pred.And_not (det_exp_pred, pred) in - let g2, det_exp_con = compile env pred_true e_con in - let g3, det_exp_alt = compile env pred_false e_alt in - let g = g1 @+ g2 @+ g3 in - (g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt)) - | Call (c, params) -> - let f = Env.find_exn env ~name:c in - let g, det_exps = - List.fold_map params ~init:Graph.empty ~f:(fun g e -> - let g', de = compile env pred e in - (g @+ g', de)) - in - let { params; body; _ } = f in - let param_det_pairs = List.zip_exn params det_exps in - let sub_body = - List.fold param_det_pairs ~init:body - ~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp) - in - let g_body, det_exp_body = compile env pred sub_body in - (g @+ g_body, det_exp_body) - | _ -> failwith "Not implemented" +let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t = + let rec compile pred = + let compile' = compile pred in + function + | Exp.Int n -> (Graph.empty, Det_exp.Int n) + | Real r -> (Graph.empty, Det_exp.Real r) + | Var x -> (Graph.empty, Det_exp.Var x) + | Sample e -> + let g, de = compile' e in + let v = gen_vertex () in + let de_fvs = Det_exp.fv de in + let f = Dist.score de v in + let g' = + Graph. + { + vertices = [ v ]; + arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v)); + det_map = Map.singleton (module Id) v f; + obs_map = Map.empty (module Id); + } + in + (g @+ g', Det_exp.Var v) + | Observe (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + let v = gen_vertex () in + let f1 = Dist.score de1 v in + let f = Dist.(If_pred (pred, f1, One)) in + let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in + if not @@ Set.is_empty (Det_exp.fv de2) then + raise Not_closed_observation; + let g' = + Graph. + { + vertices = [ v ]; + arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v)); + det_map = Map.singleton (module Id) v f; + obs_map = Map.singleton (module Id) v de2; + } + in + (g1 @+ g2 @+ g', de2) + | Assign (x, e, body) -> + let g1, det_exp1 = compile' e in + let sub_body = sub body x det_exp1 in + let g2, det_exp2 = compile' sub_body in + let g = g1 @+ g2 in + (g, det_exp2) + | If (e_pred, e_con, e_alt) -> + let g1, det_exp_pred = compile' e_pred in + let pred_true = Pred.And (det_exp_pred, pred) in + let pred_false = Pred.And_not (det_exp_pred, pred) in + let g2, det_exp_con = compile pred_true e_con in + let g3, det_exp_alt = compile pred_false e_alt in + let g = g1 @+ g2 @+ g3 in + (g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt)) + | Call (c, params) -> + let f = Env.find_exn env ~name:c in + let g, det_exps = + List.fold_map params ~init:Graph.empty ~f:(fun g e -> + let g', de = compile' e in + (g @+ g', de)) + in + let { params; body; _ } = f in + let param_det_pairs = List.zip_exn params det_exps in + let sub_body = + List.fold param_det_pairs ~init:body + ~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp) + in + let g_body, det_exp_body = compile' sub_body in + (g @+ g_body, det_exp_body) + | Add (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Add (de1, de2)) + | Radd (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Radd (de1, de2)) + | Minus (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Minus (de1, de2)) + | Rminus (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Rminus (de1, de2)) + | Neg e -> + let g, de = compile' e in + (g, Det_exp.Neg de) + | Rneg e -> + let g, de = compile' e in + (g, Det_exp.Rneg de) + | Mult (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Mult (de1, de2)) + | Rmult (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Rmult (de1, de2)) + | Div (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Div (de1, de2)) + | Rdiv (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Rdiv (de1, de2)) + | Eq (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Eq (de1, de2)) + | Noteq (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Noteq (de1, de2)) + | Less (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Less (de1, de2)) + | And (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.And (de1, de2)) + | Or (e1, e2) -> + let g1, de1 = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, Det_exp.Or (de1, de2)) + | Seq (e1, e2) -> + let g1, _ = compile' e1 in + let g2, de2 = compile' e2 in + (g1 @+ g2, de2) + | Not e -> + let g, de = compile' e in + (g, Det_exp.Not de) + | Exp.List es -> + let g, des = + List.fold_map es ~init:Graph.empty ~f:(fun g e -> + let g', de = compile' e in + (g @+ g', de)) + in + (g, Det_exp.List des) + | Record fields -> + let g, des = + List.fold_map fields ~init:Graph.empty ~f:(fun g (k, v) -> + let g_k, de_k = compile' k in + let g_v, de_v = compile' v in + (g @+ g_k @+ g_v, (de_k, de_v))) + in + (g, Det_exp.Record des) + in + compile pred exp