Skip to content

Commit

Permalink
✨ Completely type-safe implementations via GADTs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 10, 2024
1 parent 50bdae9 commit 1a1a9e7
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 114 deletions.
114 changes: 52 additions & 62 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,52 @@ type env = det some_texp Id.Map.t
let gen_vertex =
let cnt = ref 0 in
fun () ->
let v = "X" ^ string_of_int !cnt in
incr cnt;
v
[%string "X%{!cnt#Int}"]

exception Score_invalid_arguments
exception Not_closed_observation

let rec peval : type a. (a, det) texp -> (a, det) texp =
fun { ty; exp } ->
(* TODO: consider other cases *)
let exp =
match exp with
| Value _ -> exp
| Var _ -> exp
| Bop (op, te1, te2) -> (
match (peval te1, peval te2) with
(*| { ty = ty1; exp = Value v1 }, { ty = ty2; exp = Value v2 } ->*)
(* Value (op.op v1 v2)*)
| te1, te2 -> Bop (op, te1, te2))
| Uop (op, te) -> (
match peval te with
(*| { exp = Value v; _ } -> Value (op.op v)*)
| e -> Uop (op, e))
| If (te_pred, te_cons, te_alt) -> (
match peval te_pred with
(*| { exp = Value true; _ } -> (peval te_cons).exp*)
(*| { exp = Value false; _ } -> (peval te_alt).exp*)
| te_pred -> If (te_pred, peval te_cons, peval te_alt))
| Call (f, args) -> (
match peval_args args with
| args, None -> Call (f, args)
| _, Some vargs ->
(* All arguments are fully evaluated;
Go ahead and fully evaluate the (primitive) call.
It is a primitive call as this is a deterministic expression. *)
Call
( {
ret = f.ret;
name = f.name;
params = [];
sampler = (fun [] -> f.sampler vargs);
log_pmdf = (fun [] -> f.log_pmdf vargs);
},
[] ))
| If_pred (p, de) -> (
let p = peval_pred p and de = peval de in
match p with (* TODO: *) _ -> If_pred (p, de))
| If_just de -> If_just (peval de)
in

{ ty; exp }
fun ({ ty; exp } as texp) ->
match exp with
| Value _ | Rvar _ -> texp
| Bop (bop, te1, te2, ms) -> (
match (peval te1, peval te2, ms) with
| { exp = Value v1; _ }, { exp = Value v2; _ }, Both_val _ ->
{ ty; exp = Value (bop.op v1 v2) }
| te1, te2, ms -> { ty; exp = Bop (bop, te1, te2, ms) })
| Uop (uop, te) -> (
match peval te with
| { exp = Value v; _ } -> { ty; exp = Value (uop.op v) }
| e -> { ty; exp = Uop (uop, e) })
| If_pred (pred, te_con, te_alt) -> (
match peval_pred pred with
| True -> peval { ty; exp = If_just te_con }
| False -> peval { ty; exp = If_just te_alt }
| p -> { ty; exp = If_pred (p, peval te_con, peval te_alt) })
| Call (f, args) -> (
match peval_args args with
| args, None -> { ty; exp = Call (f, args) }
| _, Some vargs ->
(* All arguments are fully evaluated;
Go ahead and fully evaluate the (primitive) call.
It is a primitive call as this is a deterministic expression. *)
let f_dist =
{
ret = f.ret;
name = f.name;
params = [];
sampler = (fun [] -> f.sampler vargs);
log_pmdf = (fun [] -> f.log_pmdf vargs);
}
in
{ ty; exp = Call (f_dist, []) })
| If_pred_dist (p, de) -> (
match peval_pred p with
| True -> peval de
| False -> { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
| p -> { ty; exp = If_pred_dist (p, peval de) })
| If_just de -> { ty; exp = If_just (peval de) }

and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
function
Expand Down Expand Up @@ -86,13 +81,11 @@ and peval_pred : pred -> pred = function
let ( &&& ) p de = peval_pred (And (p, de))
let ( &&! ) p de = peval_pred (And_not (p, de))

let rec score : type a. (a, det) texp -> (a, det) texp = function
(* TODO: consider other cases *)
| { ty; exp = If (e_pred, e_con, e_alt) } ->
let s_con = score e_con and s_alt = score e_alt in
{ ty; exp = If (e_pred, s_con, s_alt) }
let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
function
| { exp = If_pred_dist (p, de); _ } ->
{ ty = de.ty; exp = If_pred_dist (p, score de) }
| { exp = Call _; _ } as e -> e
| _ -> raise Score_invalid_arguments

let rec compile :
type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp
Expand All @@ -105,24 +98,21 @@ let rec compile :
match eq_tys tyx ty with
| Some Refl -> (Graph.empty, { ty; exp })
| None -> failwith "[Bug] Type mismatch")
| Bop (op, e1, e2) ->
| Bop (op, e1, e2, ms) ->
let g1, te1 = compile ~env ~pred e1 in
let g2, te2 = compile ~env ~pred e2 in
Graph.(g1 @| g2, { ty; exp = Bop (op, te1, te2) })
Graph.(g1 @| g2, peval { ty; exp = Bop (op, te1, te2, ms) })
| Uop (op, e) ->
let g, te = compile ~env ~pred e in
(g, { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt) -> (
(g, peval { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt, _, _) ->
let g1, de_pred = compile ~env ~pred e_pred in
let pred_con = pred &&& de_pred in
let pred_alt = pred &&! de_pred in
let g2, de_con = compile ~env ~pred:pred_con e_con in
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
match pred_con with
| True -> (g, { ty; exp = If_just de_con })
| False -> (g, { ty; exp = If_just de_alt })
| _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) }))
(g, peval { ty; exp = If_pred (pred_con, de_con, de_alt) })
| Let (x, e, body) ->
let g1, det_exp1 = compile ~env ~pred e in
let g2, det_exp2 =
Expand All @@ -146,13 +136,13 @@ let rec compile :
obs_map = Id.Map.empty;
}
in
Graph.(g @| g', { ty; exp = Var v })
Graph.(g @| g', { ty; exp = Rvar v })
| Observe (e1, e2) ->
let g1, de1 = compile ~env ~pred e1 in
let g2, de2 = compile ~env ~pred e2 in
let v = gen_vertex () in
let f1 = score de1 in
let f = { ty = f1.ty; exp = If_pred (pred, f1) } in
let f = { ty = f1.ty; exp = If_pred_dist (pred, f1) } in
let fvs = Id.(fv de1.exp @| fv_pred pred) in
if not (Set.is_empty (fv de2.exp)) then
failwith "[Bug] Not closed observation";
Expand Down
22 changes: 8 additions & 14 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,26 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a =
fun ctx { ty; exp } ->
match exp with
| Value v -> v
| Var x -> (
| Rvar x -> (
let (Ex (tv, v)) = Ctx.find_exn ctx x in
match eq_dtys (dty_of_dat_ty ty) tv with
| Some Refl -> v
| None -> assert false)
| Bop ({ op; _ }, te1, te2) -> op (eval_dat ctx te1) (eval_dat ctx te2)
| Bop ({ op; _ }, te1, te2, _) -> op (eval_dat ctx te1) (eval_dat ctx te2)
| Uop ({ op; _ }, te) -> op (eval_dat ctx te)
| If (te_pred, te_cons, te_alt) ->
if eval_dat ctx te_pred then eval_dat ctx te_cons else eval_dat ctx te_alt
| If_pred (pred, te_con, te_alt) ->
if eval_pred ctx pred then eval_dat ctx te_con else eval_dat ctx te_alt
| If_just te -> eval_dat ctx te

and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a =
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| Call (f, args) -> f.sampler (eval_args ctx args)
| Var x -> (
let (Ex (tv, v)) = Ctx.find_exn ctx x in
match eq_dtys dty tv with Some Refl -> v | None -> assert false)
| If_pred (pred, dist) ->
| If_pred_dist (pred, dist) ->
if eval_pred ctx pred then eval_dist ctx dist
else eval_dist ctx { ty; exp = Call (Dist.one dty, []) }

and eval_pred (ctx : Ctx.t) : pred -> bool =
(*print_endline "[eval_pred]";*)
function
and eval_pred (ctx : Ctx.t) : pred -> bool = function
| Empty | True -> true
| False -> false
| And (p, de) -> eval_dat ctx de && eval_pred ctx p
Expand All @@ -55,7 +50,7 @@ let rec eval_pmdf :
type a. Ctx.t -> (a dist_ty, det) texp -> (some_val -> real) * some_val =
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| If_pred (pred, te) ->
| If_pred_dist (pred, te) ->
if eval_pred ctx pred then eval_pmdf ctx te
else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) }
| Call (f, args) ->
Expand All @@ -65,7 +60,6 @@ let rec eval_pmdf :
| _ -> assert false
in
(pmdf, Ex (dty, eval_dist ctx { ty; exp }))
| _ -> (* not reachable *) assert false

let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) :
float array =
Expand Down Expand Up @@ -98,7 +92,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) :
let curr = Ctx.find_exn ctx name in
let log_pmdf, cand = eval_pmdf ctx exp in

(* metropolis-hastings update logic *)
(* Metropolis-Hastings update logic *)
Ctx.set ctx ~name ~value:cand;
let log_pmdf', _ = eval_pmdf ctx exp in
let log_alpha = log_pmdf' curr -. log_pmdf cand in
Expand Down
3 changes: 1 addition & 2 deletions lib/preprocessor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ exception Unbound_variable of string
let gen_args =
let cnt = ref 0 in
fun () ->
let arg = "$arg" ^ string_of_int !cnt in
incr cnt;
arg
[%string "$arg%{!cnt#Int}"]

let rec subst (env : subst_map) : exp -> exp =
(* 𝜂-expansion required to avoid infinite recursion *)
Expand Down
69 changes: 47 additions & 22 deletions lib/typed_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ type _ dty = Tyu : unit dty | Tyi : int dty | Tyr : real dty | Tyb : bool dty
type value = Val_ph
type rv = Rv_ph
type _ stamp = Val : value stamp | Rv : rv stamp

type (_, _, _) merge_stamp =
| Both_val : value stamp * value stamp -> (value, value, value) merge_stamp
| Right_rv : value stamp * rv stamp -> (value, rv, rv) merge_stamp
| Left_rv : rv stamp * value stamp -> (rv, value, rv) merge_stamp
| Both_rv : rv stamp * rv stamp -> (rv, rv, rv) merge_stamp

type ('a, 'b) dat_ty = Dat_ty_ph
type 'a dist_ty = Dist_ty_ph

Expand Down Expand Up @@ -49,18 +56,27 @@ and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp }

and (_, _) exp =
| Value : 'a -> (('a, value) dat_ty, _) exp
| Var : Id.t -> _ exp
| Var : Id.t -> (_, ndet) exp
| Rvar : Id.t -> (('a, rv) dat_ty, det) exp
| Bop :
('a, 'b, 'c) bop * (('a, _) dat_ty, 'd) texp * (('b, _) dat_ty, 'd) texp
-> (('c, _) dat_ty, 'd) exp
| Uop : ('a, 'b) uop * (('a, _) dat_ty, 'd) texp -> (('b, _) dat_ty, 'd) exp
('a1, 'a2, 'a) bop
* (('a1, 's1) dat_ty, 'd) texp
* (('a2, 's2) dat_ty, 'd) texp
* ('s1, 's2, 's) merge_stamp
-> (('a, 's) dat_ty, 'd) exp
| Uop : ('a, 'b) uop * (('a, 's) dat_ty, 'd) texp -> (('b, 's) dat_ty, 'd) exp
| If :
((bool, _) dat_ty, 'd) texp
* (('a, _) dat_ty, 'd) texp
* (('a, _) dat_ty, 'd) texp
-> (('a, _) dat_ty, 'd) exp
| If_pred : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp
| If_just : (('a, 's) dat_ty, det) texp -> (('a, _) dat_ty, det) exp
((bool, 's_pred) dat_ty, ndet) texp
* (('a, 's_con) dat_ty, ndet) texp
* (('a, 's_alt) dat_ty, ndet) texp
* ('s_con, 's_alt, 's_ca) merge_stamp
* ('s_pred, 's_ca, 's) merge_stamp
-> (('a, 's) dat_ty, ndet) exp
| If_pred :
pred * (('a, _) dat_ty, det) texp * (('a, _) dat_ty, det) texp
-> (('a, _) dat_ty, det) exp
| If_pred_dist : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp
| If_just : (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp
| Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp
| Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp
| Sample : ('a dist_ty, ndet) texp -> (('a, rv) dat_ty, ndet) exp
Expand All @@ -70,7 +86,6 @@ and (_, _) exp =

type some_dty = Ex : _ dty -> some_dty
type some_val = Ex : ('a dty * 'a) -> some_val
type some_stamp = Ex : _ stamp -> some_stamp
type some_ty = Ex : _ ty -> some_ty
type _ some_texp = Ex : (_, 'd) texp -> 'd some_texp

Expand All @@ -83,6 +98,9 @@ type _ some_dat_texp = Ex : (_ dat_ty, 'd) texp -> 'd some_dat_texp
type _ some_dist_texp = Ex : (_ dist_ty, 'd) texp -> 'd some_dist_texp
type (_, _) eq = Refl : ('a, 'a) eq

type (_, _) some_merge_stamp =
| Ex : ('s1, 's2, 's) merge_stamp * 's stamp -> ('s1, 's2) some_merge_stamp

let dty_of_dat_ty : type a. (a, _) dat_ty ty -> a dty = function
| Dat_ty (dty, _) -> dty

Expand Down Expand Up @@ -111,8 +129,14 @@ let eq_dtys : type a1 a2. a1 dty -> a2 dty -> (a1, a2) eq option =
let unify_dtys : type a1 a2. a1 dty -> a2 dty -> (a1, a2) eq -> a1 dty =
fun t _ Refl -> t

let merge_stamps : type s1 s2. s1 stamp -> s2 stamp -> some_stamp =
fun s1 s2 -> match (s1, s2) with Val, Val -> Ex Val | _, _ -> Ex Rv
let merge_stamps : type s1 s2. s1 stamp -> s2 stamp -> (s1, s2) some_merge_stamp
=
fun s1 s2 ->
match (s1, s2) with
| Val, Val -> Ex (Both_val (Val, Val), Val)
| Val, Rv -> Ex (Right_rv (Val, Rv), Rv)
| Rv, Val -> Ex (Left_rv (Rv, Val), Rv)
| Rv, Rv -> Ex (Both_rv (Rv, Rv), Rv)

let eq_dat_tys :
type a1 a2. (a1, _) dat_ty ty -> (a2, _) dat_ty ty -> (a1, a2) eq option =
Expand Down Expand Up @@ -157,12 +181,12 @@ let string_of_ty : type a. a ty -> string = function

let rec fv : type a. (a, det) exp -> Id.Set.t = function
| Value _ -> Id.Set.empty
| Var x -> Id.Set.singleton x
| Bop (_, { exp = e1; _ }, { exp = e2; _ }) -> Id.(fv e1 @| fv e2)
| Rvar x -> Id.Set.singleton x
| Bop (_, { exp = e1; _ }, { exp = e2; _ }, _) -> Id.(fv e1 @| fv e2)
| Uop (_, { exp; _ }) -> fv exp
| If ({ exp = e_pred; _ }, { exp = e_cons; _ }, { exp = e_alt; _ }) ->
Id.(fv e_pred @| fv e_cons @| fv e_alt)
| If_pred (pred, { exp = e_cons; _ }) -> Id.(fv_pred pred @| fv e_cons)
| If_pred (pred, { exp = e_con; _ }, { exp = e_alt; _ }) ->
Id.(fv_pred pred @| fv e_con @| fv e_alt)
| If_pred_dist (pred, { exp = e_con; _ }) -> Id.(fv_pred pred @| fv e_con)
| If_just { exp; _ } -> fv exp
| Call (_, args) -> fv_args args

Expand Down Expand Up @@ -195,17 +219,18 @@ module Erased = struct
let rec of_exp : type a d. (a, d) texp -> exp =
fun { ty; exp } ->
match exp with
| If (pred, cons, alt) -> If (of_exp pred, of_exp cons, of_exp alt)
| If_pred (pred, cons) -> If (of_pred pred, of_exp cons, Value "1")
| If (pred, con, alt, _, _) -> If (of_exp pred, of_exp con, of_exp alt)
| If_pred (pred, con, alt) -> If (of_pred pred, of_exp con, of_exp alt)
| If_pred_dist (pred, con) -> If (of_pred pred, of_exp con, Value "1")
| If_just exp -> If_just (of_exp exp)
| Value v -> (
match ty with
| Dat_ty (Tyu, _) -> Value "()"
| Dat_ty (Tyi, _) -> Value (string_of_int v)
| Dat_ty (Tyr, _) -> Value (string_of_float v)
| Dat_ty (Tyb, _) -> Value (string_of_bool v))
| Var v -> Var v
| Bop (op, e1, e2) -> Bop (op.name, of_exp e1, of_exp e2)
| Var v | Rvar v -> Var v
| Bop (op, e1, e2, _) -> Bop (op.name, of_exp e1, of_exp e2)
| Uop (op, e) -> Uop (op.name, of_exp e)
| Let (x, e1, e2) -> Let (x, of_exp e1, of_exp e2)
| Call (f, args) -> Call (f.name, of_args args)
Expand Down
19 changes: 6 additions & 13 deletions lib/typing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,10 @@ let unify_branches :
fun ({ ty = Dat_ty (Tyb, s_pred); _ } as te_pred)
({ ty = Dat_ty (dty_con, s_con); _ } as te_con)
({ ty = Dat_ty (dty_alt, s_alt); _ } as te_alt) Refl ->
let exp = If (te_pred, te_con, te_alt) in
match s_pred with
| Val -> (
let dty = unify_dtys dty_con dty_alt Refl in
let (Ex s) = merge_stamps s_con s_alt in
match s with
| Val -> Ex { ty = Dat_ty (dty, Val); exp }
| Rv -> Ex { ty = Dat_ty (dty, Rv); exp })
| Rv ->
let dty = unify_dtys dty_con dty_alt Refl in
Ex { ty = Dat_ty (dty, Rv); exp }
let dty = unify_dtys dty_con dty_alt Refl in
let (Ex (ms_ca, s_ca)) = merge_stamps s_con s_alt in
let (Ex (ms, s)) = merge_stamps s_pred s_ca in
Ex { ty = Dat_ty (dty, s); exp = If (te_pred, te_con, te_alt, ms_ca, ms) }

let rec check_dat :
type a. tyenv -> Parse_tree.exp * a dty -> a some_dat_ndet_texp =
Expand Down Expand Up @@ -244,8 +237,8 @@ and check_bop :
fun tyenv bop (e1, t1) (e2, t2) tret ->
let (Ex ({ ty = Dat_ty (_, s1); _ } as te1)) = check_dat tyenv (e1, t1) in
let (Ex ({ ty = Dat_ty (_, s2); _ } as te2)) = check_dat tyenv (e2, t2) in
let (Ex s) = merge_stamps s1 s2 in
Ex { ty = Dat_ty (tret, s); exp = Bop (bop, te1, te2) }
let (Ex (ms, s)) = merge_stamps s1 s2 in
Ex { ty = Dat_ty (tret, s); exp = Bop (bop, te1, te2, ms) }

and check_args :
type a. tyenv -> Id.t -> Parse_tree.exp list * a params -> (a, ndet) args =
Expand Down
Binary file modified samples/normal_bernoulli.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified samples/simple_itpp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified samples/student.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 1a1a9e7

Please sign in to comment.