Skip to content

Commit

Permalink
♻️ Add mli stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
yhs0602 committed Jun 1, 2024
1 parent fa0c788 commit 9a9e5c5
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 99 deletions.
114 changes: 15 additions & 99 deletions lib/compile.ml → lib/compiler.ml
Original file line number Diff line number Diff line change
@@ -1,91 +1,6 @@
open Core
open Program

module Env = struct
type t = fn Map.M(Id).t

let empty : t = Map.empty (module Id)
let add (env : t) ~(name : Id.t) ~(fn : fn) = Map.add env ~key:name ~data:fn
let find_exn (env : t) ~(name : Id.t) : fn = Map.find_exn env name
end

module Pred = struct
type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t
[@@deriving sexp]

let rec fv : t -> Set.M(Id).t = function
| Empty -> Set.empty (module Id)
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p)
end

module Dist = struct
type t [@@deriving sexp]
type one = One [@@deriving sexp]

type exp =
| If_de of Det_exp.t * exp * exp
| If_pred of Pred.t * exp * one
| Dist_obj of { dist : t; var : Id.t; args : Det_exp.t list }
[@@deriving sexp]

exception Score_invalid_arguments

let prim_to_dist : Id.t -> t = failwith "Not implemented"

let rec score (det_exp : Det_exp.t) (var : Id.t) =
match det_exp with
| If (e_pred, e_con, e_alt) ->
let s_con = score e_con var in
let s_alt = score e_alt var in
If_de (e_pred, s_con, s_alt)
| Prim_call (c, es) -> Dist_obj { dist = prim_to_dist c; var; args = es }
| _ -> raise Score_invalid_arguments
end

module Graph = struct
type vertex = Id.t [@@deriving sexp]
type arc = vertex * vertex [@@deriving sexp]
type det_map = Dist.exp Map.M(Id).t [@@deriving sexp]
type obs_map = Det_exp.t Map.M(Id).t [@@deriving sexp]

type t = {
vertices : vertex list;
arcs : arc list;
det_map : det_map;
obs_map : obs_map;
}
[@@deriving sexp]

let empty =
{
vertices = [];
arcs = [];
det_map = Map.empty (module Id);
obs_map = Map.empty (module Id);
}

let union g1 g2 =
{
vertices = g1.vertices @ g2.vertices;
arcs = g1.arcs @ g2.arcs;
det_map =
Map.merge g1.det_map g2.det_map ~f:(fun ~key:_ v ->
match v with
| `Left det | `Right det -> Some det
| `Both _ ->
failwith "Graph.union: duplicate deterministic expression");
obs_map =
Map.merge g1.obs_map g2.obs_map ~f:(fun ~key:_ v ->
match v with
| `Left obs | `Right obs -> Some obs
| `Both _ -> failwith "Graph.union: duplicate observation");
}

let pp (graph : t) : string = graph |> sexp_of_t |> Sexp.to_string_hum
end

let ( @+ ) = Graph.union

let gen_sym =
let cnt = ref 0 in
fun () ->
Expand Down Expand Up @@ -140,6 +55,7 @@ exception Not_closed_observation
let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
let rec compile pred =
let compile' = compile pred in
let open Graph in
function
| Exp.Int n -> (Graph.empty, Det_exp.Int n)
| Real r -> (Graph.empty, Det_exp.Real r)
Expand All @@ -149,14 +65,14 @@ let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
let v = gen_vertex () in
let de_fvs = Det_exp.fv de in
let f = Dist.score de v in

let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
}
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
}
in
(g @+ g', Det_exp.Var v)
| Observe (e1, e2) ->
Expand All @@ -168,14 +84,14 @@ let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in
if not @@ Set.is_empty (Det_exp.fv de2) then
raise Not_closed_observation;

let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.singleton (module Id) v de2;
}
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.singleton (module Id) v de2;
}
in
(g1 @+ g2 @+ g', de2)
| Assign (x, e, body) ->
Expand Down
7 changes: 7 additions & 0 deletions lib/compiler.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
val gen_sym : unit -> Program.Id.t
val gen_vertex : unit -> string
val sub : Program.Exp.t -> Program.Id.t -> Program.Det_exp.t -> Program.Exp.t

exception Not_closed_observation

val compile : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t
24 changes: 24 additions & 0 deletions lib/dist.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
open Core
open Program

type t [@@deriving sexp]
type one = One [@@deriving sexp]

type exp =
| If_de of Det_exp.t * exp * exp
| If_pred of Pred.t * exp * one
| Dist_obj of { dist : t; var : Id.t; args : Det_exp.t list }
[@@deriving sexp]

exception Score_invalid_arguments

let prim_to_dist : Id.t -> t = failwith "Not implemented"

let rec score (det_exp : Det_exp.t) (var : Id.t) =
match det_exp with
| If (e_pred, e_con, e_alt) ->
let s_con = score e_con var in
let s_alt = score e_alt var in
If_de (e_pred, s_con, s_alt)
| Prim_call (c, es) -> Dist_obj { dist = prim_to_dist c; var; args = es }
| _ -> raise Score_invalid_arguments
22 changes: 22 additions & 0 deletions lib/dist.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
type t

val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t

type one = One

val one_of_sexp : Sexplib0.Sexp.t -> one
val sexp_of_one : one -> Sexplib0.Sexp.t

type exp =
| If_de of Program.Det_exp.t * exp * exp
| If_pred of Pred.t * exp * one
| Dist_obj of { dist : t; var : string; args : Program.Det_exp.t list }

val exp_of_sexp : Sexplib0.Sexp.t -> exp
val sexp_of_exp : exp -> Sexplib0.Sexp.t

exception Score_invalid_arguments

val prim_to_dist : string -> t
val score : Program.Det_exp.t -> string -> exp
8 changes: 8 additions & 0 deletions lib/env.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
open Core
open Program

type t = fn Map.M(Id).t

let empty : t = Map.empty (module Id)
let add (env : t) ~(name : Id.t) ~(fn : fn) = Map.add env ~key:name ~data:fn
let find_exn (env : t) ~(name : Id.t) : fn = Map.find_exn env name
12 changes: 12 additions & 0 deletions lib/env.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
type t = Program.fn Base.Map.M(Program.Id).t

val empty : t

val add :
t ->
name:string ->
fn:Program.fn ->
(string, Program.fn, Base.String.comparator_witness) Base.Map.t
Base.Map.Or_duplicate.t

val find_exn : t -> name:string -> Program.fn
43 changes: 43 additions & 0 deletions lib/graph.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
open Core
open Program

type vertex = Id.t [@@deriving sexp]
type arc = vertex * vertex [@@deriving sexp]
type det_map = Dist.exp Map.M(Id).t [@@deriving sexp]
type obs_map = Det_exp.t Map.M(Id).t [@@deriving sexp]

type t = {
vertices : vertex list;
arcs : arc list;
det_map : det_map;
obs_map : obs_map;
}
[@@deriving sexp]

let empty =
{
vertices = [];
arcs = [];
det_map = Map.empty (module Id);
obs_map = Map.empty (module Id);
}

let union g1 g2 =
{
vertices = g1.vertices @ g2.vertices;
arcs = g1.arcs @ g2.arcs;
det_map =
Map.merge g1.det_map g2.det_map ~f:(fun ~key:_ v ->
match v with
| `Left det | `Right det -> Some det
| `Both _ ->
failwith "Graph.union: duplicate deterministic expression");
obs_map =
Map.merge g1.obs_map g2.obs_map ~f:(fun ~key:_ v ->
match v with
| `Left obs | `Right obs -> Some obs
| `Both _ -> failwith "Graph.union: duplicate observation");
}

let pp (graph : t) : string = graph |> sexp_of_t |> Sexp.to_string_hum
let ( @+ ) = union
33 changes: 33 additions & 0 deletions lib/graph.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
type vertex = string

val vertex_of_sexp : Sexplib0.Sexp.t -> vertex
val sexp_of_vertex : vertex -> Sexplib0.Sexp.t

type arc = vertex * vertex

val arc_of_sexp : Sexplib0.Sexp.t -> arc
val sexp_of_arc : arc -> Sexplib0.Sexp.t

type det_map = Dist.exp Base.Map.M(Program.Id).t

val det_map_of_sexp : Sexplib0.Sexp.t -> det_map
val sexp_of_det_map : det_map -> Sexplib0.Sexp.t

type obs_map = Program.Det_exp.t Base.Map.M(Program.Id).t

val obs_map_of_sexp : Sexplib0.Sexp.t -> obs_map
val sexp_of_obs_map : obs_map -> Sexplib0.Sexp.t

type t = {
vertices : vertex list;
arcs : arc list;
det_map : det_map;
obs_map : obs_map;
}

val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t
val empty : t
val union : t -> t -> t
val pp : t -> string
val ( @+ ) : t -> t -> t
9 changes: 9 additions & 0 deletions lib/pred.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
open Core
open Program

type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t
[@@deriving sexp]

let rec fv : t -> Set.M(Id).t = function
| Empty -> Set.empty (module Id)
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p)
8 changes: 8 additions & 0 deletions lib/pred.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
type t =
| Empty
| And of Program.Det_exp.t * t
| And_not of Program.Det_exp.t * t

val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t
val fv : t -> Base.Set.M(Program.Id).t
79 changes: 79 additions & 0 deletions lib/program.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module Id = Core.String

module Exp : sig
type t =
| Int of int
| Real of float
| Var of string
| Add of t * t
| Radd of t * t
| Minus of t * t
| Rminus of t * t
| Neg of t
| Rneg of t
| Mult of t * t
| Rmult of t * t
| Div of t * t
| Rdiv of t * t
| Eq of t * t
| Noteq of t * t
| Less of t * t
| And of t * t
| Or of t * t
| Seq of t * t
| Not of t
| List of t list
| Record of (t * t) list
| Assign of string * t * t
| If of t * t * t
| Call of string * t list
| Sample of t
| Observe of t * t

val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t
end

type fn = { name : string; params : string list; body : Exp.t }

val fn_of_sexp : Sexplib0.Sexp.t -> fn
val sexp_of_fn : fn -> Sexplib0.Sexp.t

type program = { funs : fn list; exp : Exp.t }

val program_of_sexp : Sexplib0.Sexp.t -> program
val sexp_of_program : program -> Sexplib0.Sexp.t

module Det_exp : sig
type t =
| Int of int
| Real of float
| Var of string
| Add of t * t
| Radd of t * t
| Minus of t * t
| Rminus of t * t
| Neg of t
| Rneg of t
| Mult of t * t
| Rmult of t * t
| Div of t * t
| Rdiv of t * t
| Eq of t * t
| Noteq of t * t
| Less of t * t
| And of t * t
| Or of t * t
| Not of t
| List of t list
| Record of (t * t) list
| If of t * t * t
| Prim_call of string * t list

val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t
val to_exp : t -> Exp.t
val fv : t -> (string, Id.comparator_witness) Base.Set.t
end

val pp : program -> unit

0 comments on commit 9a9e5c5

Please sign in to comment.