diff --git a/bin/main.ml b/bin/main.ml index bcf4a25..82104d5 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -51,7 +51,7 @@ let command : Command.t = Out_channel.flush stdout; let graph, query = get_program filename |> Compiler.compile_program in graph_query := Some (graph, query); - print_s [%sexp (Printing.of_graph graph : Printing.graph)]); + print_s Graph.Erased.([%sexp (of_graph graph : t)])); if pp_opt || graph_opt then printf "\n"; printf "Inference: %s\n" filename; diff --git a/lib/compiler.ml b/lib/compiler.ml index e77b7b2..7a0014a 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -2,7 +2,7 @@ open! Core open Parse_tree open Typed_tree -type env = some_det Id.Map.t +type env = some_det_texp Id.Map.t let gen_vertex = let cnt = ref 0 in @@ -54,8 +54,7 @@ let rec peval : type a. (a, det) texp -> (a, det) texp = | If_pred (p, de) -> ( let p = peval_pred p and de = peval de in match p with (* TODO: *) _ -> If_pred (p, de)) - | If_con de -> If_con (peval de) - | If_alt de -> If_alt (peval de) + | If_just de -> If_just (peval de) in { ty; exp } @@ -96,8 +95,8 @@ let rec score : type a. (a, det) texp -> (a, det) texp = function | _ -> raise Score_invalid_arguments let rec compile : - type a s. - env:env -> ?pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp = + type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp + = fun ~env ?(pred = Empty) { ty; exp } -> match exp with | Value _ as exp -> (Graph.empty, { ty; exp }) @@ -132,8 +131,8 @@ let rec compile : let g3, de_alt = compile ~env ~pred:pred_alt e_alt in let g = Graph.(g1 @| g2 @| g3) in match pred_con with - | True -> (g, { ty; exp = If_con de_con }) - | False -> (g, { ty; exp = If_alt de_alt }) + | True -> (g, { ty; exp = If_just de_con }) + | False -> (g, { ty; exp = If_just de_alt }) | _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) })) | Let (x, e, body) -> let g1, det_exp1 = compile ~env ~pred e in @@ -154,7 +153,7 @@ let rec compile : { vertices = [ v ]; arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v)); - pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp); + pmdf_map = Id.Map.singleton v (Ex f : some_dist_det_texp); obs_map = Id.Map.empty; } in @@ -173,14 +172,14 @@ let rec compile : { vertices = [ v ]; arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v)); - pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp); - obs_map = Id.Map.singleton v (Ex de2 : some_dat_texp); + pmdf_map = Id.Map.singleton v (Ex f : some_dist_det_texp); + obs_map = Id.Map.singleton v (Ex de2 : some_val_det_texp); } in Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () }) and compile_args : - type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args = + type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args = fun env pred args -> match args with | [] -> (Graph.empty, []) @@ -191,7 +190,7 @@ and compile_args : exception Query_not_found -let compile_program (prog : program) : Graph.t * some_rv_texp = +let compile_program (prog : program) : Graph.t * some_rv_det_texp = Logs.debug (fun m -> m "Inlining program %a" Sexp.pp_hum [%sexp (prog : Parse_tree.program)]); let exp = Preprocessor.inline prog in diff --git a/lib/evaluator.ml b/lib/evaluator.ml index 214d4f9..affd2a4 100644 --- a/lib/evaluator.ml +++ b/lib/evaluator.ml @@ -27,8 +27,7 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a = | Uop ({ op; _ }, te) -> op (eval_dat ctx te) | If (te_pred, te_cons, te_alt) -> if eval_dat ctx te_pred then eval_dat ctx te_cons else eval_dat ctx te_alt - | If_con te -> eval_dat ctx te - | If_alt te -> eval_dat ctx te + | If_just te -> eval_dat ctx te and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a = fun ctx { ty = Dist_ty dty as ty; exp } -> @@ -81,7 +80,7 @@ let rec eval_pmdf : (* TODO: Remove existential wrapper *) let gibbs_sampling ~(num_samples : int) (graph : Graph.t) - (Ex query : some_rv_texp) : float array = + (Ex query : some_rv_det_texp) : float array = (* Initialize the context with the observed values. Float conversion must succeed as observed variables do not contain free variables *) let default : type a. a dty -> a = function @@ -154,7 +153,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) samples let infer ?(filename : string = "out") ?(num_samples : int = 100_000) - (graph : Graph.t) (query : some_rv_texp) : string = + (graph : Graph.t) (query : some_rv_det_texp) : string = let samples = gibbs_sampling graph ~num_samples query in let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in @@ -162,7 +161,8 @@ let infer ?(filename : string = "out") ?(num_samples : int = 100_000) let open Owl_plplot in let h = Plot.create plot_path in - Plot.set_title h (Printing.of_rv query); + Plot.set_title h + Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string); let mat = Owl.Mat.of_array samples 1 num_samples in Plot.histogram ~h ~bin:50 mat; Plot.output h; diff --git a/lib/graph.ml b/lib/graph.ml index ad66577..97b8f45 100644 --- a/lib/graph.ml +++ b/lib/graph.ml @@ -3,8 +3,8 @@ open Typed_tree type vertex = Id.t type arc = vertex * vertex -type pmdf_map = some_dist_texp Id.Map.t -type obs_map = some_dat_texp Id.Map.t +type pmdf_map = some_dist_det_texp Id.Map.t +type obs_map = some_val_det_texp Id.Map.t type t = { vertices : vertex list; @@ -36,9 +36,31 @@ let union g1 g2 = let ( @| ) = union let unobserved_vertices_pmdfs ({ vertices; pmdf_map; obs_map; _ } : t) : - (vertex * some_dist_texp) list = + (vertex * some_dist_det_texp) list = List.filter_map vertices ~f:(fun v -> if Map.mem obs_map v then None else let pmdf = Map.find_exn pmdf_map v in Some (v, pmdf)) + +module Erased = struct + open Typed_tree.Erased + + type typed = t + + type t = { + vertices : Id.t list; + arcs : (Id.t * Id.t) list; + pmdf_map : exp Id.Map.t; + obs_map : exp Id.Map.t; + } + [@@deriving sexp] + + let of_graph ({ vertices; arcs; pmdf_map; obs_map } : typed) : t = + { + vertices; + arcs; + pmdf_map = Map.map pmdf_map ~f:(fun (Ex e) -> of_exp e); + obs_map = Map.map obs_map ~f:(fun (Ex e) -> of_exp e); + } +end diff --git a/lib/printing.ml b/lib/printing.ml deleted file mode 100644 index 88873b3..0000000 --- a/lib/printing.ml +++ /dev/null @@ -1,70 +0,0 @@ -open! Core -open Typed_tree - -type t = - | Value : string -> t - | Var : Id.t -> t - | Bop : Id.t * t * t -> t - | Uop : Id.t * t -> t - (* TODO: Add list and record constructors *) - (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) - (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) - | If : t * t * t -> t - | If_con : t -> t - | If_alt : t -> t - | Let : Id.t * t * t -> t - | Call : Id.t * t list -> t - | Sample : t -> t - | Observe : t * t -> t -[@@deriving sexp] - -type graph = { - vertices : Id.t list; - arcs : (Id.t * Id.t) list; - pmdf_map : t Id.Map.t; - obs_map : t Id.Map.t; -} -[@@deriving sexp] - -let rec of_exp : type a d. (a, d) texp -> t = - fun { ty; exp } -> - match exp with - | If (pred, cons, alt) -> If (of_exp pred, of_exp cons, of_exp alt) - | If_pred (pred, cons) -> If (of_pred pred, of_exp cons, Value "1") - | If_con exp -> If_con (of_exp exp) - | If_alt exp -> If_alt (of_exp exp) - | Value v -> ( - match ty with - | Dat_ty (Tyu, _) -> Value "()" - | Dat_ty (Tyi, _) -> Value (string_of_int v) - | Dat_ty (Tyr, _) -> Value (string_of_float v) - | Dat_ty (Tyb, _) -> Value (string_of_bool v)) - | Var v -> Var v - | Bop (op, e1, e2) -> Bop (op.name, of_exp e1, of_exp e2) - | Uop (op, e) -> Uop (op.name, of_exp e) - | Let (x, e1, e2) -> Let (x, of_exp e1, of_exp e2) - | Call (f, args) -> Call (f.name, of_args args) - | Sample e -> Sample (of_exp e) - | Observe (d, e) -> Observe (of_exp d, of_exp e) - -and of_args : type a d. (a, d) args -> t list = function - | [] -> [] - | arg :: args -> of_exp arg :: of_args args - -and of_pred : pred -> t = function - | Empty -> Value "" - | True -> Value "true" - | False -> Value "false" - | And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp) - | And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp)) - -let of_graph ({ vertices; arcs; pmdf_map; obs_map } : Graph.t) : graph = - { - vertices; - arcs; - pmdf_map = Map.map pmdf_map ~f:(fun (Ex e) -> of_exp e); - obs_map = Map.map obs_map ~f:(fun (Ex e) -> of_exp e); - } - -let of_rv (Ex rv : some_rv_texp) = - rv |> of_exp |> sexp_of_t |> Sexp.to_string_hum diff --git a/lib/typed_tree.ml b/lib/typed_tree.ml index 9e5c13b..f1c4d71 100644 --- a/lib/typed_tree.ml +++ b/lib/typed_tree.ml @@ -1,7 +1,8 @@ open! Core -type (_, _) eq = Refl : ('a, 'a) eq type real = float +type ('a, 'b) uop = { name : Id.t; op : 'a -> 'b } +type ('a, 'b, 'c) bop = { name : Id.t; op : 'a -> 'b -> 'c } type _ dty = Tyu : unit dty | Tyi : int dty | Tyr : real dty | Tyb : bool dty type value = Val_ph type rv = Rv_ph @@ -13,13 +14,13 @@ type _ ty = | Dat_ty : 'a dty * 'b stamp -> ('a, 'b) dat_ty ty | Dist_ty : 'a dty -> 'a dist_ty ty +type det = Det_ph +type ndet = Ndet_ph + type _ params = | [] : unit params | ( :: ) : 'a dty * 'b params -> ('a * 'b) params -type det = Det_ph -type non_det = Non_det_ph - type _ vargs = | [] : unit vargs | ( :: ) : ('a dty * 'a) * 'b vargs -> ('a * 'b) vargs @@ -32,9 +33,6 @@ type ('a, 'b) dist = { log_pmdf : 'b vargs -> 'a -> real; } -type ('a, 'b, 'c) bop = { name : Id.t; op : 'a -> 'b -> 'c } -type ('a, 'b) uop = { name : Id.t; op : 'a -> 'b } - (* TODO: Why args should also be det? *) type (_, _) args = | [] : (unit, _) args @@ -62,37 +60,36 @@ and (_, _) exp = * (('a, _) dat_ty, 'd) texp -> (('a, _) dat_ty, 'd) exp | If_pred : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp - | If_con : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp - | If_alt : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp - | Let : Id.t * ('a, non_det) texp * ('b, non_det) texp -> ('b, non_det) exp + | If_just : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp + | Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp | Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp - | Sample : ('a dist_ty, non_det) texp -> (('a, rv) dat_ty, non_det) exp + | Sample : ('a dist_ty, ndet) texp -> (('a, rv) dat_ty, ndet) exp | Observe : - ('a dist_ty, non_det) texp * (('a, value) dat_ty, non_det) texp - -> ((unit, value) dat_ty, non_det) exp - -type some_non_det_texp = Ex : (_, non_det) texp -> some_non_det_texp -type some_det = Ex : (_, det) texp -> some_det -type some_rv_texp = Ex : ((_, rv) dat_ty, det) texp -> some_rv_texp -type some_dat_texp = Ex : ((_, value) dat_ty, det) texp -> some_dat_texp -type some_dist_texp = Ex : (_ dist_ty, det) texp -> some_dist_texp + ('a dist_ty, ndet) texp * (('a, value) dat_ty, ndet) texp + -> ((unit, value) dat_ty, ndet) exp + type some_dty = Ex : _ dty -> some_dty -type some_ty = Ex : _ ty -> some_ty type some_stamp = Ex : _ stamp -> some_stamp +type some_ty = Ex : _ ty -> some_ty +type some_ndet_texp = Ex : (_, ndet) texp -> some_ndet_texp +type some_det_texp = Ex : (_, det) texp -> some_det_texp +type some_dat_ndet_texp = Ex : (_ dat_ty, ndet) texp -> some_dat_ndet_texp -type _ some_dat_non_det_texp = - | Ex : (('a, _) dat_ty, non_det) texp -> 'a some_dat_non_det_texp +type _ some_dat_ndet_texp1 = + | Ex : (('a, _) dat_ty, ndet) texp -> 'a some_dat_ndet_texp1 -type 'a dist_non_det_texp = ('a dist_ty, non_det) texp +type some_val_det_texp = + | Ex : ((_, value) dat_ty, det) texp -> some_val_det_texp -type some_ndist_ndet_texp = - | Ex : (_ dat_ty, non_det) texp -> some_ndist_ndet_texp +type some_rv_det_texp = Ex : ((_, rv) dat_ty, det) texp -> some_rv_det_texp +type some_dist_det_texp = Ex : (_ dist_ty, det) texp -> some_dist_det_texp +type (_, _) eq = Refl : ('a, 'a) eq let dty_of_ty : type a. (a, _) dat_ty ty -> a dty = function | Dat_ty (dty, _) -> dty let some_dat_ndet_texp_of_ndet_texp : - type a. (a, non_det) texp -> some_ndist_ndet_texp option = + type a. (a, ndet) texp -> some_dat_ndet_texp option = fun texp -> match texp.ty with | Dat_ty (Tyu, _) -> Some (Ex texp) @@ -103,8 +100,8 @@ let some_dat_ndet_texp_of_ndet_texp : let eq_dat_ndet_texps : type a1 a2. - ((a1, _) dat_ty, non_det) texp -> - ((a2, _) dat_ty, non_det) texp -> + ((a1, _) dat_ty, ndet) texp -> + ((a2, _) dat_ty, ndet) texp -> (a1, a2) eq option = fun te_con te_alt -> match (dty_of_ty te_con.ty, dty_of_ty te_alt.ty) with @@ -142,8 +139,7 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function | If ({ exp = e_pred; _ }, { exp = e_cons; _ }, { exp = e_alt; _ }) -> Id.(fv e_pred @| fv e_cons @| fv e_alt) | If_pred (pred, { exp = e_cons; _ }) -> Id.(fv_pred pred @| fv e_cons) - | If_con { exp; _ } -> fv exp - | If_alt { exp; _ } -> fv exp + | If_just { exp; _ } -> fv exp | Call (_, args) -> fv_args args and fv_args : type a. (a, det) args -> Id.Set.t = function @@ -154,3 +150,54 @@ and fv_pred : pred -> Id.Set.t = function | Empty | True | False -> Id.Set.empty | And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) | And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) + +module Erased = struct + type exp = + | Value : string -> exp + | Var : Id.t -> exp + | Bop : Id.t * exp * exp -> exp + | Uop : Id.t * exp -> exp + (* TODO: Add list and record constructors *) + (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) + (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) + | If : exp * exp * exp -> exp + | If_just : exp -> exp + | Let : Id.t * exp * exp -> exp + | Call : Id.t * exp list -> exp + | Sample : exp -> exp + | Observe : exp * exp -> exp + [@@deriving sexp] + + let rec of_exp : type a d. (a, d) texp -> exp = + fun { ty; exp } -> + match exp with + | If (pred, cons, alt) -> If (of_exp pred, of_exp cons, of_exp alt) + | If_pred (pred, cons) -> If (of_pred pred, of_exp cons, Value "1") + | If_just exp -> If_just (of_exp exp) + | Value v -> ( + match ty with + | Dat_ty (Tyu, _) -> Value "()" + | Dat_ty (Tyi, _) -> Value (string_of_int v) + | Dat_ty (Tyr, _) -> Value (string_of_float v) + | Dat_ty (Tyb, _) -> Value (string_of_bool v)) + | Var v -> Var v + | Bop (op, e1, e2) -> Bop (op.name, of_exp e1, of_exp e2) + | Uop (op, e) -> Uop (op.name, of_exp e) + | Let (x, e1, e2) -> Let (x, of_exp e1, of_exp e2) + | Call (f, args) -> Call (f.name, of_args args) + | Sample e -> Sample (of_exp e) + | Observe (d, e) -> Observe (of_exp d, of_exp e) + + and of_args : type a d. (a, d) args -> exp list = function + | [] -> [] + | arg :: args -> of_exp arg :: of_args args + + and of_pred : pred -> exp = function + | Empty -> Value "" + | True -> Value "true" + | False -> Value "false" + | And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp) + | And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp)) + + let of_rv (Ex rv : some_rv_det_texp) = rv |> of_exp +end diff --git a/lib/typing.ml b/lib/typing.ml index b38b094..6e30d8a 100644 --- a/lib/typing.ml +++ b/lib/typing.ml @@ -49,11 +49,11 @@ let get_uop : type a b. Parse_tree.exp * (a dty * b dty) -> (a, b) uop = let unify_branches : type a_con a_alt s_pred s_con s_alt. - ((bool, s_pred) dat_ty, non_det) texp -> - ((a_con, s_con) dat_ty, non_det) texp -> - ((a_alt, s_alt) dat_ty, non_det) texp -> + ((bool, s_pred) dat_ty, ndet) texp -> + ((a_con, s_con) dat_ty, ndet) texp -> + ((a_alt, s_alt) dat_ty, ndet) texp -> (a_con, a_alt) eq -> - a_con some_dat_non_det_texp = + a_con some_dat_ndet_texp1 = fun te_pred te_con te_alt Refl -> match te_pred.ty with | Dat_ty (Tyb, Val) -> ( @@ -126,7 +126,7 @@ let unify_branches : Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) }) let rec check_dat : - type a. tyenv -> Parse_tree.exp * a dty -> a some_dat_non_det_texp = + type a. tyenv -> Parse_tree.exp * a dty -> a some_dat_ndet_texp1 = fun tyenv (exp, dty) -> Logs.debug (fun m -> m "Checking exp (%a : %a)" Sexp.pp_hum @@ -292,7 +292,7 @@ and check_uop : (a, ret) uop -> Parse_tree.exp * a dty -> ret dty -> - ret some_dat_non_det_texp = + ret some_dat_ndet_texp1 = fun tyenv uop (e, t) tret -> let (Ex ({ ty = Dat_ty (_, s); _ } as te)) = check_dat tyenv (e, t) in match s with @@ -306,7 +306,7 @@ and check_bop : Parse_tree.exp * a1 dty -> Parse_tree.exp * a2 dty -> ret dty -> - ret some_dat_non_det_texp = + ret some_dat_ndet_texp1 = fun tyenv bop (e1, t1) (e2, t2) tret -> let (Ex ({ ty = Dat_ty (_, s1); _ } as te1)) = check_dat tyenv (e1, t1) in let (Ex ({ ty = Dat_ty (_, s2); _ } as te2)) = check_dat tyenv (e2, t2) in @@ -315,8 +315,7 @@ and check_bop : | _, _ -> Ex { ty = Dat_ty (tret, Rv); exp = Bop (bop, te1, te2) } and check_args : - type a. tyenv -> Id.t -> Parse_tree.exp list * a params -> (a, non_det) args - = + type a. tyenv -> Id.t -> Parse_tree.exp list * a params -> (a, ndet) args = fun tyenv prim (es, dtys) -> match dtys with | [] -> [] @@ -328,8 +327,8 @@ and check_args : let args = check_args tyenv prim (args, dtys) in arg :: args) -and check_dist : type a. tyenv -> Parse_tree.exp * a dty -> a dist_non_det_texp - = +and check_dist : + type a. tyenv -> Parse_tree.exp * a dty -> (a dist_ty, ndet) texp = fun tyenv (exp, dty) -> Logs.debug (fun m -> m "Checking exp (%a : %a dist)" Sexp.pp_hum @@ -381,7 +380,7 @@ and check_dist : type a. tyenv -> Parse_tree.exp * a dty -> a dist_non_det_texp | Or _ | Neg _ | Rneg _ | Not _ | Sample _ | Observe _ | List _ | Record _ -> raise (Type_error "Expected distribution") -and infer (tyenv : tyenv) (exp : Parse_tree.exp) : some_non_det_texp = +and infer (tyenv : tyenv) (exp : Parse_tree.exp) : some_ndet_texp = Logs.debug (fun m -> m "Infering exp %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]); match exp with @@ -441,4 +440,4 @@ and infer (tyenv : tyenv) (exp : Parse_tree.exp) : some_non_det_texp = | List _ -> failwith "List not implemented" | Record _ -> failwith "Record not implemented" -let check : Parse_tree.exp -> some_non_det_texp = infer Id.Map.empty +let check : Parse_tree.exp -> some_ndet_texp = infer Id.Map.empty