Skip to content

Commit

Permalink
♻️ Make type names more consistent and move Printing module
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 9, 2024
1 parent ab18585 commit 35fcfb2
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 134 deletions.
2 changes: 1 addition & 1 deletion bin/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ let command : Command.t =
Out_channel.flush stdout;
let graph, query = get_program filename |> Compiler.compile_program in
graph_query := Some (graph, query);
print_s [%sexp (Printing.of_graph graph : Printing.graph)]);
print_s Graph.Erased.([%sexp (of_graph graph : t)]));

if pp_opt || graph_opt then printf "\n";
printf "Inference: %s\n" filename;
Expand Down
23 changes: 11 additions & 12 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ open! Core
open Parse_tree
open Typed_tree

type env = some_det Id.Map.t
type env = some_det_texp Id.Map.t

let gen_vertex =
let cnt = ref 0 in
Expand Down Expand Up @@ -54,8 +54,7 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
| If_pred (p, de) -> (
let p = peval_pred p and de = peval de in
match p with (* TODO: *) _ -> If_pred (p, de))
| If_con de -> If_con (peval de)
| If_alt de -> If_alt (peval de)
| If_just de -> If_just (peval de)
in

{ ty; exp }
Expand Down Expand Up @@ -96,8 +95,8 @@ let rec score : type a. (a, det) texp -> (a, det) texp = function
| _ -> raise Score_invalid_arguments

let rec compile :
type a s.
env:env -> ?pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp =
type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp
=
fun ~env ?(pred = Empty) { ty; exp } ->
match exp with
| Value _ as exp -> (Graph.empty, { ty; exp })
Expand Down Expand Up @@ -132,8 +131,8 @@ let rec compile :
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
match pred_con with
| True -> (g, { ty; exp = If_con de_con })
| False -> (g, { ty; exp = If_alt de_alt })
| True -> (g, { ty; exp = If_just de_con })
| False -> (g, { ty; exp = If_just de_alt })
| _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) }))
| Let (x, e, body) ->
let g1, det_exp1 = compile ~env ~pred e in
Expand All @@ -154,7 +153,7 @@ let rec compile :
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp);
pmdf_map = Id.Map.singleton v (Ex f : some_dist_det_texp);
obs_map = Id.Map.empty;
}
in
Expand All @@ -173,14 +172,14 @@ let rec compile :
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp);
obs_map = Id.Map.singleton v (Ex de2 : some_dat_texp);
pmdf_map = Id.Map.singleton v (Ex f : some_dist_det_texp);
obs_map = Id.Map.singleton v (Ex de2 : some_val_det_texp);
}
in
Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () })

and compile_args :
type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args =
type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args =
fun env pred args ->
match args with
| [] -> (Graph.empty, [])
Expand All @@ -191,7 +190,7 @@ and compile_args :

exception Query_not_found

let compile_program (prog : program) : Graph.t * some_rv_texp =
let compile_program (prog : program) : Graph.t * some_rv_det_texp =
Logs.debug (fun m ->
m "Inlining program %a" Sexp.pp_hum [%sexp (prog : Parse_tree.program)]);
let exp = Preprocessor.inline prog in
Expand Down
10 changes: 5 additions & 5 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a =
| Uop ({ op; _ }, te) -> op (eval_dat ctx te)
| If (te_pred, te_cons, te_alt) ->
if eval_dat ctx te_pred then eval_dat ctx te_cons else eval_dat ctx te_alt
| If_con te -> eval_dat ctx te
| If_alt te -> eval_dat ctx te
| If_just te -> eval_dat ctx te

and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a =
fun ctx { ty = Dist_ty dty as ty; exp } ->
Expand Down Expand Up @@ -81,7 +80,7 @@ let rec eval_pmdf :

(* TODO: Remove existential wrapper *)
let gibbs_sampling ~(num_samples : int) (graph : Graph.t)
(Ex query : some_rv_texp) : float array =
(Ex query : some_rv_det_texp) : float array =
(* Initialize the context with the observed values. Float conversion must
succeed as observed variables do not contain free variables *)
let default : type a. a dty -> a = function
Expand Down Expand Up @@ -154,15 +153,16 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t)
samples

let infer ?(filename : string = "out") ?(num_samples : int = 100_000)
(graph : Graph.t) (query : some_rv_texp) : string =
(graph : Graph.t) (query : some_rv_det_texp) : string =
let samples = gibbs_sampling graph ~num_samples query in

let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in
let plot_path = filename ^ ".png" in

let open Owl_plplot in
let h = Plot.create plot_path in
Plot.set_title h (Printing.of_rv query);
Plot.set_title h
Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string);
let mat = Owl.Mat.of_array samples 1 num_samples in
Plot.histogram ~h ~bin:50 mat;
Plot.output h;
Expand Down
28 changes: 25 additions & 3 deletions lib/graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ open Typed_tree

type vertex = Id.t
type arc = vertex * vertex
type pmdf_map = some_dist_texp Id.Map.t
type obs_map = some_dat_texp Id.Map.t
type pmdf_map = some_dist_det_texp Id.Map.t
type obs_map = some_val_det_texp Id.Map.t

type t = {
vertices : vertex list;
Expand Down Expand Up @@ -36,9 +36,31 @@ let union g1 g2 =
let ( @| ) = union

let unobserved_vertices_pmdfs ({ vertices; pmdf_map; obs_map; _ } : t) :
(vertex * some_dist_texp) list =
(vertex * some_dist_det_texp) list =
List.filter_map vertices ~f:(fun v ->
if Map.mem obs_map v then None
else
let pmdf = Map.find_exn pmdf_map v in
Some (v, pmdf))

module Erased = struct
open Typed_tree.Erased

type typed = t

type t = {
vertices : Id.t list;
arcs : (Id.t * Id.t) list;
pmdf_map : exp Id.Map.t;
obs_map : exp Id.Map.t;
}
[@@deriving sexp]

let of_graph ({ vertices; arcs; pmdf_map; obs_map } : typed) : t =
{
vertices;
arcs;
pmdf_map = Map.map pmdf_map ~f:(fun (Ex e) -> of_exp e);
obs_map = Map.map obs_map ~f:(fun (Ex e) -> of_exp e);
}
end
70 changes: 0 additions & 70 deletions lib/printing.ml

This file was deleted.

Loading

0 comments on commit 35fcfb2

Please sign in to comment.