Skip to content

Commit

Permalink
♻️ More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 1, 2024
1 parent e6e3d5d commit f65925c
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 37 deletions.
2 changes: 1 addition & 1 deletion bin/main.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Stappl

let print_position (outx : Out_channel.t) (lexbuf : Lexing.lexbuf) : unit =
Expand Down
22 changes: 11 additions & 11 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Program

let gen_sym =
Expand Down Expand Up @@ -71,7 +71,7 @@ let compile (program : program) : Graph.t * Det_exp.t =
let c2 e1 e2 ctor =
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
Graph.(g1 @+ g2, ctor de1 de2)
Graph.(g1 @| g2, ctor de1 de2)
in

let open Det_exp in
Expand All @@ -93,14 +93,14 @@ let compile (program : program) : Graph.t * Det_exp.t =
obs_map = Id.Map.empty;
}
in
(g @+ g', Var v)
(g @| g', Var v)
| Observe (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
let v = gen_vertex () in
let f1 = Dist.score de1 v in
let f = Dist.(If_pred (pred, f1, One)) in
let fvs = Set.union (fv de1) (Pred.fv pred) in
let fvs = Id.(fv de1 @| Pred.fv pred) in
if Core.not @@ Set.is_empty (fv de2) then raise Not_closed_observation;
let g' =
{
Expand All @@ -110,23 +110,23 @@ let compile (program : program) : Graph.t * Det_exp.t =
obs_map = Id.Map.singleton v de2;
}
in
(g1 @+ g2 @+ g', de2)
(g1 @| g2 @| g', de2)
| Assign (x, e, body) ->
let g1, det_exp1 = compile' e in
let sub_body = sub body x det_exp1 in
let g2, det_exp2 = compile' sub_body in
(g1 @+ g2, det_exp2)
(g1 @| g2, det_exp2)
| If (e_pred, e_con, e_alt) ->
let g1, det_exp_pred = compile' e_pred in
let open Pred in
let g2, det_exp_con = compile (pred &&& det_exp_pred) e_con in
let g3, det_exp_alt = compile (pred &&! det_exp_pred) e_alt in
(g1 @+ g2 @+ g3, If (det_exp_pred, det_exp_con, det_exp_alt))
(g1 @| g2 @| g3, If (det_exp_pred, det_exp_con, det_exp_alt))
| Call (c, params) -> (
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
(g @| g', de))
in
match Env.find env ~name:c with
| Some f ->
Expand All @@ -137,7 +137,7 @@ let compile (program : program) : Graph.t * Det_exp.t =
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile' sub_body in
(g @+ g_body, det_exp_body)
(g @| g_body, det_exp_body)
| None -> (g, Prim_call (c, det_exps)))
| Add (e1, e2) -> c2 e1 e2 add
| Radd (e1, e2) -> c2 e1 e2 radd
Expand All @@ -160,15 +160,15 @@ let compile (program : program) : Graph.t * Det_exp.t =
let g, des =
List.fold_map es ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
(g @| g', de))
in
(g, List des)
| Record fields ->
let g, des =
List.fold_map fields ~init:Graph.empty ~f:(fun g (k, v) ->
let g_k, de_k = compile' k in
let g_v, de_v = compile' v in
(g @+ g_k @+ g_v, (de_k, de_v)))
(g @| g_k @| g_v, (de_k, de_v)))
in
(g, Record des)
in
Expand Down
2 changes: 1 addition & 1 deletion lib/dist.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Program

type t = string [@@deriving sexp]
Expand Down
2 changes: 1 addition & 1 deletion lib/env.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Program

type t = fn Id.Map.t
Expand Down
4 changes: 2 additions & 2 deletions lib/graph.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Program

type vertex = Id.t [@@deriving sexp]
Expand Down Expand Up @@ -34,5 +34,5 @@ let union g1 g2 =
| `Both _ -> failwith "Graph.union: duplicate observation");
}

let ( @+ ) = union
let ( @| ) = union
let pretty (graph : t) : string = graph |> sexp_of_t |> Sexp.to_string_hum
8 changes: 4 additions & 4 deletions lib/pred.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
open Core
open! Core
open Program

type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t
Expand All @@ -7,6 +7,6 @@ type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t
let ( &&& ) p de = And (de, p)
let ( &&! ) p de = And_not (de, p)

let rec fv : t -> Set.M(Id).t = function
| Empty -> Set.empty (module Id)
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p)
let rec fv : t -> Id.Set.t = function
| Empty -> Id.Set.empty
| And (de, p) | And_not (de, p) -> Id.(Det_exp.fv de @| fv p)
35 changes: 18 additions & 17 deletions lib/program.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
open Core
module Id = String
open! Core

module Id = struct
let ( @| ) = Set.union

include String
end

module Det_exp = struct
type t =
Expand Down Expand Up @@ -28,9 +33,11 @@ module Det_exp = struct
| Prim_call of Id.t * t list
[@@deriving sexp, variants, stable_variant]

let rec fv : t -> (Id.t, Id.comparator_witness) Set.t = function
| Int _ | Real _ -> Set.empty (module Id)
| Var x -> Set.singleton (module Id) x
let rec fv : t -> Id.Set.t =
let open Id in
function
| Int _ | Real _ -> Id.Set.empty
| Var x -> Id.Set.singleton x
| Add (e1, e2)
| Radd (e1, e2)
| Minus (e1, e2)
Expand All @@ -44,21 +51,15 @@ module Det_exp = struct
| Less (e1, e2)
| And (e1, e2)
| Or (e1, e2) ->
Set.union (fv e1) (fv e2)
fv e1 @| fv e2
| Neg e | Rneg e | Not e -> fv e
| List es ->
List.fold es
~init:(Set.empty (module Id))
~f:(fun acc e -> Set.union acc (fv e))
| List es -> List.fold es ~init:Id.Set.empty ~f:(fun acc e -> acc @| fv e)
| Record fields ->
List.fold fields
~init:(Set.empty (module Id))
~f:(fun acc (k, v) -> Set.(union acc (union (fv k) (fv v))))
| If (cond, e1, e2) -> Set.(union (fv cond) (union (fv e1) (fv e2)))
List.fold fields ~init:Id.Set.empty ~f:(fun acc (k, v) ->
acc @| fv k @| fv v)
| If (cond, e1, e2) -> fv cond @| fv e1 @| fv e2
| Prim_call (id, es) ->
List.fold es
~init:(Set.singleton (module Id) id)
~f:(fun acc e -> Set.union acc (fv e))
List.fold es ~init:(Id.Set.singleton id) ~f:(fun acc e -> acc @| fv e)
end

module Exp = struct
Expand Down

0 comments on commit f65925c

Please sign in to comment.