Skip to content

Commit

Permalink
✨ Derive S-expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 1, 2024
1 parent 02703fe commit b61e98a
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions lib/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ open Core
open Program

module Env = struct
type t = (Id.t, fn, Id.comparator_witness) Map.t
type t = fn Map.M(Id).t

let empty : t = Map.empty (module Id)
let add (env : t) ~(name : Id.t) ~(fn : fn) = Map.add env ~key:name ~data:fn
Expand All @@ -11,20 +11,22 @@ end

module Pred = struct
type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t
[@@deriving sexp]

let rec fv : t -> Set.M(Id).t = function
| Empty -> Set.empty (module Id)
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p)
end

module Dist = struct
type t
type one = One
type t [@@deriving sexp]
type one = One [@@deriving sexp]

type exp =
| If_de of Det_exp.t * exp * exp
| If_pred of Pred.t * exp * one
| Dist_obj of { dist : t; var : Id.t; args : Det_exp.t list }
[@@deriving sexp]

exception Score_invalid_arguments

Expand All @@ -41,17 +43,18 @@ module Dist = struct
end

module Graph = struct
type vertex = Id.t
type arc = vertex * vertex
type det_map = (Id.t, Dist.exp, Id.comparator_witness) Map.t
type obs_map = (Id.t, Det_exp.t, Id.comparator_witness) Map.t
type vertex = Id.t [@@deriving sexp]
type arc = vertex * vertex [@@deriving sexp]
type det_map = Dist.exp Map.M(Id).t [@@deriving sexp]
type obs_map = Det_exp.t Map.M(Id).t [@@deriving sexp]

type t = {
vertices : vertex list;
arcs : arc list;
det_map : det_map;
obs_map : obs_map;
}
[@@deriving sexp]

let empty =
{
Expand All @@ -77,6 +80,8 @@ module Graph = struct
| `Left obs | `Right obs -> Some obs
| `Both _ -> failwith "Graph.union: duplicate observation");
}

let pp (graph : t) : string = graph |> sexp_of_t |> Sexp.to_string_hum
end

let ( @+ ) = Graph.union
Expand Down Expand Up @@ -267,7 +272,7 @@ let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
| Not e ->
let g, de = compile' e in
(g, Det_exp.Not de)
| Exp.List es ->
| List es ->
let g, des =
List.fold_map es ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
Expand Down

0 comments on commit b61e98a

Please sign in to comment.