Skip to content

Commit

Permalink
♻️ Better error messages and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 7, 2024
1 parent a46f49e commit 9b70b6b
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 134 deletions.
92 changes: 31 additions & 61 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -60,39 +60,37 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
({ ty; exp = Value v } :: tl, Some ((ty, v) :: vargs))
| te, (tl, _) -> (te :: tl, None))

let rec score : type a. (a, det) texp -> Id.t -> (a, det) texp =
fun e var ->
match e.exp with
| If (e_pred, e_con, e_alt) ->
let s_con = score e_con var in
let s_alt = score e_alt var in
{ ty = e.ty; exp = If (e_pred, s_con, s_alt) }
| Call _ -> e
let rec score : type a. (a, det) texp -> (a, det) texp = function
| { ty; exp = If (e_pred, e_con, e_alt) } ->
let s_con = score e_con and s_alt = score e_alt in
{ ty; exp = If (e_pred, s_con, s_alt) }
| { exp = Call _; _ } as e -> e
| _ -> raise Score_invalid_arguments

type pred = (bool, det) texp

let rec compile :
type a.
env -> (bool, det) texp -> (a, non_det) texp -> Graph.t * (a, det) texp =
fun env pred e ->
let { ty; exp } = e in
type a. env:env -> pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp
=
fun ~env ~pred { ty; exp } ->
match exp with
| Value v -> (Graph.empty, { ty; exp = Value v })
| Var x -> (
let (Ex { ty = tx; exp }) = Map.find_exn env x in
match (tx, ty) with
let (Ex { ty = tyx; exp }) = Map.find_exn env x in
match (tyx, ty) with
| Tyi, Tyi -> (Graph.empty, { ty; exp })
| Tyr, Tyr -> (Graph.empty, { ty; exp })
| Tyb, Tyb -> (Graph.empty, { ty; exp })
| _, _ -> assert false)
| Bop (op, e1, e2) ->
let g1, te1 = compile env pred e1 in
let g2, te2 = compile env pred e2 in
let g1, te1 = compile ~env ~pred e1 in
let g2, te2 = compile ~env ~pred e2 in
Graph.(g1 @| g2, { ty; exp = Bop (op, te1, te2) })
| Uop (op, e) ->
let g, te = compile env pred e in
let g, te = compile ~env ~pred e in
(g, { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt) -> (
let g1, de_pred = compile env pred e_pred in
let g1, de_pred = compile ~env ~pred e_pred in
let pred_con =
peval
{ ty = Tyb; exp = Bop ({ f = ( && ); name = "&&" }, pred, de_pred) }
Expand All @@ -108,27 +106,27 @@ let rec compile :
{ ty = Tyb; exp = Uop ({ f = not; name = "!" }, de_pred) } );
}
in
let g2, de_con = compile env pred_con e_con in
let g3, de_alt = compile env pred_alt e_alt in
let g2, de_con = compile ~env ~pred:pred_con e_con in
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
match pred_con.exp with
| Value true -> (g, de_con)
| Value false -> (g, de_alt)
| _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) }))
| Let (x, e, body) ->
let g1, det_exp1 = compile env pred e in
let g1, det_exp1 = compile ~env ~pred e in
let g2, det_exp2 =
compile (Map.set env ~key:x ~data:(Ex det_exp1)) pred body
compile ~env:(Map.set env ~key:x ~data:(Ex det_exp1)) ~pred body
in
Graph.(g1 @| g2, det_exp2)
| Call (f, args) ->
let g, args = compile_args env pred args in
(g, { ty; exp = Call (f, args) })
| Sample e ->
let g, de = compile env pred e in
let g, de = compile ~env ~pred e in
let v = gen_vertex () in
let de_fvs = fv de.exp in
let f : some_det = Ex (score de v) in
let f : some_det = Ex (score de) in
let g' =
Graph.
{
Expand All @@ -140,39 +138,13 @@ let rec compile :
in
Graph.(g @| g', { ty; exp = Var v })
| Observe (e1, e2) ->
let g1, de1 = compile env pred e1 in
let g2, de2 = compile env pred e2 in
let g1, de1 = compile ~env ~pred e1 in
let g2, de2 = compile ~env ~pred e2 in
let v = gen_vertex () in
let f1 = score de1 v in
let one : type a. a ty -> (a, unit) dist =
fun ty ->
match ty with
| Tyi ->
{
ret = ty;
name = "one";
params = [];
sampler = (fun _ -> 1);
log_pmdf = (fun [] _ -> 0.0);
}
| Tyr ->
{
ret = ty;
name = "one";
params = [];
sampler = (fun _ -> 1.0);
log_pmdf = (fun [] _ -> 0.0);
}
| Tyb ->
{
ret = Tyb;
name = "one";
params = [];
sampler = (fun _ -> true);
log_pmdf = (fun [] _ -> 0.0);
}
let f1 = score de1 in
let f =
{ ty; exp = If (pred, f1, { ty; exp = Call (Dist.one ty, []) }) }
in
let f = { ty; exp = If (pred, f1, { ty; exp = Call (one ty, []) }) } in
let fvs = Id.(fv de1.exp @| fv pred.exp) in
if not (Set.is_empty (fv de2.exp)) then raise Not_closed_observation;
let g' =
Expand All @@ -187,18 +159,16 @@ let rec compile :
Graph.(g1 @| g2 @| g', de2)

and compile_args :
type a.
env -> (bool, det) texp -> (a, non_det) args -> Graph.t * (a, det) args =
type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args =
fun env pred args ->
match args with
| [] -> (Graph.empty, [])
| arg :: args ->
let g, arg = compile env pred arg in
let g, arg = compile ~env ~pred arg in
let g', args = compile_args env pred args in
Graph.(g @| g', arg :: args)

let compile_program (prog : program) : Graph.t * some_det =
let open Typing in
let (Ex e) = convert Id.Map.empty (inline prog) in
let g, e = compile Id.Map.empty { ty = Tyb; exp = Value true } e in
let (Ex e) = Typing.check_program prog in
let g, e = compile ~env:Id.Map.empty ~pred:{ ty = Tyb; exp = Value true } e in
(g, Ex e)
53 changes: 53 additions & 0 deletions lib/dist.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
open! Core
open Typed_tree

let one : type a. a ty -> (a, unit) dist = function
| Tyi ->
{
ret = Tyi;
name = "one";
params = [];
sampler = (fun [] -> 1);
log_pmdf = (fun [] _ -> 0.0);
}
| Tyr ->
{
ret = Tyr;
name = "one";
params = [];
sampler = (fun [] -> 1.0);
log_pmdf = (fun [] _ -> 0.0);
}
| Tyb ->
{
ret = Tyb;
name = "one";
params = [];
sampler = (fun [] -> true);
log_pmdf = (fun [] _ -> 0.0);
}

let get_dist (name : Id.t) : some_dist =
let open Owl.Stats in
match name with
| "bernoulli" ->
Ex
{
ret = Tyb;
name = "bernoulli";
params = [ Tyr ];
sampler = (fun [ (Tyr, p) ] -> binomial_rvs ~p ~n:1 = 1);
log_pmdf =
(fun [ (Tyr, p) ] b -> binomial_logpdf ~p ~n:1 (Bool.to_int b));
}
| "normal" ->
Ex
{
ret = Tyr;
name = "normal";
params = [ Tyr; Tyr ];
sampler = (fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_rvs ~mu ~sigma);
log_pmdf =
(fun [ (Tyr, mu); (Tyr, sigma) ] -> gaussian_logpdf ~mu ~sigma);
}
| _ -> failwith "Unknown primitive function"
5 changes: 5 additions & 0 deletions lib/typed_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ open! Core
type real = float
type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty

let string_of_ty : type a. a ty -> string = function
| Tyi -> "int"
| Tyr -> "real"
| Tyb -> "bool"

type _ params =
| [] : unit params
| ( :: ) : 'a ty * 'b params -> ('a * 'b) params
Expand Down
Loading

0 comments on commit 9b70b6b

Please sign in to comment.