diff --git a/bin/dune b/bin/dune index a817757..b0a19bf 100644 --- a/bin/dune +++ b/bin/dune @@ -1,7 +1,13 @@ (executable (public_name stappl) (name main) - (libraries core core_unix.command_unix core_unix.filename_unix stappl) - (modes byte exe) + (libraries + core + core_unix.command_unix + core_unix.filename_unix + logs + logs.fmt + stappl) (preprocess - (pps ppx_jane))) + (pps ppx_jane)) + (modes byte exe)) diff --git a/bin/main.ml b/bin/main.ml index 768bd6b..bcf4a25 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -35,8 +35,11 @@ let command : Command.t = (let%map_open.Command filename = anon (maybe_with_default "-" ("filename" %: Filename_unix.arg_type)) and pp_opt = flag "-pp" no_arg ~doc:" Pretty print the program" - and graph_opt = flag "-graph" no_arg ~doc:" Print the compiled graph" in + and graph_opt = flag "-graph" no_arg ~doc:" Print the compiled graph" + and debug_opt = flag "-debug" no_arg ~doc:" Debug mode" in fun () -> + if debug_opt then Logs.set_level (Some Logs.Debug); + if pp_opt then ( printf "Pretty-print: %s\n" filename; print_s [%sexp (get_program filename : Parse_tree.program)]); @@ -49,6 +52,7 @@ let command : Command.t = let graph, query = get_program filename |> Compiler.compile_program in graph_query := Some (graph, query); print_s [%sexp (Printing.of_graph graph : Printing.graph)]); + if pp_opt || graph_opt then printf "\n"; printf "Inference: %s\n" filename; Out_channel.flush stdout; @@ -60,4 +64,7 @@ let command : Command.t = printf "Query result saved at %s\n" (Evaluator.infer ~filename graph query)) -let () = Command_unix.run ~version:"0.1.0" ~build_info:"STAPPL" command +let () = + Logs.set_reporter (Logs_fmt.reporter ()); + Command_unix.run ~version:"0.1.0" ~build_info:"STAPPL" command; + exit (if Logs.err_count () > 0 then 1 else 0) diff --git a/lib/compiler.ml b/lib/compiler.ml index 119cf46..e77b7b2 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -16,21 +16,24 @@ exception Not_closed_observation let rec peval : type a. (a, det) texp -> (a, det) texp = fun { ty; exp } -> + (* TODO: consider other cases *) let exp = match exp with - | (Value _ | Var _) as e -> e + | Value _ -> exp + | Var _ -> exp | Bop (op, te1, te2) -> ( match (peval te1, peval te2) with - | { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (op.f v1 v2) + (*| { ty = ty1; exp = Value v1 }, { ty = ty2; exp = Value v2 } ->*) + (* Value (op.op v1 v2)*) | te1, te2 -> Bop (op, te1, te2)) | Uop (op, te) -> ( match peval te with - | { exp = Value v; _ } -> Value (op.f v) + (*| { exp = Value v; _ } -> Value (op.op v)*) | e -> Uop (op, e)) | If (te_pred, te_cons, te_alt) -> ( match peval te_pred with - | { exp = Value true; _ } -> (peval te_cons).exp - | { exp = Value false; _ } -> (peval te_alt).exp + (*| { exp = Value true; _ } -> (peval te_cons).exp*) + (*| { exp = Value false; _ } -> (peval te_alt).exp*) | te_pred -> If (te_pred, peval te_cons, peval te_alt)) | Call (f, args) -> ( match peval_args args with @@ -48,7 +51,13 @@ let rec peval : type a. (a, det) texp -> (a, det) texp = log_pmdf = (fun [] -> f.log_pmdf vargs); }, [] )) + | If_pred (p, de) -> ( + let p = peval_pred p and de = peval de in + match p with (* TODO: *) _ -> If_pred (p, de)) + | If_con de -> If_con (peval de) + | If_alt de -> If_alt (peval de) in + { ty; exp } and peval_args : type a. (a, det) args -> (a, det) args * a vargs option = @@ -57,31 +66,57 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option = | te :: tl -> ( match (peval te, peval_args tl) with | { ty; exp = Value v }, (tl, Some vargs) -> - ({ ty; exp = Value v } :: tl, Some ((ty, v) :: vargs)) + ({ ty; exp = Value v } :: tl, Some ((dty_of_ty ty, v) :: vargs)) | te, (tl, _) -> (te :: tl, None)) +and peval_pred : pred -> pred = function + | Empty -> failwith "[Bug] Empty predicate" + | True -> True + | False -> False + | And (p, de) -> ( + match peval de with + | { exp = Value true; _ } -> peval_pred p + | { exp = Value false; _ } -> False + | de -> And (p, de)) + | And_not (p, de) -> ( + match peval de with + | { exp = Value true; _ } -> False + | { exp = Value false; _ } -> peval_pred p + | de -> And_not (p, de)) + +let ( &&& ) p de = peval_pred (And (p, de)) +let ( &&! ) p de = peval_pred (And_not (p, de)) + let rec score : type a. (a, det) texp -> (a, det) texp = function + (* TODO: consider other cases *) | { ty; exp = If (e_pred, e_con, e_alt) } -> let s_con = score e_con and s_alt = score e_alt in { ty; exp = If (e_pred, s_con, s_alt) } | { exp = Call _; _ } as e -> e | _ -> raise Score_invalid_arguments -type pred = (bool, det) texp - let rec compile : - type a. env:env -> pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp - = - fun ~env ~pred { ty; exp } -> + type a s. + env:env -> ?pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp = + fun ~env ?(pred = Empty) { ty; exp } -> match exp with - | Value v -> (Graph.empty, { ty; exp = Value v }) + | Value _ as exp -> (Graph.empty, { ty; exp }) | Var x -> ( let (Ex { ty = tyx; exp }) = Map.find_exn env x in match (tyx, ty) with - | Tyi, Tyi -> (Graph.empty, { ty; exp }) - | Tyr, Tyr -> (Graph.empty, { ty; exp }) - | Tyb, Tyb -> (Graph.empty, { ty; exp }) - | _, _ -> assert false) + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> (Graph.empty, { ty; exp }) + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> (Graph.empty, { ty; exp }) + | Dist_ty Tyu, Dist_ty Tyu -> (Graph.empty, { ty; exp }) + | Dist_ty Tyb, Dist_ty Tyb -> (Graph.empty, { ty; exp }) + | Dist_ty Tyi, Dist_ty Tyi -> (Graph.empty, { ty; exp }) + | Dist_ty Tyr, Dist_ty Tyr -> (Graph.empty, { ty; exp }) + | _, _ -> failwith "[Bug] Type mismatch") | Bop (op, e1, e2) -> let g1, te1 = compile ~env ~pred e1 in let g2, te2 = compile ~env ~pred e2 in @@ -91,27 +126,14 @@ let rec compile : (g, { ty; exp = Uop (op, te) }) | If (e_pred, e_con, e_alt) -> ( let g1, de_pred = compile ~env ~pred e_pred in - let pred_con = - peval - { ty = Tyb; exp = Bop ({ f = ( && ); name = "&&" }, pred, de_pred) } - in - let pred_alt = - peval - { - ty = Tyb; - exp = - Bop - ( { f = ( && ); name = "&&" }, - pred, - { ty = Tyb; exp = Uop ({ f = not; name = "!" }, de_pred) } ); - } - in + let pred_con = pred &&& de_pred in + let pred_alt = pred &&! de_pred in let g2, de_con = compile ~env ~pred:pred_con e_con in let g3, de_alt = compile ~env ~pred:pred_alt e_alt in let g = Graph.(g1 @| g2 @| g3) in - match pred_con.exp with - | Value true -> (g, de_con) - | Value false -> (g, de_alt) + match pred_con with + | True -> (g, { ty; exp = If_con de_con }) + | False -> (g, { ty; exp = If_alt de_alt }) | _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) })) | Let (x, e, body) -> let g1, det_exp1 = compile ~env ~pred e in @@ -126,13 +148,13 @@ let rec compile : let g, de = compile ~env ~pred e in let v = gen_vertex () in let de_fvs = fv de.exp in - let f : some_det = Ex (score de) in + let f = score de in let g' = Graph. { vertices = [ v ]; arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v)); - pmdf_map = Id.Map.singleton v f; + pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp); obs_map = Id.Map.empty; } in @@ -142,21 +164,20 @@ let rec compile : let g2, de2 = compile ~env ~pred e2 in let v = gen_vertex () in let f1 = score de1 in - let f = - { ty; exp = If (pred, f1, { ty; exp = Call (Dist.one ty, []) }) } - in - let fvs = Id.(fv de1.exp @| fv pred.exp) in - if not (Set.is_empty (fv de2.exp)) then raise Not_closed_observation; + let f = { ty = f1.ty; exp = If_pred (pred, f1) } in + let fvs = Id.(fv de1.exp @| fv_pred pred) in + if not (Set.is_empty (fv de2.exp)) then + failwith "[Bug] Not closed observation"; let g' = Graph. { vertices = [ v ]; arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v)); - pmdf_map = Id.Map.singleton v (Ex f : some_det); - obs_map = Id.Map.singleton v (Ex de2 : some_det); + pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp); + obs_map = Id.Map.singleton v (Ex de2 : some_dat_texp); } in - Graph.(g1 @| g2 @| g', de2) + Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () }) and compile_args : type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args = @@ -168,7 +189,17 @@ and compile_args : let g', args = compile_args env pred args in Graph.(g @| g', arg :: args) -let compile_program (prog : program) : Graph.t * some_det = - let (Ex e) = Typing.check_program prog in - let g, e = compile ~env:Id.Map.empty ~pred:{ ty = Tyb; exp = Value true } e in - (g, Ex e) +exception Query_not_found + +let compile_program (prog : program) : Graph.t * some_rv_texp = + Logs.debug (fun m -> + m "Inlining program %a" Sexp.pp_hum [%sexp (prog : Parse_tree.program)]); + let exp = Preprocessor.inline prog in + Logs.debug (fun m -> + m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]); + + let (Ex e) = Typing.check exp in + let g, { ty; exp } = compile ~env:Id.Map.empty e in + match ty with + | Dat_ty (_, Rv) -> (g, Ex { ty; exp }) + | _ -> raise Query_not_found diff --git a/lib/dist.ml b/lib/dist.ml index be83080..81c1bae 100644 --- a/lib/dist.ml +++ b/lib/dist.ml @@ -1,7 +1,17 @@ open! Core open Typed_tree -let one : type a. a ty -> (a, unit) dist = function +type some_dist = Ex : _ dist -> some_dist + +let one : type a. a dty -> (a, unit) dist = function + | Tyu -> + { + ret = Tyu; + name = "one"; + params = []; + sampler = (fun [] -> ()); + log_pmdf = (fun [] _ -> 0.0); + } | Tyi -> { ret = Tyi; diff --git a/lib/dune b/lib/dune index 1e1fb83..bf16f5d 100644 --- a/lib/dune +++ b/lib/dune @@ -1,6 +1,6 @@ (library (name stappl) - (libraries core owl owl-plplot string_dict) + (libraries core owl owl-plplot string_dict logs) (inline_tests) (preprocess (pps ppx_jane))) diff --git a/lib/evaluator.ml b/lib/evaluator.ml index 788fe38..214d4f9 100644 --- a/lib/evaluator.ml +++ b/lib/evaluator.ml @@ -1,6 +1,8 @@ open! Core open Typed_tree +type some_v = Ex : ('a dty * 'a) -> some_v + module Ctx = struct type t = some_v Id.Table.t @@ -9,62 +11,94 @@ module Ctx = struct let find_exn = Hashtbl.find_exn end -let rec eval : type a. Ctx.t -> (a, det) texp -> a = +let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a = fun ctx { ty; exp } -> match exp with | Value v -> v | Var x -> ( let (Ex (tv, v)) = Ctx.find_exn ctx x in match (ty, tv) with - | Tyi, Tyi -> v - | Tyr, Tyr -> v - | Tyb, Tyb -> v + | Dat_ty (Tyu, _), Tyu -> v + | Dat_ty (Tyb, _), Tyb -> v + | Dat_ty (Tyi, _), Tyi -> v + | Dat_ty (Tyr, _), Tyr -> v | _ -> assert false) - | Bop (op, te1, te2) -> op.f (eval ctx te1) (eval ctx te2) - | Uop (op, te) -> op.f (eval ctx te) + | Bop ({ op; _ }, te1, te2) -> op (eval_dat ctx te1) (eval_dat ctx te2) + | Uop ({ op; _ }, te) -> op (eval_dat ctx te) | If (te_pred, te_cons, te_alt) -> - if eval ctx te_pred then eval ctx te_cons else eval ctx te_alt + if eval_dat ctx te_pred then eval_dat ctx te_cons else eval_dat ctx te_alt + | If_con te -> eval_dat ctx te + | If_alt te -> eval_dat ctx te + +and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a = + fun ctx { ty = Dist_ty dty as ty; exp } -> + match exp with | Call (f, args) -> f.sampler (eval_args ctx args) + | Var x -> ( + let (Ex (tv, v)) = Ctx.find_exn ctx x in + match (dty, tv) with + | Tyu, Tyu -> v + | Tyb, Tyb -> v + | Tyi, Tyi -> v + | Tyr, Tyr -> v + | _ -> assert false) + | If_pred (pred, dist) -> + if eval_pred ctx pred then eval_dist ctx dist + else eval_dist ctx { ty; exp = Call (Dist.one dty, []) } + +and eval_pred (ctx : Ctx.t) : pred -> bool = + (*print_endline "[eval_pred]";*) + function + | Empty | True -> true + | False -> false + | And (p, de) -> eval_dat ctx de && eval_pred ctx p + | And_not (p, de) -> (not (eval_dat ctx de)) && eval_pred ctx p and eval_args : type a. Ctx.t -> (a, det) args -> a vargs = fun ctx -> function | [] -> [] - | te :: tl -> (te.ty, eval ctx te) :: eval_args ctx tl + | ({ ty = Dat_ty (dty, _); _ } as te) :: tl -> + (dty, eval_dat ctx te) :: eval_args ctx tl -let rec eval_pmdf (ctx : Ctx.t) (Ex { ty; exp } : some_det) : - (some_v -> float) * some_v = +let rec eval_pmdf : + type a. Ctx.t -> (a dist_ty, det) texp -> (some_v -> real) * some_v = + fun ctx { ty = Dist_ty dty as ty; exp } -> match exp with - | If (te_pred, te_cons, te_alt) -> - if eval ctx te_pred then eval_pmdf ctx (Ex te_cons) - else eval_pmdf ctx (Ex te_alt) + | If_pred (pred, te) -> + if eval_pred ctx pred then eval_pmdf ctx te + else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) } | Call (f, args) -> let pmdf (Ex (ty', v) : some_v) = - match (ty, ty') with + match (dty, ty') with + | Tyu, Tyu -> f.log_pmdf (eval_args ctx args) v + | Tyb, Tyb -> f.log_pmdf (eval_args ctx args) v | Tyi, Tyi -> f.log_pmdf (eval_args ctx args) v | Tyr, Tyr -> f.log_pmdf (eval_args ctx args) v - | Tyb, Tyb -> f.log_pmdf (eval_args ctx args) v | _, _ -> assert false in - (pmdf, Ex (ty, eval ctx { ty; exp })) + (pmdf, Ex (dty, eval_dist ctx { ty; exp })) | _ -> (* not reachable *) assert false -let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) : - float array = +(* TODO: Remove existential wrapper *) +let gibbs_sampling ~(num_samples : int) (graph : Graph.t) + (Ex query : some_rv_texp) : float array = (* Initialize the context with the observed values. Float conversion must succeed as observed variables do not contain free variables *) - let default : type a. a ty -> a = function + let default : type a. a dty -> a = function + | Tyu -> () + | Tyb -> false | Tyi -> 0 | Tyr -> 0.0 - | Tyb -> false in let ctx = Id.Table.create () in - Map.iteri graph.obs_map ~f:(fun ~key:name ~data:(Ex { ty; exp }) -> - let value : some_v = Ex (ty, eval ctx { ty; exp }) in + Map.iteri graph.obs_map + ~f:(fun ~key:name ~data:(Ex { ty = Dat_ty (dty, _) as ty; exp }) -> + let value : some_v = Ex (dty, eval_dat ctx { ty; exp }) in Ctx.set ctx ~name ~value); let unobserved = Graph.unobserved_vertices_pmdfs graph in - List.iter unobserved ~f:(fun (name, Ex { ty; _ }) -> - let value : some_v = Ex (ty, default ty) in + List.iter unobserved ~f:(fun (name, Ex { ty = Dist_ty dty; _ }) -> + let value : some_v = Ex (dty, default dty) in Ctx.set ctx ~name ~value); (* Adapted from gibbs_sampling of Owl *) @@ -73,7 +107,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) : let samples = Array.create ~len:num_samples 0. in for i = 0 to num_iter - 1 do (* Gibbs step *) - List.iter unobserved ~f:(fun (name, exp) -> + List.iter unobserved ~f:(fun (name, Ex exp) -> let curr = Ctx.find_exn ctx name in let log_pmdf, cand = eval_pmdf ctx exp in @@ -84,12 +118,12 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) : (* variables influenced by "name" *) let name_infl = - Map.filteri graph.pmdf_map - ~f:(fun ~key:name' ~data:(Ex { exp; _ }) -> + Map.filteri graph.pmdf_map ~f:(fun ~key:name' ~data:(Ex { exp; _ }) -> Id.(name' = name) || Set.mem (fv exp) name) in let log_alpha = - Map.fold name_infl ~init:log_alpha ~f:(fun ~key:name' ~data:exp acc -> + Map.fold name_infl ~init:log_alpha + ~f:(fun ~key:name' ~data:(Ex exp) acc -> let prob_w_cand = (fst (eval_pmdf ctx exp)) (Ctx.find_exn ctx name') in @@ -106,12 +140,13 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) : if Float.(uniform > alpha) then Ctx.set ctx ~name ~value:curr); if i >= a && i mod b = 0 then - let (Ex query) = query in let query = - match (query.ty, eval ctx query) with - | Tyi, i -> float_of_int i - | Tyr, r -> r - | Tyb, b -> if b then 1.0 else 0.0 + match (query.ty, eval_dat ctx query) with + (* TODO: Fix query type error *) + | Dat_ty (Tyu, Rv), _ -> 0.0 + | Dat_ty (Tyb, Rv), b -> if b then 1.0 else 0.0 + | Dat_ty (Tyi, Rv), i -> float_of_int i + | Dat_ty (Tyr, Rv), r -> r in samples.((i - a) / b) <- query done; @@ -119,7 +154,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (query : some_det) : samples let infer ?(filename : string = "out") ?(num_samples : int = 100_000) - (graph : Graph.t) (query : some_det) : string = + (graph : Graph.t) (query : some_rv_texp) : string = let samples = gibbs_sampling graph ~num_samples query in let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in @@ -127,7 +162,7 @@ let infer ?(filename : string = "out") ?(num_samples : int = 100_000) let open Owl_plplot in let h = Plot.create plot_path in - Plot.set_title h (Printing.to_string query); + Plot.set_title h (Printing.of_rv query); let mat = Owl.Mat.of_array samples 1 num_samples in Plot.histogram ~h ~bin:50 mat; Plot.output h; diff --git a/lib/graph.ml b/lib/graph.ml index f688409..ad66577 100644 --- a/lib/graph.ml +++ b/lib/graph.ml @@ -3,8 +3,8 @@ open Typed_tree type vertex = Id.t type arc = vertex * vertex -type pmdf_map = some_det Id.Map.t -type obs_map = some_det Id.Map.t +type pmdf_map = some_dist_texp Id.Map.t +type obs_map = some_dat_texp Id.Map.t type t = { vertices : vertex list; @@ -36,7 +36,7 @@ let union g1 g2 = let ( @| ) = union let unobserved_vertices_pmdfs ({ vertices; pmdf_map; obs_map; _ } : t) : - (vertex * some_det) list = + (vertex * some_dist_texp) list = List.filter_map vertices ~f:(fun v -> if Map.mem obs_map v then None else diff --git a/lib/preprocessor.ml b/lib/preprocessor.ml new file mode 100644 index 0000000..580a201 --- /dev/null +++ b/lib/preprocessor.ml @@ -0,0 +1,120 @@ +open! Core +open Parse_tree + +type subst_map = Id.t Id.Map.t + +exception Arity_mismatch of string +exception Unbound_variable of string + +let gen_args = + let cnt = ref 0 in + fun () -> + let arg = "$arg" ^ string_of_int !cnt in + incr cnt; + arg + +let rec subst (env : subst_map) : exp -> exp = + (* 𝜂-expansion required to avoid infinite recursion *) + let subst' e = subst env e in + + function + | (Int _ | Real _ | Bool _) as e -> e + | Var x -> ( + match Map.find env x with + | None -> raise (Unbound_variable x) + | Some v -> Var v) + | Add (e1, e2) -> Add (subst' e1, subst' e2) + | Radd (e1, e2) -> Radd (subst' e1, subst' e2) + | Minus (e1, e2) -> Minus (subst' e1, subst' e2) + | Rminus (e1, e2) -> Rminus (subst' e1, subst' e2) + | Neg e -> Neg (subst' e) + | Rneg e -> Rneg (subst' e) + | Mult (e1, e2) -> Mult (subst' e1, subst' e2) + | Rmult (e1, e2) -> Rmult (subst' e1, subst' e2) + | Div (e1, e2) -> Div (subst' e1, subst' e2) + | Rdiv (e1, e2) -> Rdiv (subst' e1, subst' e2) + | Eq (e1, e2) -> Eq (subst' e1, subst' e2) + | Req (e1, e2) -> Req (subst' e1, subst' e2) + | Noteq (e1, e2) -> Noteq (subst' e1, subst' e2) + | Less (e1, e2) -> Less (subst' e1, subst' e2) + | Rless (e1, e2) -> Rless (subst' e1, subst' e2) + | And (e1, e2) -> And (subst' e1, subst' e2) + | Or (e1, e2) -> Or (subst' e1, subst' e2) + | Seq (e1, e2) -> Seq (subst' e1, subst' e2) + | Not e -> Not (subst' e) + | Assign (x, e1, e2) -> + Assign (x, subst' e1, subst (Map.set env ~key:x ~data:x) e2) + | If (e_pred, e_con, e_alt) -> If (subst' e_pred, subst' e_con, subst' e_alt) + | Call (f, args) -> + let args = List.map ~f:subst' args in + Call (f, args) + | Sample e -> Sample (subst' e) + | Observe (d, e) -> Observe (subst' d, subst' e) + | List _ -> failwith "List not implemented" + | Record _ -> failwith "Record not implemented" + +let rec inline_one (fn : fn) (prog : program) = + let rec inline_exp scope (exp : exp) = + let inline_exp' = inline_exp scope in + match exp with + | (Int _ | Real _ | Bool _) as e -> e + | Var x as e -> if Set.mem scope x then e else raise (Unbound_variable x) + | Add (e1, e2) -> Add (inline_exp' e1, inline_exp' e2) + | Radd (e1, e2) -> Radd (inline_exp' e1, inline_exp' e2) + | Minus (e1, e2) -> Minus (inline_exp' e1, inline_exp' e2) + | Rminus (e1, e2) -> Rminus (inline_exp' e1, inline_exp' e2) + | Neg e -> Neg (inline_exp' e) + | Rneg e -> Rneg (inline_exp' e) + | Mult (e1, e2) -> Mult (inline_exp' e1, inline_exp' e2) + | Rmult (e1, e2) -> Rmult (inline_exp' e1, inline_exp' e2) + | Div (e1, e2) -> Div (inline_exp' e1, inline_exp' e2) + | Rdiv (e1, e2) -> Rdiv (inline_exp' e1, inline_exp' e2) + | Eq (e1, e2) -> Eq (inline_exp' e1, inline_exp' e2) + | Req (e1, e2) -> Req (inline_exp' e1, inline_exp' e2) + | Noteq (e1, e2) -> Noteq (inline_exp' e1, inline_exp' e2) + | Less (e1, e2) -> Less (inline_exp' e1, inline_exp' e2) + | Rless (e1, e2) -> Rless (inline_exp' e1, inline_exp' e2) + | And (e1, e2) -> And (inline_exp' e1, inline_exp' e2) + | Or (e1, e2) -> Or (inline_exp' e1, inline_exp' e2) + | Seq (e1, e2) -> Seq (inline_exp' e1, inline_exp' e2) + | Not e -> Not (inline_exp' e) + | Assign (x, e1, e2) -> + Assign (x, inline_exp' e1, inline_exp (Set.add scope x) e2) + | If (e_pred, e_con, e_alt) -> + If (inline_exp' e_pred, inline_exp' e_con, inline_exp' e_alt) + | Call (f, args) -> + (* A-Normalize the arguments. For example, f(sample(e)) should only evaluate sample(e) once. *) + let args = List.map ~f:inline_exp' args in + if Id.(f <> fn.name) then Call (f, args) + else + let param_args = + try List.zip_exn fn.params args + with _ -> raise (Arity_mismatch fn.name) + in + let param_args = + List.map ~f:(fun (p, a) -> (p, gen_args (), a)) param_args + in + let env = + List.fold param_args ~init:Id.Map.empty ~f:(fun env (p, p', _a) -> + Map.set env ~key:p ~data:p') + in + List.fold param_args ~init:(subst env fn.body) + ~f:(fun body (_p, p', a) -> Assign (p', a, body)) + | Sample e -> Sample (inline_exp' e) + | Observe (d, e) -> Observe (inline_exp' d, inline_exp' e) + | List _ -> failwith "List not implemented" + | Record _ -> failwith "Record not implemented" + in + + let { funs; exp } = prog in + match funs with + | [] -> { funs = []; exp = inline_exp Id.Set.empty exp } + | { name; params; body } :: funs -> + let body = inline_exp (Id.Set.of_list params) body in + if Id.(name = fn.name) then { funs = { name; params; body } :: funs; exp } + else + let { funs; exp } = inline_one fn { funs; exp } in + { funs = { name; params; body } :: funs; exp } + +let rec inline ({ funs; exp } : program) : exp = + match funs with [] -> exp | f :: funs -> inline (inline_one f { funs; exp }) diff --git a/lib/printing.ml b/lib/printing.ml index 793c935..88873b3 100644 --- a/lib/printing.ml +++ b/lib/printing.ml @@ -2,7 +2,7 @@ open! Core open Typed_tree type t = - | Value : Id.t -> t + | Value : string -> t | Var : Id.t -> t | Bop : Id.t * t * t -> t | Uop : Id.t * t -> t @@ -10,6 +10,8 @@ type t = (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) | If : t * t * t -> t + | If_con : t -> t + | If_alt : t -> t | Let : Id.t * t * t -> t | Call : Id.t * t list -> t | Sample : t -> t @@ -27,15 +29,19 @@ type graph = { let rec of_exp : type a d. (a, d) texp -> t = fun { ty; exp } -> match exp with + | If (pred, cons, alt) -> If (of_exp pred, of_exp cons, of_exp alt) + | If_pred (pred, cons) -> If (of_pred pred, of_exp cons, Value "1") + | If_con exp -> If_con (of_exp exp) + | If_alt exp -> If_alt (of_exp exp) | Value v -> ( match ty with - | Tyi -> Value (string_of_int v) - | Tyr -> Value (string_of_float v) - | Tyb -> Value (string_of_bool v)) + | Dat_ty (Tyu, _) -> Value "()" + | Dat_ty (Tyi, _) -> Value (string_of_int v) + | Dat_ty (Tyr, _) -> Value (string_of_float v) + | Dat_ty (Tyb, _) -> Value (string_of_bool v)) | Var v -> Var v | Bop (op, e1, e2) -> Bop (op.name, of_exp e1, of_exp e2) | Uop (op, e) -> Uop (op.name, of_exp e) - | If (pred, cons, alt) -> If (of_exp pred, of_exp cons, of_exp alt) | Let (x, e1, e2) -> Let (x, of_exp e1, of_exp e2) | Call (f, args) -> Call (f.name, of_args args) | Sample e -> Sample (of_exp e) @@ -45,6 +51,13 @@ and of_args : type a d. (a, d) args -> t list = function | [] -> [] | arg :: args -> of_exp arg :: of_args args +and of_pred : pred -> t = function + | Empty -> Value "" + | True -> Value "true" + | False -> Value "false" + | And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp) + | And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp)) + let of_graph ({ vertices; arcs; pmdf_map; obs_map } : Graph.t) : graph = { vertices; @@ -53,4 +66,5 @@ let of_graph ({ vertices; arcs; pmdf_map; obs_map } : Graph.t) : graph = obs_map = Map.map obs_map ~f:(fun (Ex e) -> of_exp e); } -let to_string (Ex e : some_det) = e |> of_exp |> sexp_of_t |> Sexp.to_string_hum +let of_rv (Ex rv : some_rv_texp) = + rv |> of_exp |> sexp_of_t |> Sexp.to_string_hum diff --git a/lib/typed_tree.ml b/lib/typed_tree.ml index 2198738..7cd8f7f 100644 --- a/lib/typed_tree.ml +++ b/lib/typed_tree.ml @@ -1,54 +1,114 @@ open! Core type real = float -type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty +type _ dty = Tyu : unit dty | Tyi : int dty | Tyr : real dty | Tyb : bool dty +type value = Val_ph +type rv = Rv_ph +type _ stamp = Val : value stamp | Rv : rv stamp +type ('a, 'b) dat_ty = Dat_ty_ph +type 'a dist_ty = Dist_ty_ph -let string_of_ty : type a. a ty -> string = function - | Tyi -> "int" - | Tyr -> "real" - | Tyb -> "bool" +type _ ty = + | Dat_ty : 'a dty * 'b stamp -> ('a, 'b) dat_ty ty + | Dist_ty : 'a dty -> 'a dist_ty ty type _ params = | [] : unit params - | ( :: ) : 'a ty * 'b params -> ('a * 'b) params + | ( :: ) : 'a dty * 'b params -> ('a * 'b) params -type det = Det -type non_det = Non_det +type det = Det_ph +type non_det = Non_det_ph type _ vargs = | [] : unit vargs - | ( :: ) : ('a ty * 'a) * 'b vargs -> ('a * 'b) vargs + | ( :: ) : ('a dty * 'a) * 'b vargs -> ('a * 'b) vargs type ('a, 'b) dist = { - ret : 'a ty; + ret : 'a dty; name : Id.t; params : 'b params; sampler : 'b vargs -> 'a; log_pmdf : 'b vargs -> 'a -> real; } -type ('a, 'b, 'c) bop = { name : Id.t; f : 'a -> 'b -> 'c } -type ('a, 'b) uop = { name : Id.t; f : 'a -> 'b } +type ('a, 'b, 'c) bop = { name : Id.t; op : 'a -> 'b -> 'c } +type ('a, 'b) uop = { name : Id.t; op : 'a -> 'b } +(* TODO: Why args should also be det? *) type (_, _) args = | [] : (unit, _) args - | ( :: ) : ('a, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args + | ( :: ) : (('a, _) dat_ty, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args + +and pred = + | Empty : pred + | True : pred + | False : pred + | And : pred * ((bool, _) dat_ty, det) texp -> pred + | And_not : pred * ((bool, _) dat_ty, det) texp -> pred + +and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp } and (_, _) exp = - | Value : 'a -> ('a, _) exp + | Value : 'a -> (('a, value) dat_ty, _) exp | Var : Id.t -> _ exp - | Bop : ('a, 'b, 'c) bop * ('a, 'd) texp * ('b, 'd) texp -> ('c, 'd) exp - | Uop : ('a, 'b) uop * ('a, 'd) texp -> ('b, 'd) exp - (* TODO: Add list and record constructors *) - (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) - (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) - | If : (bool, 'd) texp * ('a, 'd) texp * ('a, 'd) texp -> ('a, 'd) exp + | Bop : + ('a, 'b, 'c) bop * (('a, _) dat_ty, 'd) texp * (('b, _) dat_ty, 'd) texp + -> (('c, _) dat_ty, 'd) exp + | Uop : ('a, 'b) uop * (('a, _) dat_ty, 'd) texp -> (('b, _) dat_ty, 'd) exp + | If : + ((bool, _) dat_ty, 'd) texp + * (('a, _) dat_ty, 'd) texp + * (('a, _) dat_ty, 'd) texp + -> (('a, _) dat_ty, 'd) exp + | If_pred : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp + | If_con : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp + | If_alt : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp | Let : Id.t * ('a, non_det) texp * ('b, non_det) texp -> ('b, non_det) exp - | Call : ('a, 'b) dist * ('b, 'd) args -> ('a, 'd) exp - | Sample : ('a, non_det) texp -> ('a, non_det) exp - | Observe : ('a, non_det) texp * ('a, non_det) texp -> ('a, non_det) exp + | Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp + | Sample : ('a dist_ty, non_det) texp -> (('a, rv) dat_ty, non_det) exp + | Observe : + ('a dist_ty, non_det) texp * (('a, value) dat_ty, non_det) texp + -> ((unit, value) dat_ty, non_det) exp -and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp } +type some_non_det_texp = Ex : (_, non_det) texp -> some_non_det_texp +type some_det = Ex : (_, det) texp -> some_det +type some_rv_texp = Ex : ((_, rv) dat_ty, det) texp -> some_rv_texp +type some_dat_texp = Ex : ((_, value) dat_ty, det) texp -> some_dat_texp +type some_dist_texp = Ex : (_ dist_ty, det) texp -> some_dist_texp +type some_dty = Ex : _ dty -> some_dty +type some_ty = Ex : _ ty -> some_ty +type some_stamp = Ex : _ stamp -> some_stamp + +type _ some_dat_non_det_texp = + | Ex : (('a, _) dat_ty, non_det) texp -> 'a some_dat_non_det_texp + +type 'a dist_non_det_texp = ('a dist_ty, non_det) texp +(* | Ex : ('a dist_ty ty, non_det) texp -> 'a some_dist_non_det_texp*) + +(*type _ some_dist_texp = Ex : ('a dist_ty, non_det) texp -> 'a some_dist_texp*) + +let dty_of_ty : type a. (a, _) dat_ty ty -> a dty = function + | Dat_ty (dty, _) -> dty + +let string_of_dty : type a. a dty -> string = function + | Tyu -> "unit" + | Tyi -> "int" + | Tyr -> "real" + | Tyb -> "bool" + +let string_of_ty : type a. a ty -> string = function + | Dat_ty (Tyu, Val) -> "unit val" + | Dat_ty (Tyi, Val) -> "int val" + | Dat_ty (Tyr, Val) -> "real val" + | Dat_ty (Tyb, Val) -> "bool val" + | Dat_ty (Tyu, Rv) -> "unit rv" + | Dat_ty (Tyi, Rv) -> "int rv" + | Dat_ty (Tyr, Rv) -> "real rv" + | Dat_ty (Tyb, Rv) -> "bool rv" + | Dist_ty Tyu -> "unit dist" + | Dist_ty Tyi -> "int dist" + | Dist_ty Tyr -> "real dist" + | Dist_ty Tyb -> "bool dist" let rec fv : type a. (a, det) exp -> Id.Set.t = function | Value _ -> Id.Set.empty @@ -57,15 +117,16 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function | Uop (_, { exp; _ }) -> fv exp | If ({ exp = e_pred; _ }, { exp = e_cons; _ }, { exp = e_alt; _ }) -> Id.(fv e_pred @| fv e_cons @| fv e_alt) + | If_pred (pred, { exp = e_cons; _ }) -> Id.(fv_pred pred @| fv e_cons) + | If_con { exp; _ } -> fv exp + | If_alt { exp; _ } -> fv exp | Call (_, args) -> fv_args args and fv_args : type a. (a, det) args -> Id.Set.t = function | [] -> Id.Set.empty | { exp; _ } :: es -> Id.(fv exp @| fv_args es) -type some_ndet = Ex : (_, non_det) texp -> some_ndet -type some_det = Ex : (_, det) texp -> some_det -type some_ty = Ex : _ ty -> some_ty -type some_params = Ex : _ params -> some_params -type some_v = Ex : ('a ty * 'a) -> some_v -type some_dist = Ex : _ dist -> some_dist +and fv_pred : pred -> Id.Set.t = function + | Empty | True | False -> Id.Set.empty + | And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) + | And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) diff --git a/lib/typing.ml b/lib/typing.ml index 6614b6b..590ebc7 100644 --- a/lib/typing.ml +++ b/lib/typing.ml @@ -5,415 +5,490 @@ type tyenv = some_ty Id.Map.t exception Arity_mismatch of string exception Unbound_variable of string -exception Type_mismatch of string +exception Type_error of string -let gen_args = - let cnt = ref 0 in - fun () -> - let arg = "$arg" ^ string_of_int !cnt in - incr cnt; - arg +let raise_if_type_error (exp : Parse_tree.exp) (t1 : _ ty) (t2 : _ ty) : _ = + raise + (Type_error + (sprintf + "Branches of an if statement must return the same type: got (%s) and \ + (%s) in %s" + (string_of_ty t1) (string_of_ty t2) + ([%sexp (exp : Parse_tree.exp)] |> Sexp.to_string_hum))) -let rec subst (env : Id.t Id.Map.t) : Parse_tree.exp -> Parse_tree.exp = - let subst' = subst env in +let get_bop : + type a b c. Parse_tree.exp * (a dty * b dty * c dty) -> (a, b, c) bop = function - | (Int _ | Real _ | Bool _) as e -> e - | Var x -> ( - match Map.find env x with - | None -> raise (Unbound_variable x) - | Some v -> Var v) - | Add (e1, e2) -> Add (subst' e1, subst' e2) - | Radd (e1, e2) -> Radd (subst' e1, subst' e2) - | Minus (e1, e2) -> Minus (subst' e1, subst' e2) - | Rminus (e1, e2) -> Rminus (subst' e1, subst' e2) - | Neg e -> Neg (subst' e) - | Rneg e -> Rneg (subst' e) - | Mult (e1, e2) -> Mult (subst' e1, subst' e2) - | Rmult (e1, e2) -> Rmult (subst' e1, subst' e2) - | Div (e1, e2) -> Div (subst' e1, subst' e2) - | Rdiv (e1, e2) -> Rdiv (subst' e1, subst' e2) - | Eq (e1, e2) -> Eq (subst' e1, subst' e2) - | Req (e1, e2) -> Req (subst' e1, subst' e2) - | Noteq (e1, e2) -> Noteq (subst' e1, subst' e2) - | Less (e1, e2) -> Less (subst' e1, subst' e2) - | Rless (e1, e2) -> Rless (subst' e1, subst' e2) - | And (e1, e2) -> And (subst' e1, subst' e2) - | Or (e1, e2) -> Or (subst' e1, subst' e2) - | Seq (e1, e2) -> Seq (subst' e1, subst' e2) - | Not e -> Not (subst' e) - | Assign (x, e1, e2) -> - Assign (x, subst' e1, subst (Map.set env ~key:x ~data:x) e2) - | If (cond, yes, no) -> If (subst' cond, subst' yes, subst' no) - | Call (f, args) -> - let args = List.map ~f:subst' args in - Call (f, args) - | Sample e -> Sample (subst' e) - | Observe (d, e) -> Observe (subst' d, subst' e) - | List _ -> failwith "List not implemented" - | Record _ -> failwith "Record not implemented" - -let rec inline_one (fn : Parse_tree.fn) (prog : Parse_tree.program) = - let open Parse_tree in - let rec inline_exp scope (e : exp) = - let inline_exp' = inline_exp scope in - match e with - | Int _ | Real _ | Bool _ -> e - | Var x -> if Set.mem scope x then e else raise (Unbound_variable x) - | Add (e1, e2) -> Add (inline_exp' e1, inline_exp' e2) - | Radd (e1, e2) -> Radd (inline_exp' e1, inline_exp' e2) - | Minus (e1, e2) -> Minus (inline_exp' e1, inline_exp' e2) - | Rminus (e1, e2) -> Rminus (inline_exp' e1, inline_exp' e2) - | Neg e -> Neg (inline_exp' e) - | Rneg e -> Rneg (inline_exp' e) - | Mult (e1, e2) -> Mult (inline_exp' e1, inline_exp' e2) - | Rmult (e1, e2) -> Rmult (inline_exp' e1, inline_exp' e2) - | Div (e1, e2) -> Div (inline_exp' e1, inline_exp' e2) - | Rdiv (e1, e2) -> Rdiv (inline_exp' e1, inline_exp' e2) - | Eq (e1, e2) -> Eq (inline_exp' e1, inline_exp' e2) - | Req (e1, e2) -> Req (inline_exp' e1, inline_exp' e2) - | Noteq (e1, e2) -> Noteq (inline_exp' e1, inline_exp' e2) - | Less (e1, e2) -> Less (inline_exp' e1, inline_exp' e2) - | Rless (e1, e2) -> Rless (inline_exp' e1, inline_exp' e2) - | And (e1, e2) -> And (inline_exp' e1, inline_exp' e2) - | Or (e1, e2) -> Or (inline_exp' e1, inline_exp' e2) - | Seq (e1, e2) -> Seq (inline_exp' e1, inline_exp' e2) - | Not e -> Not (inline_exp' e) - | Assign (x, e1, e2) -> - Assign (x, inline_exp' e1, inline_exp (Set.add scope x) e2) - | If (cond, yes, no) -> - If (inline_exp' cond, inline_exp' yes, inline_exp' no) - | Call (f, args) as e -> - (* A-Normalize the arguments. For example, f(sample(e)) should only evaluate sample(e) once. *) - let args = List.map ~f:inline_exp' args in - if Id.(f <> fn.name) then e - else - let param_args = - try List.zip_exn fn.params args - with _ -> raise (Arity_mismatch fn.name) - in - let param_args = - List.map ~f:(fun (p, a) -> (p, gen_args (), a)) param_args - in - let env = - List.fold param_args ~init:Id.Map.empty ~f:(fun env (p, p', _a) -> - Map.set env ~key:p ~data:p') - in - List.fold param_args ~init:(subst env fn.body) - ~f:(fun body (_p, p', a) -> Assign (p', a, body)) - | Sample e -> Sample (inline_exp' e) - | Observe (d, e) -> Observe (inline_exp' d, inline_exp' e) - | List _ -> failwith "List not implemented" - | Record _ -> failwith "Record not implemented" - in + | Add _, (Tyi, Tyi, Tyi) -> { name = "+"; op = ( + ) } + | Radd _, (Tyr, Tyr, Tyr) -> { name = "+."; op = ( +. ) } + | Minus _, (Tyi, Tyi, Tyi) -> { name = "-"; op = ( - ) } + | Rminus _, (Tyr, Tyr, Tyr) -> { name = "-."; op = ( -. ) } + | Mult _, (Tyi, Tyi, Tyi) -> { name = "*"; op = ( * ) } + | Rmult _, (Tyr, Tyr, Tyr) -> { name = "*."; op = ( *. ) } + | Div _, (Tyi, Tyi, Tyi) -> { name = "/"; op = ( / ) } + | Rdiv _, (Tyr, Tyr, Tyr) -> { name = "/."; op = ( /. ) } + | Eq _, (Tyi, Tyi, Tyb) -> { name = "="; op = ( = ) } + | Req _, (Tyr, Tyr, Tyb) -> { name = "=."; op = Float.( = ) } + | Noteq _, (Tyi, Tyi, Tyb) -> { name = "<>"; op = ( <> ) } + | Less _, (Tyi, Tyi, Tyb) -> { name = "<"; op = ( < ) } + | Rless _, (Tyr, Tyr, Tyb) -> { name = "<."; op = Float.( < ) } + | And _, (Tyb, Tyb, Tyb) -> { name = "&&"; op = ( && ) } + | Or _, (Tyb, Tyb, Tyb) -> { name = "||"; op = ( || ) } + | _ -> raise (Type_error "Expected binary operation") - let { funs; exp } = prog in - match funs with - | [] -> { funs = []; exp = inline_exp Id.Set.empty exp } - | { name; params; body } :: funs -> - let body = inline_exp (Id.Set.of_list params) body in - if Id.(name = fn.name) then { funs = { name; params; body } :: funs; exp } - else - let { funs; exp } = inline_one fn { funs; exp } in - { funs = { name; params; body } :: funs; exp } +let get_uop : type a b. Parse_tree.exp * (a dty * b dty) -> (a, b) uop = + function + | Neg _, (Tyi, Tyi) -> { name = "~-"; op = ( ~- ) } + | Rneg _, (Tyr, Tyr) -> { name = "~-."; op = ( ~-. ) } + | Not _, (Tyb, Tyb) -> { name = "!"; op = not } + | e, _ -> + raise + (Type_error + ("Expected unary operation, got " + ^ ([%sexp (e : Parse_tree.exp)] |> Sexp.to_string_hum))) -let rec inline (prog : Parse_tree.program) = - let open Parse_tree in - let { funs; exp } = prog in - match funs with - | [] -> exp - | fn :: funs -> inline (inline_one fn { funs; exp }) +let rec check_dat : + type a. tyenv -> Parse_tree.exp * a dty -> a some_dat_non_det_texp = + fun tyenv (exp, dty) -> + Logs.debug (fun m -> + m "Checking exp (%a : %a)" Sexp.pp_hum + [%sexp (exp : Parse_tree.exp)] + (fun fmt dty -> Format.pp_print_string fmt (string_of_dty dty)) + dty); -let rec check : type a. tyenv -> Parse_tree.exp -> a ty -> (a, non_det) texp = - fun tyenv e ty -> - match e with + match exp with | Var x -> ( match Map.find tyenv x with | None -> raise (Unbound_variable x) - | Some (Ex t) -> ( - match (t, ty) with - | Tyi, Tyi -> { ty; exp = Var x } - | Tyr, Tyr -> { ty; exp = Var x } - | Tyb, Tyb -> { ty; exp = Var x } - | t1, t2 -> + | Some (Ex ty_x) -> ( + match (ty_x, dty) with + | Dat_ty (Tyu, _), Tyu -> Ex { ty = ty_x; exp = Var x } + | Dat_ty (Tyi, _), Tyi -> Ex { ty = ty_x; exp = Var x } + | Dat_ty (Tyr, _), Tyr -> Ex { ty = ty_x; exp = Var x } + | Dat_ty (Tyb, _), Tyb -> Ex { ty = ty_x; exp = Var x } + | ty_x, dty -> raise - (Type_mismatch - (sprintf "Variable %s: expected %s, got %s" x - (string_of_ty t2) (string_of_ty t1))))) + (Type_error + (sprintf "Variable %s: expected (%s _), got %s" x + (string_of_dty dty) (string_of_ty ty_x))))) | Int i -> ( - match ty with - | Tyi -> { ty; exp = Value i } - | t -> - raise - (Type_mismatch (sprintf "Expected int, got %s" (string_of_ty t)))) - | Add (e1, e2) -> ( - match ty with - | Tyi -> check_bop tyenv "+" ( + ) Tyi Tyi Tyi e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected int for Add, got %s" (string_of_ty t)))) - | Minus (e1, e2) -> ( - match ty with - | Tyi -> check_bop tyenv "-" ( - ) Tyi Tyi Tyi e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected int for Minus, got %s" (string_of_ty t)))) - | Neg e -> ( - match ty with - | Tyi -> check_uop tyenv "-" Int.neg Tyi Tyi e - | t -> + match dty with + | Tyi -> Ex { ty = Dat_ty (Tyi, Val); exp = Value i } + | dty -> raise - (Type_mismatch - (sprintf "Expected int for Neg, got %s" (string_of_ty t)))) - | Mult (e1, e2) -> ( - match ty with - | Tyi -> check_bop tyenv "*" ( * ) Tyi Tyi Tyi e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected int for Mult, got %s" (string_of_ty t)))) - | Div (e1, e2) -> ( - match ty with - | Tyi -> check_bop tyenv "/" ( / ) Tyi Tyi Tyi e1 e2 - | t -> + (Type_error (sprintf "Expected int, got %s" (string_of_dty dty)))) + | Bool b -> ( + match dty with + | Tyb -> Ex { ty = Dat_ty (Tyb, Val); exp = Value b } + | dty -> raise - (Type_mismatch - (sprintf "Expected int for Div, got %s" (string_of_ty t)))) + (Type_error (sprintf "Expected bool, got %s" (string_of_dty dty)))) | Real r -> ( - match ty with - | Tyr -> { ty; exp = Value r } - | t -> - raise - (Type_mismatch (sprintf "Expected float, got %s" (string_of_ty t)))) - | Radd (e1, e2) -> ( - match ty with - | Tyr -> check_bop tyenv "+" ( +. ) Tyr Tyr Tyr e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected float for Radd, got %s" (string_of_ty t)))) - | Rminus (e1, e2) -> ( - match ty with - | Tyr -> check_bop tyenv "-" ( -. ) Tyr Tyr Tyr e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected float for Rminus, got %s" (string_of_ty t)))) - | Rneg e -> ( - match ty with - | Tyr -> check_uop tyenv "-" Float.neg Tyr Tyr e - | t -> - raise - (Type_mismatch - (sprintf "Expected float for Rneg, got %s" (string_of_ty t)))) - | Rmult (e1, e2) -> ( - match ty with - | Tyr -> check_bop tyenv "*" ( *. ) Tyr Tyr Tyr e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected float for Rmult, got %s" (string_of_ty t)))) - | Rdiv (e1, e2) -> ( - match ty with - | Tyr -> check_bop tyenv "/" ( /. ) Tyr Tyr Tyr e1 e2 - | t -> - raise - (Type_mismatch - (sprintf "Expected float for Rdiv, got %s" (string_of_ty t)))) - | Bool b -> ( - match ty with - | Tyb -> { ty; exp = Value b } - | t -> + match dty with + | Tyr -> Ex { ty = Dat_ty (Tyr, Val); exp = Value r } + | dty -> raise - (Type_mismatch (sprintf "Expected bool, got %s" (string_of_ty t)))) - | Eq (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "=" Int.( = ) Tyi Tyi Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected real, got %s" (string_of_dty dty)))) + | Add (e1, e2) | Minus (e1, e2) | Mult (e1, e2) | Div (e1, e2) -> ( + match dty with + | Tyi -> + let bop = get_bop (exp, (Tyi, Tyi, Tyi)) in + check_bop tyenv bop (e1, Tyi) (e2, Tyi) Tyi + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Eq, got %s" (string_of_ty t)))) - | Req (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "=" Float.( = ) Tyr Tyr Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected int, got %s" (string_of_dty dty)))) + | Radd (e1, e2) | Rminus (e1, e2) | Rmult (e1, e2) | Rdiv (e1, e2) -> ( + match dty with + | Tyr -> + let bop = get_bop (exp, (Tyr, Tyr, Tyr)) in + check_bop tyenv bop (e1, Tyr) (e2, Tyr) Tyr + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Req, got %s" (string_of_ty t)))) - | Noteq (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "<>" Int.( <> ) Tyi Tyi Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected real, got %s" (string_of_dty dty)))) + | Eq (e1, e2) | Noteq (e1, e2) | Less (e1, e2) -> ( + match dty with + | Tyb -> + let bop = get_bop (exp, (Tyi, Tyi, Tyb)) in + check_bop tyenv bop (e1, Tyi) (e2, Tyi) Tyb + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Noteq, got %s" (string_of_ty t)))) - | Less (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "<" Int.( < ) Tyi Tyi Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected int, got %s" (string_of_dty dty)))) + | Req (e1, e2) | Rless (e1, e2) -> ( + match dty with + | Tyb -> + let bop = get_bop (exp, (Tyr, Tyr, Tyb)) in + check_bop tyenv bop (e1, Tyr) (e2, Tyr) Tyb + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Less, got %s" (string_of_ty t)))) - | Rless (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "<" Float.( < ) Tyr Tyr Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected real, got %s" (string_of_dty dty)))) + | And (e1, e2) | Or (e1, e2) -> ( + match dty with + | Tyb -> + let bop = get_bop (exp, (Tyb, Tyb, Tyb)) in + check_bop tyenv bop (e1, Tyb) (e2, Tyb) Tyb + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Rless, got %s" (string_of_ty t)))) - | And (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "&&" ( && ) Tyb Tyb Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected bool, got %s" (string_of_dty dty)))) + | Neg e -> ( + match dty with + | Tyi -> + let uop = get_uop (exp, (Tyi, Tyi)) in + check_uop tyenv uop (e, Tyi) Tyi + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for And, got %s" (string_of_ty t)))) - | Or (e1, e2) -> ( - match ty with - | Tyb -> check_bop tyenv "||" ( || ) Tyb Tyb Tyb e1 e2 - | t -> + (Type_error (sprintf "Expected int, got %s" (string_of_dty dty)))) + | Rneg e -> ( + match dty with + | Tyr -> + let uop = get_uop (exp, (Tyr, Tyr)) in + check_uop tyenv uop (e, Tyr) Tyr + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Or, got %s" (string_of_ty t)))) + (Type_error (sprintf "Expected real, got %s" (string_of_dty dty)))) | Not e -> ( - match ty with - | Tyb -> check_uop tyenv "!" not Tyb Tyb e - | t -> + match dty with + | Tyb -> + let uop = get_uop (exp, (Tyb, Tyb)) in + check_uop tyenv uop (e, Tyb) Tyb + | dty -> raise - (Type_mismatch - (sprintf "Expected bool for Not, got %s" (string_of_ty t)))) - | Observe (d, e) -> ( - let (Ex td) = convert tyenv d in - let (Ex te) = convert tyenv e in - match (ty, td.ty, te.ty) with - | Tyi, Tyi, Tyi -> { ty; exp = Observe (td, te) } - | Tyr, Tyr, Tyr -> { ty; exp = Observe (td, te) } - | Tyb, Tyb, Tyb -> { ty; exp = Observe (td, te) } - | _ -> + (Type_error (sprintf "Expected bool, got %s" (string_of_dty dty)))) + | Observe (de, ve) -> ( + match dty with + | Tyu -> ( + let tde = infer tyenv de in + let tve = infer tyenv ve in + match (tde, tve) with + | ( Ex ({ ty = Dist_ty Tyu; _ } as tde), + Ex ({ ty = Dat_ty (Tyu, Val); _ } as tve) ) -> + Ex { ty = Dat_ty (Tyu, Val); exp = Observe (tde, tve) } + | ( Ex ({ ty = Dist_ty Tyb; _ } as tde), + Ex ({ ty = Dat_ty (Tyb, Val); _ } as tve) ) -> + Ex { ty = Dat_ty (Tyu, Val); exp = Observe (tde, tve) } + | ( Ex ({ ty = Dist_ty Tyi; _ } as tde), + Ex ({ ty = Dat_ty (Tyi, Val); _ } as tve) ) -> + Ex { ty = Dat_ty (Tyu, Val); exp = Observe (tde, tve) } + | ( Ex ({ ty = Dist_ty Tyr; _ } as tde), + Ex ({ ty = Dat_ty (Tyr, Val); _ } as tve) ) -> + Ex { ty = Dat_ty (Tyu, Val); exp = Observe (tde, tve) } + | _, _ -> + (* TODO: more precise error message *) + raise + (Type_error + (sprintf "Arguments to observe have unexpected types"))) + | dty -> raise - (Type_mismatch (sprintf "Argument to observe has different types"))) + (Type_error (sprintf "Expected unit, got %s" (string_of_dty dty)))) | Seq (e1, e2) -> - let (Ex te1) = convert tyenv e1 in - let te2 = check tyenv e2 ty in - { ty; exp = Let ("_", te1, te2) } + let (Ex te1) = infer tyenv e1 in + let (Ex te2) = check_dat tyenv (e2, dty) in + Ex { te2 with exp = Let ("_", te1, te2) } | Assign (x, e1, e2) -> - let (Ex ({ ty = ty1; exp = _ } as te1)) = convert tyenv e1 in - let tyenv = Map.set tyenv ~key:x ~data:(Ex ty1) in - let te2 = check tyenv e2 ty in - { ty; exp = Let (x, te1, te2) } - | If (pred, conseq, alt) -> - let tpred = check tyenv pred Tyb in - let tconseq = check tyenv conseq ty in - let talt = check tyenv alt ty in - { ty; exp = If (tpred, tconseq, talt) } - | Call (prim, args) -> ( - let (Ex dist) = Dist.get_dist prim in - let args = check_args tyenv args dist.params in - match (dist.ret, ty) with - | Tyi, Tyi -> { ty; exp = Call (dist, args) } - | Tyr, Tyr -> { ty; exp = Call (dist, args) } - | Tyb, Tyb -> { ty; exp = Call (dist, args) } - | _ -> - raise - (Type_mismatch - (sprintf "Expected %s for Call, got %s" (string_of_ty dist.ret) - (string_of_ty ty)))) - | Sample e -> - let te = check tyenv e ty in - { ty; exp = Sample te } - | List _ -> raise (Type_mismatch "List not implemented") - | Record _ -> raise (Type_mismatch "Record not implemented") + let (Ex te1) = infer tyenv e1 in + let tyenv = Map.set tyenv ~key:x ~data:(Ex te1.ty) in + let (Ex te2) = check_dat tyenv (e2, dty) in + Ex { te2 with exp = Let (x, te1, te2) } + | If (e_pred, e_con, e_alt) -> ( + let (Ex te_pred) = check_dat tyenv (e_pred, Tyb) in + let (Ex te_con) = check_dat tyenv (e_con, dty) in + let (Ex te_alt) = check_dat tyenv (e_alt, dty) in + match te_pred.ty with + | Dat_ty (Tyb, Val) -> ( + match (te_con.ty, te_alt.ty) with + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) }) + | Dat_ty (Tyb, Rv) -> ( + match (te_con.ty, te_alt.ty) with + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) })) + | Sample e -> ( + let te = check_dist tyenv (e, dty) in + match te.ty with + | Dist_ty Tyu -> Ex { ty = Dat_ty (Tyu, Rv); exp = Sample te } + | Dist_ty Tyb -> Ex { ty = Dat_ty (Tyb, Rv); exp = Sample te } + | Dist_ty Tyi -> Ex { ty = Dat_ty (Tyi, Rv); exp = Sample te } + | Dist_ty Tyr -> Ex { ty = Dat_ty (Tyr, Rv); exp = Sample te }) + | List _ -> raise (Type_error "List not implemented") + | Record _ -> raise (Type_error "Record not implemented") + | Call (f, e) -> + raise + (Type_error + ("Expected data type, got distribution: " ^ f ^ " " + ^ ([%sexp (e : Parse_tree.exp list)] |> Sexp.to_string_hum))) and check_uop : - type arg ret. + type a ret. tyenv -> - Id.t -> - (arg -> ret) -> - arg ty -> - ret ty -> - Parse_tree.exp -> - (ret, non_det) texp = - fun tyenv name f t ty e -> - let te = check tyenv e t in - { ty; exp = Uop ({ name; f }, te) } + (a, ret) uop -> + Parse_tree.exp * a dty -> + ret dty -> + ret some_dat_non_det_texp = + fun tyenv uop (e, t) tret -> + let (Ex ({ ty = Dat_ty (_, s); _ } as te)) = check_dat tyenv (e, t) in + match s with + | Val -> Ex { ty = Dat_ty (tret, Val); exp = Uop (uop, te) } + | _ -> Ex { ty = Dat_ty (tret, Rv); exp = Uop (uop, te) } and check_bop : - type arg1 arg2 ret. + type a1 a2 ret. tyenv -> - Id.t -> - (arg1 -> arg2 -> ret) -> - arg1 ty -> - arg2 ty -> - ret ty -> - Parse_tree.exp -> - Parse_tree.exp -> - (ret, non_det) texp = - fun tyenv name f t1 t2 ty e1 e2 -> - let te1 = check tyenv e1 t1 in - let te2 = check tyenv e2 t2 in - { ty; exp = Bop ({ name; f }, te1, te2) } + (a1, a2, ret) bop -> + Parse_tree.exp * a1 dty -> + Parse_tree.exp * a2 dty -> + ret dty -> + ret some_dat_non_det_texp = + fun tyenv bop (e1, t1) (e2, t2) tret -> + let (Ex ({ ty = Dat_ty (_, s1); _ } as te1)) = check_dat tyenv (e1, t1) in + let (Ex ({ ty = Dat_ty (_, s2); _ } as te2)) = check_dat tyenv (e2, t2) in + match (s1, s2) with + | Val, Val -> Ex { ty = Dat_ty (tret, Val); exp = Bop (bop, te1, te2) } + | _, _ -> Ex { ty = Dat_ty (tret, Rv); exp = Bop (bop, te1, te2) } and check_args : - type a. tyenv -> Parse_tree.exp list -> a params -> (a, non_det) args = - fun tyenv el tyl -> - match tyl with + type a. tyenv -> Id.t -> Parse_tree.exp list * a params -> (a, non_det) args + = + fun tyenv prim (es, dtys) -> + match dtys with | [] -> [] - | argty :: argtys -> ( - match el with - | [] -> failwith "Primitive call failed" + | dty :: dtys -> ( + match es with + | [] -> raise (Arity_mismatch prim) | arg :: args -> - let arg = check tyenv arg argty in - let args = check_args tyenv args argtys in + let (Ex arg) = check_dat tyenv (arg, dty) in + let args = check_args tyenv prim (args, dtys) in arg :: args) -and convert (tyenv : tyenv) (e : Parse_tree.exp) : some_ndet = - match e with +and check_dist : type a. tyenv -> Parse_tree.exp * a dty -> a dist_non_det_texp + = + fun tyenv (exp, dty) -> + Logs.debug (fun m -> + m "Checking exp (%a : %a dist)" Sexp.pp_hum + [%sexp (exp : Parse_tree.exp)] + (fun fmt dty -> Format.pp_print_string fmt (string_of_dty dty)) + dty); + + match exp with | Var x -> ( match Map.find tyenv x with - | None -> failwith ("Unbound variable " ^ x) + | None -> raise (Unbound_variable x) + | Some (Ex ty_x) -> ( + match (ty_x, dty) with + | Dist_ty Tyu, Tyu -> { ty = ty_x; exp = Var x } + | Dist_ty Tyb, Tyb -> { ty = ty_x; exp = Var x } + | Dist_ty Tyi, Tyi -> { ty = ty_x; exp = Var x } + | Dist_ty Tyr, Tyr -> { ty = ty_x; exp = Var x } + | ty_x, dty -> + raise + (Type_error + (sprintf "Variable %s: expected (%s _), got %s" x + (string_of_dty dty) (string_of_ty ty_x))))) + | Seq (e1, e2) -> + let (Ex te1) = infer tyenv e1 in + let te2 = check_dist tyenv (e2, dty) in + { ty = te2.ty; exp = Let ("_", te1, te2) } + | Assign (x, e1, e2) -> + let (Ex te1) = infer tyenv e1 in + let tyenv = Map.set tyenv ~key:x ~data:(Ex te1.ty) in + let te2 = check_dist tyenv (e2, dty) in + { ty = te2.ty; exp = Let (x, te1, te2) } + | If _ -> + raise (Type_error "You cannot return a distribution from a conditional") + | Call (prim, args) -> ( + let (Ex dist) = Dist.get_dist prim in + let args = check_args tyenv dist.name (args, dist.params) in + match (dist.ret, dty) with + | Tyu, Tyu -> { ty = Dist_ty Tyu; exp = Call (dist, args) } + | Tyb, Tyb -> { ty = Dist_ty Tyb; exp = Call (dist, args) } + | Tyi, Tyi -> { ty = Dist_ty Tyi; exp = Call (dist, args) } + | Tyr, Tyr -> { ty = Dist_ty Tyr; exp = Call (dist, args) } + | _ -> + raise + (Type_error + (sprintf "Expected %s for Call, got %s" (string_of_dty dist.ret) + (string_of_dty dty)))) + | Bool _ | Int _ | Real _ | Add _ | Radd _ | Minus _ | Rminus _ | Mult _ + | Rmult _ | Div _ | Rdiv _ | Eq _ | Req _ | Noteq _ | Less _ | Rless _ | And _ + | Or _ | Neg _ | Rneg _ | Not _ | Sample _ | Observe _ | List _ | Record _ -> + raise (Type_error "Expected distribution") + +and infer (tyenv : tyenv) (exp : Parse_tree.exp) : some_non_det_texp = + Logs.debug (fun m -> + m "Infering exp %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]); + match exp with + | Var x -> ( + match Map.find tyenv x with + | None -> raise (Unbound_variable x) | Some (Ex t) -> Ex { ty = t; exp = Var x }) - | Int _ | Add _ | Minus _ | Neg _ | Mult _ | Div _ -> Ex (check tyenv e Tyi) - | Real _ | Radd _ | Rminus _ | Rneg _ | Rmult _ | Rdiv _ -> - Ex (check tyenv e Tyr) - | Bool _ | Eq _ | Req _ | Noteq _ | Less _ | Rless _ | And _ | Or _ | Not _ -> - Ex (check tyenv e Tyb) - | Observe (d, e) -> ( - let (Ex td) = convert tyenv d in - let (Ex te) = convert tyenv e in - match (td.ty, te.ty) with - | Tyi, Tyi -> Ex { ty = Tyi; exp = Observe (td, te) } - | Tyr, Tyr -> Ex { ty = Tyr; exp = Observe (td, te) } - | Tyb, Tyb -> Ex { ty = Tyb; exp = Observe (td, te) } - | _, _ -> failwith "Argument to observe has different types.") + | (Int _ | Add _ | Minus _ | Neg _ | Mult _ | Div _) as e -> + let (Ex t) = check_dat tyenv (e, Tyi) in + Ex t + | (Real _ | Radd _ | Rminus _ | Rneg _ | Rmult _ | Rdiv _) as e -> + let (Ex t) = check_dat tyenv (e, Tyr) in + Ex t + | (Bool _ | Eq _ | Req _ | Noteq _ | Less _ | Rless _ | And _ | Or _ | Not _) + as e -> + let (Ex t) = check_dat tyenv (e, Tyb) in + Ex t + | Observe _ as e -> + let (Ex t) = check_dat tyenv (e, Tyu) in + Ex t | Seq (e1, e2) -> - let (Ex te1) = convert tyenv e1 in - let (Ex ({ ty = ty2; exp = _ } as te2)) = convert tyenv e2 in - Ex { ty = ty2; exp = Let ("_", te1, te2) } + let (Ex te1) = infer tyenv e1 in + let (Ex te2) = infer tyenv e2 in + Ex { ty = te2.ty; exp = Let ("_", te1, te2) } | Assign (x, e1, e2) -> - let (Ex ({ ty = ty1; exp = _ } as te1)) = convert tyenv e1 in - let tyenv = Map.set tyenv ~key:x ~data:(Ex ty1) in - let (Ex ({ ty = ty2; exp = _ } as te2)) = convert tyenv e2 in - Ex { ty = ty2; exp = Let (x, te1, te2) } - | If (pred, conseq, alt) -> ( - let tpred = check tyenv pred Tyb in - let (Ex tconseq) = convert tyenv conseq in - let (Ex talt) = convert tyenv alt in - match (tconseq.ty, talt.ty) with - | Tyi, Tyi -> Ex { ty = Tyi; exp = If (tpred, tconseq, talt) } - | Tyr, Tyr -> Ex { ty = Tyr; exp = If (tpred, tconseq, talt) } - | Tyb, Tyb -> Ex { ty = Tyb; exp = If (tpred, tconseq, talt) } - | _, _ -> failwith "Branches of an if statement must return the same type" - ) + let (Ex te1) = infer tyenv e1 in + let tyenv = Map.set tyenv ~key:x ~data:(Ex te1.ty) in + let (Ex te2) = infer tyenv e2 in + Ex { ty = te2.ty; exp = Let (x, te1, te2) } + | If (e_pred, e_con, e_alt) -> ( + let (Ex te_pred) = check_dat tyenv (e_pred, Tyb) in + let (Ex te_con) = infer tyenv e_con in + let (Ex te_alt) = infer tyenv e_alt in + match te_pred.ty with + | Dat_ty (Tyb, Val) -> ( + match (te_con.ty, te_alt.ty) with + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Val); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | t1, t2 -> raise_if_type_error exp t1 t2) + | Dat_ty (Tyb, Rv) -> ( + match (te_con.ty, te_alt.ty) with + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Val), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Val) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> + Ex { ty = Dat_ty (Tyu, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Val), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Val) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> + Ex { ty = Dat_ty (Tyb, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Val), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Val) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> + Ex { ty = Dat_ty (Tyi, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Val), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Val) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> + Ex { ty = Dat_ty (Tyr, Rv); exp = If (te_pred, te_con, te_alt) } + | t1, t2 -> raise_if_type_error exp t1 t2)) | Call (prim, args) -> let (Ex dist) = Dist.get_dist prim in - let args = check_args tyenv args dist.params in - Ex { ty = dist.ret; exp = Call (dist, args) } - | Sample e -> - let (Ex te) = convert tyenv e in - Ex { ty = te.ty; exp = Sample te } + let args = check_args tyenv prim (args, dist.params) in + Ex { ty = Dist_ty dist.ret; exp = Call (dist, args) } + | Sample e -> ( + let (Ex te) = infer tyenv e in + match te with + | { ty = Dist_ty Tyu; _ } -> Ex { ty = Dat_ty (Tyu, Rv); exp = Sample te } + | { ty = Dist_ty Tyb; _ } -> Ex { ty = Dat_ty (Tyb, Rv); exp = Sample te } + | { ty = Dist_ty Tyi; _ } -> Ex { ty = Dat_ty (Tyi, Rv); exp = Sample te } + | { ty = Dist_ty Tyr; _ } -> Ex { ty = Dat_ty (Tyr, Rv); exp = Sample te } + | _ -> raise (Type_error "Expected distribution")) | List _ -> failwith "List not implemented" | Record _ -> failwith "Record not implemented" -let check_program (program : Parse_tree.program) : some_ndet = - convert Id.Map.empty (inline program) +let check : Parse_tree.exp -> some_non_det_texp = infer Id.Map.empty diff --git a/samples/normal_bernoulli.png b/samples/normal_bernoulli.png index 752f5c0..e8d5183 100644 Binary files a/samples/normal_bernoulli.png and b/samples/normal_bernoulli.png differ diff --git a/samples/simple_itpp.png b/samples/simple_itpp.png index bcf754a..4d23ca1 100644 Binary files a/samples/simple_itpp.png and b/samples/simple_itpp.png differ diff --git a/samples/student.png b/samples/student.png index 6c369a7..25e674e 100644 Binary files a/samples/student.png and b/samples/student.png differ