Skip to content

Commit

Permalink
✨ Handle primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 1, 2024
1 parent 45b0764 commit 8bad53b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
8 changes: 3 additions & 5 deletions bin/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ let () =
in
try
let pgm = Parser.program Lexer.start lexbuf in
let _ = print_endline "=== Printing Input Program ===" in
print_endline "=== Printing Input Program ===";
pp pgm;
let _ = print_endline "=== Gathering Functions ===" in
let open Compiler in
let env = gather_functions pgm in
let _ = print_endline "=== Compiling Output Program ===" in
let graph, _de = compile''' env Pred.Empty pgm.exp in
let _ = print_endline "=== Printing Output Program ===" in
let graph, _de = compile env Pred.Empty pgm.exp in
print_endline "=== Printing Output Program ===";
Printf.printf "%s" (Graph.pp graph)
with Parsing.Parse_error ->
print_endline ("Parsing Error: " ^ lexbuf_contents lexbuf)
Expand Down
42 changes: 24 additions & 18 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ let gen_vertex =

let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t =
let sub' exp = sub exp x det_exp in
Printf.printf "factorial %s\n" x;
match exp with
| Int n -> Int n
| Real r -> Real r
Expand Down Expand Up @@ -57,8 +56,7 @@ let gather_functions (prog : program) : Env.t =

exception Not_closed_observation

let compile''' (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t
=
let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t =
let rec compile pred =
let compile' e = compile pred e in
let open Graph in
Expand Down Expand Up @@ -114,21 +112,29 @@ let compile''' (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t
let g3, det_exp_alt = compile pred_false e_alt in
let g = g1 @+ g2 @+ g3 in
(g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt))
| Call (c, params) ->
let f = Env.find_exn env ~name:c in
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
in
let { params; body; _ } = f in
let param_det_pairs = List.zip_exn params det_exps in
let sub_body =
List.fold param_det_pairs ~init:body
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile' sub_body in
(g @+ g_body, det_exp_body)
| Call (c, params) -> (
match Env.find env ~name:c with
| Some f ->
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
in
let { params; body; _ } = f in
let param_det_pairs = List.zip_exn params det_exps in
let sub_body =
List.fold param_det_pairs ~init:body
~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp)
in
let g_body, det_exp_body = compile' sub_body in
(g @+ g_body, det_exp_body)
| None ->
let g, det_exps =
List.fold_map params ~init:Graph.empty ~f:(fun g e ->
let g', de = compile' e in
(g @+ g', de))
in
(g, Prim_call (c, det_exps)))
| Add (e1, e2) ->
let g1, de1 = compile' e1 in
let g2, de2 = compile' e2 in
Expand Down
2 changes: 1 addition & 1 deletion lib/compiler.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ val gather_functions : Program.program -> Env.t

exception Not_closed_observation

val compile''' : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t
val compile : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t
2 changes: 1 addition & 1 deletion lib/env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ let empty : t = Map.empty (module Id)
let add_exn (env : t) ~(name : Id.t) ~(fn : fn) =
Map.add_exn env ~key:name ~data:fn

let find_exn (env : t) ~(name : Id.t) : fn = Map.find_exn env name
let find (env : t) ~(name : Id.t) : fn option = Map.find env name
2 changes: 1 addition & 1 deletion lib/env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ val add_exn :
fn:Program.fn ->
(string, Program.fn, Base.String.comparator_witness) Base.Map.t

val find_exn : t -> name:string -> Program.fn
val find : t -> name:string -> Program.fn option

0 comments on commit 8bad53b

Please sign in to comment.