diff --git a/lib/compiler.ml b/lib/compiler.ml index d3ac648..119cf46 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -60,39 +60,37 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option = ({ ty; exp = Value v } :: tl, Some ((ty, v) :: vargs)) | te, (tl, _) -> (te :: tl, None)) -let rec score : type a. (a, det) texp -> Id.t -> (a, det) texp = - fun e var -> - match e.exp with - | If (e_pred, e_con, e_alt) -> - let s_con = score e_con var in - let s_alt = score e_alt var in - { ty = e.ty; exp = If (e_pred, s_con, s_alt) } - | Call _ -> e +let rec score : type a. (a, det) texp -> (a, det) texp = function + | { ty; exp = If (e_pred, e_con, e_alt) } -> + let s_con = score e_con and s_alt = score e_alt in + { ty; exp = If (e_pred, s_con, s_alt) } + | { exp = Call _; _ } as e -> e | _ -> raise Score_invalid_arguments +type pred = (bool, det) texp + let rec compile : - type a. - env -> (bool, det) texp -> (a, non_det) texp -> Graph.t * (a, det) texp = - fun env pred e -> - let { ty; exp } = e in + type a. env:env -> pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp + = + fun ~env ~pred { ty; exp } -> match exp with | Value v -> (Graph.empty, { ty; exp = Value v }) | Var x -> ( - let (Ex { ty = tx; exp }) = Map.find_exn env x in - match (tx, ty) with + let (Ex { ty = tyx; exp }) = Map.find_exn env x in + match (tyx, ty) with | Tyi, Tyi -> (Graph.empty, { ty; exp }) | Tyr, Tyr -> (Graph.empty, { ty; exp }) | Tyb, Tyb -> (Graph.empty, { ty; exp }) | _, _ -> assert false) | Bop (op, e1, e2) -> - let g1, te1 = compile env pred e1 in - let g2, te2 = compile env pred e2 in + let g1, te1 = compile ~env ~pred e1 in + let g2, te2 = compile ~env ~pred e2 in Graph.(g1 @| g2, { ty; exp = Bop (op, te1, te2) }) | Uop (op, e) -> - let g, te = compile env pred e in + let g, te = compile ~env ~pred e in (g, { ty; exp = Uop (op, te) }) | If (e_pred, e_con, e_alt) -> ( - let g1, de_pred = compile env pred e_pred in + let g1, de_pred = compile ~env ~pred e_pred in let pred_con = peval { ty = Tyb; exp = Bop ({ f = ( && ); name = "&&" }, pred, de_pred) } @@ -108,27 +106,27 @@ let rec compile : { ty = Tyb; exp = Uop ({ f = not; name = "!" }, de_pred) } ); } in - let g2, de_con = compile env pred_con e_con in - let g3, de_alt = compile env pred_alt e_alt in + let g2, de_con = compile ~env ~pred:pred_con e_con in + let g3, de_alt = compile ~env ~pred:pred_alt e_alt in let g = Graph.(g1 @| g2 @| g3) in match pred_con.exp with | Value true -> (g, de_con) | Value false -> (g, de_alt) | _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) })) | Let (x, e, body) -> - let g1, det_exp1 = compile env pred e in + let g1, det_exp1 = compile ~env ~pred e in let g2, det_exp2 = - compile (Map.set env ~key:x ~data:(Ex det_exp1)) pred body + compile ~env:(Map.set env ~key:x ~data:(Ex det_exp1)) ~pred body in Graph.(g1 @| g2, det_exp2) | Call (f, args) -> let g, args = compile_args env pred args in (g, { ty; exp = Call (f, args) }) | Sample e -> - let g, de = compile env pred e in + let g, de = compile ~env ~pred e in let v = gen_vertex () in let de_fvs = fv de.exp in - let f : some_det = Ex (score de v) in + let f : some_det = Ex (score de) in let g' = Graph. { @@ -140,39 +138,13 @@ let rec compile : in Graph.(g @| g', { ty; exp = Var v }) | Observe (e1, e2) -> - let g1, de1 = compile env pred e1 in - let g2, de2 = compile env pred e2 in + let g1, de1 = compile ~env ~pred e1 in + let g2, de2 = compile ~env ~pred e2 in let v = gen_vertex () in - let f1 = score de1 v in - let one : type a. a ty -> (a, unit) dist = - fun ty -> - match ty with - | Tyi -> - { - ret = ty; - name = "one"; - params = []; - sampler = (fun _ -> 1); - log_pmdf = (fun [] _ -> 0.0); - } - | Tyr -> - { - ret = ty; - name = "one"; - params = []; - sampler = (fun _ -> 1.0); - log_pmdf = (fun [] _ -> 0.0); - } - | Tyb -> - { - ret = Tyb; - name = "one"; - params = []; - sampler = (fun _ -> true); - log_pmdf = (fun [] _ -> 0.0); - } + let f1 = score de1 in + let f = + { ty; exp = If (pred, f1, { ty; exp = Call (Dist.one ty, []) }) } in - let f = { ty; exp = If (pred, f1, { ty; exp = Call (one ty, []) }) } in let fvs = Id.(fv de1.exp @| fv pred.exp) in if not (Set.is_empty (fv de2.exp)) then raise Not_closed_observation; let g' = @@ -187,18 +159,16 @@ let rec compile : Graph.(g1 @| g2 @| g', de2) and compile_args : - type a. - env -> (bool, det) texp -> (a, non_det) args -> Graph.t * (a, det) args = + type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args = fun env pred args -> match args with | [] -> (Graph.empty, []) | arg :: args -> - let g, arg = compile env pred arg in + let g, arg = compile ~env ~pred arg in let g', args = compile_args env pred args in Graph.(g @| g', arg :: args) let compile_program (prog : program) : Graph.t * some_det = - let open Typing in - let (Ex e) = convert Id.Map.empty (inline prog) in - let g, e = compile Id.Map.empty { ty = Tyb; exp = Value true } e in + let (Ex e) = Typing.check_program prog in + let g, e = compile ~env:Id.Map.empty ~pred:{ ty = Tyb; exp = Value true } e in (g, Ex e) diff --git a/lib/dist.ml b/lib/dist.ml new file mode 100644 index 0000000..be83080 --- /dev/null +++ b/lib/dist.ml @@ -0,0 +1,53 @@ +open! Core +open Typed_tree + +let one : type a. a ty -> (a, unit) dist = function + | Tyi -> + { + ret = Tyi; + name = "one"; + params = []; + sampler = (fun [] -> 1); + log_pmdf = (fun [] _ -> 0.0); + } + | Tyr -> + { + ret = Tyr; + name = "one"; + params = []; + sampler = (fun [] -> 1.0); + log_pmdf = (fun [] _ -> 0.0); + } + | Tyb -> + { + ret = Tyb; + name = "one"; + params = []; + sampler = (fun [] -> true); + log_pmdf = (fun [] _ -> 0.0); + } + +let get_dist (name : Id.t) : some_dist = + let open Owl.Stats in + match name with + | "bernoulli" -> + Ex + { + ret = Tyb; + name = "bernoulli"; + params = [ Tyr ]; + sampler = (fun [ (Tyr, p) ] -> binomial_rvs ~p ~n:1 = 1); + log_pmdf = + (fun [ (Tyr, p) ] b -> binomial_logpdf ~p ~n:1 (Bool.to_int b)); + } + | "normal" -> + Ex + { + ret = Tyr; + name = "normal"; + params = [ Tyr; Tyr ]; + sampler = (fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_rvs ~mu ~sigma); + log_pmdf = + (fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_logpdf ~mu ~sigma); + } + | _ -> failwith "Unknown primitive function" diff --git a/lib/typed_tree.ml b/lib/typed_tree.ml index 3976ecc..2198738 100644 --- a/lib/typed_tree.ml +++ b/lib/typed_tree.ml @@ -3,6 +3,11 @@ open! Core type real = float type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty +let string_of_ty : type a. a ty -> string = function + | Tyi -> "int" + | Tyr -> "real" + | Tyb -> "bool" + type _ params = | [] : unit params | ( :: ) : 'a ty * 'b params -> ('a * 'b) params diff --git a/lib/typing.ml b/lib/typing.ml index 1865155..6614b6b 100644 --- a/lib/typing.ml +++ b/lib/typing.ml @@ -3,6 +3,10 @@ open Typed_tree type tyenv = some_ty Id.Map.t +exception Arity_mismatch of string +exception Unbound_variable of string +exception Type_mismatch of string + let gen_args = let cnt = ref 0 in fun () -> @@ -10,13 +14,13 @@ let gen_args = incr cnt; arg -let rec subst (env : Id.t Id.Map.t) (e : Parse_tree.exp) = +let rec subst (env : Id.t Id.Map.t) : Parse_tree.exp -> Parse_tree.exp = let subst' = subst env in - match e with - | Int _ | Real _ | Bool _ -> e + function + | (Int _ | Real _ | Bool _) as e -> e | Var x -> ( match Map.find env x with - | None -> failwith ("Unbound variable " ^ x) + | None -> raise (Unbound_variable x) | Some v -> Var v) | Add (e1, e2) -> Add (subst' e1, subst' e2) | Radd (e1, e2) -> Radd (subst' e1, subst' e2) @@ -54,7 +58,7 @@ let rec inline_one (fn : Parse_tree.fn) (prog : Parse_tree.program) = let inline_exp' = inline_exp scope in match e with | Int _ | Real _ | Bool _ -> e - | Var x -> if Set.mem scope x then e else failwith ("Unbound variable " ^ x) + | Var x -> if Set.mem scope x then e else raise (Unbound_variable x) | Add (e1, e2) -> Add (inline_exp' e1, inline_exp' e2) | Radd (e1, e2) -> Radd (inline_exp' e1, inline_exp' e2) | Minus (e1, e2) -> Minus (inline_exp' e1, inline_exp' e2) @@ -78,35 +82,36 @@ let rec inline_one (fn : Parse_tree.fn) (prog : Parse_tree.program) = Assign (x, inline_exp' e1, inline_exp (Set.add scope x) e2) | If (cond, yes, no) -> If (inline_exp' cond, inline_exp' yes, inline_exp' no) - | Call (f, args) -> + | Call (f, args) as e -> + (* A-Normalize the arguments. For example, f(sample(e)) should only evaluate sample(e) once. *) let args = List.map ~f:inline_exp' args in - if Id.(equal f fn.name) then - let kvpair = + if Id.(f <> fn.name) then e + else + let param_args = try List.zip_exn fn.params args - with _ -> - failwith - ("Argument length mismatch when calling function " ^ fn.name) + with _ -> raise (Arity_mismatch fn.name) + in + let param_args = + List.map ~f:(fun (p, a) -> (p, gen_args (), a)) param_args in - let kvpair = List.map ~f:(fun (k, v) -> (k, gen_args (), v)) kvpair in let env = - List.fold kvpair ~init:Id.Map.empty ~f:(fun acc (k, v, _) -> - Map.set acc ~key:k ~data:v) + List.fold param_args ~init:Id.Map.empty ~f:(fun env (p, p', _a) -> + Map.set env ~key:p ~data:p') in - List.fold kvpair ~init:(subst env fn.body) ~f:(fun acc (_, v, arg) -> - Assign (v, arg, acc)) - else Call (f, args) + List.fold param_args ~init:(subst env fn.body) + ~f:(fun body (_p, p', a) -> Assign (p', a, body)) | Sample e -> Sample (inline_exp' e) | Observe (d, e) -> Observe (inline_exp' d, inline_exp' e) | List _ -> failwith "List not implemented" | Record _ -> failwith "Record not implemented" in + let { funs; exp } = prog in match funs with | [] -> { funs = []; exp = inline_exp Id.Set.empty exp } | { name; params; body } :: funs -> let body = inline_exp (Id.Set.of_list params) body in - if Id.(equal name fn.name) then - { funs = { name; params; body } :: funs; exp } + if Id.(name = fn.name) then { funs = { name; params; body } :: funs; exp } else let { funs; exp } = inline_one fn { funs; exp } in { funs = { name; params; body } :: funs; exp } @@ -118,127 +123,166 @@ let rec inline (prog : Parse_tree.program) = | [] -> exp | fn :: funs -> inline (inline_one fn { funs; exp }) -let get_dist (name : Id.t) : some_dist = - let open Owl.Stats in - match name with - | "bernoulli" -> - Ex - { - ret = Tyb; - name = "bernoulli"; - params = [ Tyr ]; - sampler = (fun [ (Tyr, p) ] -> binomial_rvs ~p ~n:1 = 1); - log_pmdf = - (fun [ (Tyr, p) ] b -> binomial_logpdf ~p ~n:1 (Bool.to_int b)); - } - | "normal" -> - Ex - { - ret = Tyr; - name = "normal"; - params = [ Tyr; Tyr ]; - sampler = (fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_rvs ~mu ~sigma); - log_pmdf = - (fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_logpdf ~mu ~sigma); - } - | _ -> failwith "Unknown primitive function" - let rec check : type a. tyenv -> Parse_tree.exp -> a ty -> (a, non_det) texp = fun tyenv e ty -> match e with | Var x -> ( match Map.find tyenv x with - | None -> failwith ("Unbound variable " ^ x) + | None -> raise (Unbound_variable x) | Some (Ex t) -> ( match (t, ty) with | Tyi, Tyi -> { ty; exp = Var x } | Tyr, Tyr -> { ty; exp = Var x } | Tyb, Tyb -> { ty; exp = Var x } - | _, _ -> failwith ("Variable " ^ x ^ " type mismatch"))) + | t1, t2 -> + raise + (Type_mismatch + (sprintf "Variable %s: expected %s, got %s" x + (string_of_ty t2) (string_of_ty t1))))) | Int i -> ( match ty with | Tyi -> { ty; exp = Value i } - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch (sprintf "Expected int, got %s" (string_of_ty t)))) | Add (e1, e2) -> ( match ty with | Tyi -> check_bop tyenv "+" ( + ) Tyi Tyi Tyi e1 e2 - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch + (sprintf "Expected int for Add, got %s" (string_of_ty t)))) | Minus (e1, e2) -> ( match ty with | Tyi -> check_bop tyenv "-" ( - ) Tyi Tyi Tyi e1 e2 - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch + (sprintf "Expected int for Minus, got %s" (string_of_ty t)))) | Neg e -> ( match ty with | Tyi -> check_uop tyenv "-" Int.neg Tyi Tyi e - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch + (sprintf "Expected int for Neg, got %s" (string_of_ty t)))) | Mult (e1, e2) -> ( match ty with | Tyi -> check_bop tyenv "*" ( * ) Tyi Tyi Tyi e1 e2 - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch + (sprintf "Expected int for Mult, got %s" (string_of_ty t)))) | Div (e1, e2) -> ( match ty with | Tyi -> check_bop tyenv "/" ( / ) Tyi Tyi Tyi e1 e2 - | _ -> failwith "Expected something other than int") + | t -> + raise + (Type_mismatch + (sprintf "Expected int for Div, got %s" (string_of_ty t)))) | Real r -> ( match ty with | Tyr -> { ty; exp = Value r } - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch (sprintf "Expected float, got %s" (string_of_ty t)))) | Radd (e1, e2) -> ( match ty with | Tyr -> check_bop tyenv "+" ( +. ) Tyr Tyr Tyr e1 e2 - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch + (sprintf "Expected float for Radd, got %s" (string_of_ty t)))) | Rminus (e1, e2) -> ( match ty with | Tyr -> check_bop tyenv "-" ( -. ) Tyr Tyr Tyr e1 e2 - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch + (sprintf "Expected float for Rminus, got %s" (string_of_ty t)))) | Rneg e -> ( match ty with | Tyr -> check_uop tyenv "-" Float.neg Tyr Tyr e - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch + (sprintf "Expected float for Rneg, got %s" (string_of_ty t)))) | Rmult (e1, e2) -> ( match ty with | Tyr -> check_bop tyenv "*" ( *. ) Tyr Tyr Tyr e1 e2 - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch + (sprintf "Expected float for Rmult, got %s" (string_of_ty t)))) | Rdiv (e1, e2) -> ( match ty with | Tyr -> check_bop tyenv "/" ( /. ) Tyr Tyr Tyr e1 e2 - | _ -> failwith "Expected something other than float") + | t -> + raise + (Type_mismatch + (sprintf "Expected float for Rdiv, got %s" (string_of_ty t)))) | Bool b -> ( match ty with | Tyb -> { ty; exp = Value b } - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch (sprintf "Expected bool, got %s" (string_of_ty t)))) | Eq (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "=" Int.( = ) Tyi Tyi Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Eq, got %s" (string_of_ty t)))) | Req (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "=" Float.( = ) Tyr Tyr Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Req, got %s" (string_of_ty t)))) | Noteq (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "<>" Int.( <> ) Tyi Tyi Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Noteq, got %s" (string_of_ty t)))) | Less (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "<" Int.( < ) Tyi Tyi Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Less, got %s" (string_of_ty t)))) | Rless (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "<" Float.( < ) Tyr Tyr Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Rless, got %s" (string_of_ty t)))) | And (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "&&" ( && ) Tyb Tyb Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for And, got %s" (string_of_ty t)))) | Or (e1, e2) -> ( match ty with | Tyb -> check_bop tyenv "||" ( || ) Tyb Tyb Tyb e1 e2 - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Or, got %s" (string_of_ty t)))) | Not e -> ( match ty with | Tyb -> check_uop tyenv "!" not Tyb Tyb e - | _ -> failwith "Expected something other than bool") + | t -> + raise + (Type_mismatch + (sprintf "Expected bool for Not, got %s" (string_of_ty t)))) | Observe (d, e) -> ( let (Ex td) = convert tyenv d in let (Ex te) = convert tyenv e in @@ -246,7 +290,9 @@ let rec check : type a. tyenv -> Parse_tree.exp -> a ty -> (a, non_det) texp = | Tyi, Tyi, Tyi -> { ty; exp = Observe (td, te) } | Tyr, Tyr, Tyr -> { ty; exp = Observe (td, te) } | Tyb, Tyb, Tyb -> { ty; exp = Observe (td, te) } - | _, _, _ -> failwith "Argument to observe has different types") + | _ -> + raise + (Type_mismatch (sprintf "Argument to observe has different types"))) | Seq (e1, e2) -> let (Ex te1) = convert tyenv e1 in let te2 = check tyenv e2 ty in @@ -262,18 +308,22 @@ let rec check : type a. tyenv -> Parse_tree.exp -> a ty -> (a, non_det) texp = let talt = check tyenv alt ty in { ty; exp = If (tpred, tconseq, talt) } | Call (prim, args) -> ( - let (Ex dist) = get_dist prim in + let (Ex dist) = Dist.get_dist prim in let args = check_args tyenv args dist.params in match (dist.ret, ty) with | Tyi, Tyi -> { ty; exp = Call (dist, args) } | Tyr, Tyr -> { ty; exp = Call (dist, args) } | Tyb, Tyb -> { ty; exp = Call (dist, args) } - | _, _ -> failwith "No") + | _ -> + raise + (Type_mismatch + (sprintf "Expected %s for Call, got %s" (string_of_ty dist.ret) + (string_of_ty ty)))) | Sample e -> let te = check tyenv e ty in { ty; exp = Sample te } - | List _ -> failwith "List not implemented" - | Record _ -> failwith "Record not implemented" + | List _ -> raise (Type_mismatch "List not implemented") + | Record _ -> raise (Type_mismatch "Record not implemented") and check_uop : type arg ret. @@ -356,7 +406,7 @@ and convert (tyenv : tyenv) (e : Parse_tree.exp) : some_ndet = | _, _ -> failwith "Branches of an if statement must return the same type" ) | Call (prim, args) -> - let (Ex dist) = get_dist prim in + let (Ex dist) = Dist.get_dist prim in let args = check_args tyenv args dist.params in Ex { ty = dist.ret; exp = Call (dist, args) } | Sample e -> @@ -364,3 +414,6 @@ and convert (tyenv : tyenv) (e : Parse_tree.exp) : some_ndet = Ex { ty = te.ty; exp = Sample te } | List _ -> failwith "List not implemented" | Record _ -> failwith "Record not implemented" + +let check_program (program : Parse_tree.program) : some_ndet = + convert Id.Map.empty (inline program)