-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
260 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
val gen_sym : unit -> Program.Id.t | ||
val gen_vertex : unit -> string | ||
val sub : Program.Exp.t -> Program.Id.t -> Program.Det_exp.t -> Program.Exp.t | ||
|
||
exception Not_closed_observation | ||
|
||
val compile : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
open Core | ||
open Program | ||
|
||
type t [@@deriving sexp] | ||
type one = One [@@deriving sexp] | ||
|
||
type exp = | ||
| If_de of Det_exp.t * exp * exp | ||
| If_pred of Pred.t * exp * one | ||
| Dist_obj of { dist : t; var : Id.t; args : Det_exp.t list } | ||
[@@deriving sexp] | ||
|
||
exception Score_invalid_arguments | ||
|
||
let prim_to_dist : Id.t -> t = failwith "Not implemented" | ||
|
||
let rec score (det_exp : Det_exp.t) (var : Id.t) = | ||
match det_exp with | ||
| If (e_pred, e_con, e_alt) -> | ||
let s_con = score e_con var in | ||
let s_alt = score e_alt var in | ||
If_de (e_pred, s_con, s_alt) | ||
| Prim_call (c, es) -> Dist_obj { dist = prim_to_dist c; var; args = es } | ||
| _ -> raise Score_invalid_arguments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
type t | ||
|
||
val t_of_sexp : Sexplib0.Sexp.t -> t | ||
val sexp_of_t : t -> Sexplib0.Sexp.t | ||
|
||
type one = One | ||
|
||
val one_of_sexp : Sexplib0.Sexp.t -> one | ||
val sexp_of_one : one -> Sexplib0.Sexp.t | ||
|
||
type exp = | ||
| If_de of Program.Det_exp.t * exp * exp | ||
| If_pred of Pred.t * exp * one | ||
| Dist_obj of { dist : t; var : string; args : Program.Det_exp.t list } | ||
|
||
val exp_of_sexp : Sexplib0.Sexp.t -> exp | ||
val sexp_of_exp : exp -> Sexplib0.Sexp.t | ||
|
||
exception Score_invalid_arguments | ||
|
||
val prim_to_dist : string -> t | ||
val score : Program.Det_exp.t -> string -> exp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
open Core | ||
open Program | ||
|
||
type t = fn Map.M(Id).t | ||
|
||
let empty : t = Map.empty (module Id) | ||
let add (env : t) ~(name : Id.t) ~(fn : fn) = Map.add env ~key:name ~data:fn | ||
let find_exn (env : t) ~(name : Id.t) : fn = Map.find_exn env name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
type t = Program.fn Base.Map.M(Program.Id).t | ||
|
||
val empty : t | ||
|
||
val add : | ||
t -> | ||
name:string -> | ||
fn:Program.fn -> | ||
(string, Program.fn, Base.String.comparator_witness) Base.Map.t | ||
Base.Map.Or_duplicate.t | ||
|
||
val find_exn : t -> name:string -> Program.fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
open Core | ||
open Program | ||
|
||
type vertex = Id.t [@@deriving sexp] | ||
type arc = vertex * vertex [@@deriving sexp] | ||
type det_map = Dist.exp Map.M(Id).t [@@deriving sexp] | ||
type obs_map = Det_exp.t Map.M(Id).t [@@deriving sexp] | ||
|
||
type t = { | ||
vertices : vertex list; | ||
arcs : arc list; | ||
det_map : det_map; | ||
obs_map : obs_map; | ||
} | ||
[@@deriving sexp] | ||
|
||
let empty = | ||
{ | ||
vertices = []; | ||
arcs = []; | ||
det_map = Map.empty (module Id); | ||
obs_map = Map.empty (module Id); | ||
} | ||
|
||
let union g1 g2 = | ||
{ | ||
vertices = g1.vertices @ g2.vertices; | ||
arcs = g1.arcs @ g2.arcs; | ||
det_map = | ||
Map.merge g1.det_map g2.det_map ~f:(fun ~key:_ v -> | ||
match v with | ||
| `Left det | `Right det -> Some det | ||
| `Both _ -> | ||
failwith "Graph.union: duplicate deterministic expression"); | ||
obs_map = | ||
Map.merge g1.obs_map g2.obs_map ~f:(fun ~key:_ v -> | ||
match v with | ||
| `Left obs | `Right obs -> Some obs | ||
| `Both _ -> failwith "Graph.union: duplicate observation"); | ||
} | ||
|
||
let pp (graph : t) : string = graph |> sexp_of_t |> Sexp.to_string_hum | ||
let ( @+ ) = union |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
type vertex = string | ||
|
||
val vertex_of_sexp : Sexplib0.Sexp.t -> vertex | ||
val sexp_of_vertex : vertex -> Sexplib0.Sexp.t | ||
|
||
type arc = vertex * vertex | ||
|
||
val arc_of_sexp : Sexplib0.Sexp.t -> arc | ||
val sexp_of_arc : arc -> Sexplib0.Sexp.t | ||
|
||
type det_map = Dist.exp Base.Map.M(Program.Id).t | ||
|
||
val det_map_of_sexp : Sexplib0.Sexp.t -> det_map | ||
val sexp_of_det_map : det_map -> Sexplib0.Sexp.t | ||
|
||
type obs_map = Program.Det_exp.t Base.Map.M(Program.Id).t | ||
|
||
val obs_map_of_sexp : Sexplib0.Sexp.t -> obs_map | ||
val sexp_of_obs_map : obs_map -> Sexplib0.Sexp.t | ||
|
||
type t = { | ||
vertices : vertex list; | ||
arcs : arc list; | ||
det_map : det_map; | ||
obs_map : obs_map; | ||
} | ||
|
||
val t_of_sexp : Sexplib0.Sexp.t -> t | ||
val sexp_of_t : t -> Sexplib0.Sexp.t | ||
val empty : t | ||
val union : t -> t -> t | ||
val pp : t -> string | ||
val ( @+ ) : t -> t -> t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
open Core | ||
open Program | ||
|
||
type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t | ||
[@@deriving sexp] | ||
|
||
let rec fv : t -> Set.M(Id).t = function | ||
| Empty -> Set.empty (module Id) | ||
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
type t = | ||
| Empty | ||
| And of Program.Det_exp.t * t | ||
| And_not of Program.Det_exp.t * t | ||
|
||
val t_of_sexp : Sexplib0.Sexp.t -> t | ||
val sexp_of_t : t -> Sexplib0.Sexp.t | ||
val fv : t -> Base.Set.M(Program.Id).t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
module Id = Core.String | ||
|
||
module Exp : sig | ||
type t = | ||
| Int of int | ||
| Real of float | ||
| Var of string | ||
| Add of t * t | ||
| Radd of t * t | ||
| Minus of t * t | ||
| Rminus of t * t | ||
| Neg of t | ||
| Rneg of t | ||
| Mult of t * t | ||
| Rmult of t * t | ||
| Div of t * t | ||
| Rdiv of t * t | ||
| Eq of t * t | ||
| Noteq of t * t | ||
| Less of t * t | ||
| And of t * t | ||
| Or of t * t | ||
| Seq of t * t | ||
| Not of t | ||
| List of t list | ||
| Record of (t * t) list | ||
| Assign of string * t * t | ||
| If of t * t * t | ||
| Call of string * t list | ||
| Sample of t | ||
| Observe of t * t | ||
|
||
val t_of_sexp : Sexplib0.Sexp.t -> t | ||
val sexp_of_t : t -> Sexplib0.Sexp.t | ||
end | ||
|
||
type fn = { name : string; params : string list; body : Exp.t } | ||
|
||
val fn_of_sexp : Sexplib0.Sexp.t -> fn | ||
val sexp_of_fn : fn -> Sexplib0.Sexp.t | ||
|
||
type program = { funs : fn list; exp : Exp.t } | ||
|
||
val program_of_sexp : Sexplib0.Sexp.t -> program | ||
val sexp_of_program : program -> Sexplib0.Sexp.t | ||
|
||
module Det_exp : sig | ||
type t = | ||
| Int of int | ||
| Real of float | ||
| Var of string | ||
| Add of t * t | ||
| Radd of t * t | ||
| Minus of t * t | ||
| Rminus of t * t | ||
| Neg of t | ||
| Rneg of t | ||
| Mult of t * t | ||
| Rmult of t * t | ||
| Div of t * t | ||
| Rdiv of t * t | ||
| Eq of t * t | ||
| Noteq of t * t | ||
| Less of t * t | ||
| And of t * t | ||
| Or of t * t | ||
| Not of t | ||
| List of t list | ||
| Record of (t * t) list | ||
| If of t * t * t | ||
| Prim_call of string * t list | ||
|
||
val t_of_sexp : Sexplib0.Sexp.t -> t | ||
val sexp_of_t : t -> Sexplib0.Sexp.t | ||
val to_exp : t -> Exp.t | ||
val fv : t -> (string, Id.comparator_witness) Base.Set.t | ||
end | ||
|
||
val pp : program -> unit |