Skip to content

Commit

Permalink
🎉 Type-safe random-variable/value/dist distinction
Browse files Browse the repository at this point in the history
Needs more polishing here and there including a full-implementation of
partial evaluation, but the current version works too (in principle).
The sample STAPPL programs run without friction!
  • Loading branch information
Zeta611 committed Jun 9, 2024
1 parent 9b70b6b commit 996cd9e
Show file tree
Hide file tree
Showing 14 changed files with 851 additions and 492 deletions.
12 changes: 9 additions & 3 deletions bin/dune
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
(executable
(public_name stappl)
(name main)
(libraries core core_unix.command_unix core_unix.filename_unix stappl)
(modes byte exe)
(libraries
core
core_unix.command_unix
core_unix.filename_unix
logs
logs.fmt
stappl)
(preprocess
(pps ppx_jane)))
(pps ppx_jane))
(modes byte exe))
11 changes: 9 additions & 2 deletions bin/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ let command : Command.t =
(let%map_open.Command filename =
anon (maybe_with_default "-" ("filename" %: Filename_unix.arg_type))
and pp_opt = flag "-pp" no_arg ~doc:" Pretty print the program"
and graph_opt = flag "-graph" no_arg ~doc:" Print the compiled graph" in
and graph_opt = flag "-graph" no_arg ~doc:" Print the compiled graph"
and debug_opt = flag "-debug" no_arg ~doc:" Debug mode" in
fun () ->
if debug_opt then Logs.set_level (Some Logs.Debug);

if pp_opt then (
printf "Pretty-print: %s\n" filename;
print_s [%sexp (get_program filename : Parse_tree.program)]);
Expand All @@ -49,6 +52,7 @@ let command : Command.t =
let graph, query = get_program filename |> Compiler.compile_program in
graph_query := Some (graph, query);
print_s [%sexp (Printing.of_graph graph : Printing.graph)]);

if pp_opt || graph_opt then printf "\n";
printf "Inference: %s\n" filename;
Out_channel.flush stdout;
Expand All @@ -60,4 +64,7 @@ let command : Command.t =
printf "Query result saved at %s\n"
(Evaluator.infer ~filename graph query))

let () = Command_unix.run ~version:"0.1.0" ~build_info:"STAPPL" command
let () =
Logs.set_reporter (Logs_fmt.reporter ());
Command_unix.run ~version:"0.1.0" ~build_info:"STAPPL" command;
exit (if Logs.err_count () > 0 then 1 else 0)
127 changes: 79 additions & 48 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@ exception Not_closed_observation

let rec peval : type a. (a, det) texp -> (a, det) texp =
fun { ty; exp } ->
(* TODO: consider other cases *)
let exp =
match exp with
| (Value _ | Var _) as e -> e
| Value _ -> exp
| Var _ -> exp
| Bop (op, te1, te2) -> (
match (peval te1, peval te2) with
| { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (op.f v1 v2)
(*| { ty = ty1; exp = Value v1 }, { ty = ty2; exp = Value v2 } ->*)
(* Value (op.op v1 v2)*)
| te1, te2 -> Bop (op, te1, te2))
| Uop (op, te) -> (
match peval te with
| { exp = Value v; _ } -> Value (op.f v)
(*| { exp = Value v; _ } -> Value (op.op v)*)
| e -> Uop (op, e))
| If (te_pred, te_cons, te_alt) -> (
match peval te_pred with
| { exp = Value true; _ } -> (peval te_cons).exp
| { exp = Value false; _ } -> (peval te_alt).exp
(*| { exp = Value true; _ } -> (peval te_cons).exp*)
(*| { exp = Value false; _ } -> (peval te_alt).exp*)
| te_pred -> If (te_pred, peval te_cons, peval te_alt))
| Call (f, args) -> (
match peval_args args with
Expand All @@ -48,7 +51,13 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
log_pmdf = (fun [] -> f.log_pmdf vargs);
},
[] ))
| If_pred (p, de) -> (
let p = peval_pred p and de = peval de in
match p with (* TODO: *) _ -> If_pred (p, de))
| If_con de -> If_con (peval de)
| If_alt de -> If_alt (peval de)
in

{ ty; exp }

and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
Expand All @@ -57,31 +66,57 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
| te :: tl -> (
match (peval te, peval_args tl) with
| { ty; exp = Value v }, (tl, Some vargs) ->
({ ty; exp = Value v } :: tl, Some ((ty, v) :: vargs))
({ ty; exp = Value v } :: tl, Some ((dty_of_ty ty, v) :: vargs))
| te, (tl, _) -> (te :: tl, None))

and peval_pred : pred -> pred = function
| Empty -> failwith "[Bug] Empty predicate"
| True -> True
| False -> False
| And (p, de) -> (
match peval de with
| { exp = Value true; _ } -> peval_pred p
| { exp = Value false; _ } -> False
| de -> And (p, de))
| And_not (p, de) -> (
match peval de with
| { exp = Value true; _ } -> False
| { exp = Value false; _ } -> peval_pred p
| de -> And_not (p, de))

let ( &&& ) p de = peval_pred (And (p, de))
let ( &&! ) p de = peval_pred (And_not (p, de))

let rec score : type a. (a, det) texp -> (a, det) texp = function
(* TODO: consider other cases *)
| { ty; exp = If (e_pred, e_con, e_alt) } ->
let s_con = score e_con and s_alt = score e_alt in
{ ty; exp = If (e_pred, s_con, s_alt) }
| { exp = Call _; _ } as e -> e
| _ -> raise Score_invalid_arguments

type pred = (bool, det) texp

let rec compile :
type a. env:env -> pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp
=
fun ~env ~pred { ty; exp } ->
type a s.
env:env -> ?pred:pred -> (a, non_det) texp -> Graph.t * (a, det) texp =
fun ~env ?(pred = Empty) { ty; exp } ->
match exp with
| Value v -> (Graph.empty, { ty; exp = Value v })
| Value _ as exp -> (Graph.empty, { ty; exp })
| Var x -> (
let (Ex { ty = tyx; exp }) = Map.find_exn env x in
match (tyx, ty) with
| Tyi, Tyi -> (Graph.empty, { ty; exp })
| Tyr, Tyr -> (Graph.empty, { ty; exp })
| Tyb, Tyb -> (Graph.empty, { ty; exp })
| _, _ -> assert false)
| Dat_ty (Tyu, Val), Dat_ty (Tyu, Val) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyb, Val), Dat_ty (Tyb, Val) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyi, Val), Dat_ty (Tyi, Val) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyr, Val), Dat_ty (Tyr, Val) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyu, Rv), Dat_ty (Tyu, Rv) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyb, Rv), Dat_ty (Tyb, Rv) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyi, Rv), Dat_ty (Tyi, Rv) -> (Graph.empty, { ty; exp })
| Dat_ty (Tyr, Rv), Dat_ty (Tyr, Rv) -> (Graph.empty, { ty; exp })
| Dist_ty Tyu, Dist_ty Tyu -> (Graph.empty, { ty; exp })
| Dist_ty Tyb, Dist_ty Tyb -> (Graph.empty, { ty; exp })
| Dist_ty Tyi, Dist_ty Tyi -> (Graph.empty, { ty; exp })
| Dist_ty Tyr, Dist_ty Tyr -> (Graph.empty, { ty; exp })
| _, _ -> failwith "[Bug] Type mismatch")
| Bop (op, e1, e2) ->
let g1, te1 = compile ~env ~pred e1 in
let g2, te2 = compile ~env ~pred e2 in
Expand All @@ -91,27 +126,14 @@ let rec compile :
(g, { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt) -> (
let g1, de_pred = compile ~env ~pred e_pred in
let pred_con =
peval
{ ty = Tyb; exp = Bop ({ f = ( && ); name = "&&" }, pred, de_pred) }
in
let pred_alt =
peval
{
ty = Tyb;
exp =
Bop
( { f = ( && ); name = "&&" },
pred,
{ ty = Tyb; exp = Uop ({ f = not; name = "!" }, de_pred) } );
}
in
let pred_con = pred &&& de_pred in
let pred_alt = pred &&! de_pred in
let g2, de_con = compile ~env ~pred:pred_con e_con in
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
match pred_con.exp with
| Value true -> (g, de_con)
| Value false -> (g, de_alt)
match pred_con with
| True -> (g, { ty; exp = If_con de_con })
| False -> (g, { ty; exp = If_alt de_alt })
| _ -> (g, { ty; exp = If (de_pred, de_con, de_alt) }))
| Let (x, e, body) ->
let g1, det_exp1 = compile ~env ~pred e in
Expand All @@ -126,13 +148,13 @@ let rec compile :
let g, de = compile ~env ~pred e in
let v = gen_vertex () in
let de_fvs = fv de.exp in
let f : some_det = Ex (score de) in
let f = score de in
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
pmdf_map = Id.Map.singleton v f;
pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp);
obs_map = Id.Map.empty;
}
in
Expand All @@ -142,21 +164,20 @@ let rec compile :
let g2, de2 = compile ~env ~pred e2 in
let v = gen_vertex () in
let f1 = score de1 in
let f =
{ ty; exp = If (pred, f1, { ty; exp = Call (Dist.one ty, []) }) }
in
let fvs = Id.(fv de1.exp @| fv pred.exp) in
if not (Set.is_empty (fv de2.exp)) then raise Not_closed_observation;
let f = { ty = f1.ty; exp = If_pred (pred, f1) } in
let fvs = Id.(fv de1.exp @| fv_pred pred) in
if not (Set.is_empty (fv de2.exp)) then
failwith "[Bug] Not closed observation";
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
pmdf_map = Id.Map.singleton v (Ex f : some_det);
obs_map = Id.Map.singleton v (Ex de2 : some_det);
pmdf_map = Id.Map.singleton v (Ex f : some_dist_texp);
obs_map = Id.Map.singleton v (Ex de2 : some_dat_texp);
}
in
Graph.(g1 @| g2 @| g', de2)
Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () })

and compile_args :
type a. env -> pred -> (a, non_det) args -> Graph.t * (a, det) args =
Expand All @@ -168,7 +189,17 @@ and compile_args :
let g', args = compile_args env pred args in
Graph.(g @| g', arg :: args)

let compile_program (prog : program) : Graph.t * some_det =
let (Ex e) = Typing.check_program prog in
let g, e = compile ~env:Id.Map.empty ~pred:{ ty = Tyb; exp = Value true } e in
(g, Ex e)
exception Query_not_found

let compile_program (prog : program) : Graph.t * some_rv_texp =
Logs.debug (fun m ->
m "Inlining program %a" Sexp.pp_hum [%sexp (prog : Parse_tree.program)]);
let exp = Preprocessor.inline prog in
Logs.debug (fun m ->
m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]);

let (Ex e) = Typing.check exp in
let g, { ty; exp } = compile ~env:Id.Map.empty e in
match ty with
| Dat_ty (_, Rv) -> (g, Ex { ty; exp })
| _ -> raise Query_not_found
12 changes: 11 additions & 1 deletion lib/dist.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
open! Core
open Typed_tree

let one : type a. a ty -> (a, unit) dist = function
type some_dist = Ex : _ dist -> some_dist

let one : type a. a dty -> (a, unit) dist = function
| Tyu ->
{
ret = Tyu;
name = "one";
params = [];
sampler = (fun [] -> ());
log_pmdf = (fun [] _ -> 0.0);
}
| Tyi ->
{
ret = Tyi;
Expand Down
2 changes: 1 addition & 1 deletion lib/dune
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(library
(name stappl)
(libraries core owl owl-plplot string_dict)
(libraries core owl owl-plplot string_dict logs)
(inline_tests)
(preprocess
(pps ppx_jane)))
Expand Down
Loading

0 comments on commit 996cd9e

Please sign in to comment.