diff --git a/bin/main.ml b/bin/main.ml index 5c3f91b..ba34671 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -21,14 +21,12 @@ let () = in try let pgm = Parser.program Lexer.start lexbuf in - let _ = print_endline "=== Printing Input Program ===" in + print_endline "=== Printing Input Program ==="; pp pgm; - let _ = print_endline "=== Gathering Functions ===" in let open Compiler in let env = gather_functions pgm in - let _ = print_endline "=== Compiling Output Program ===" in - let graph, _de = compile''' env Pred.Empty pgm.exp in - let _ = print_endline "=== Printing Output Program ===" in + let graph, _de = compile env Pred.Empty pgm.exp in + print_endline "=== Printing Output Program ==="; Printf.printf "%s" (Graph.pp graph) with Parsing.Parse_error -> print_endline ("Parsing Error: " ^ lexbuf_contents lexbuf) diff --git a/lib/compiler.ml b/lib/compiler.ml index 8196065..be4c443 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -15,7 +15,6 @@ let gen_vertex = let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t = let sub' exp = sub exp x det_exp in - Printf.printf "factorial %s\n" x; match exp with | Int n -> Int n | Real r -> Real r @@ -57,8 +56,7 @@ let gather_functions (prog : program) : Env.t = exception Not_closed_observation -let compile''' (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t - = +let compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t = let rec compile pred = let compile' e = compile pred e in let open Graph in @@ -114,21 +112,29 @@ let compile''' (env : Env.t) (pred : Pred.t) (exp : Exp.t) : Graph.t * Det_exp.t let g3, det_exp_alt = compile pred_false e_alt in let g = g1 @+ g2 @+ g3 in (g, Det_exp.If (det_exp_pred, det_exp_con, det_exp_alt)) - | Call (c, params) -> - let f = Env.find_exn env ~name:c in - let g, det_exps = - List.fold_map params ~init:Graph.empty ~f:(fun g e -> - let g', de = compile' e in - (g @+ g', de)) - in - let { params; body; _ } = f in - let param_det_pairs = List.zip_exn params det_exps in - let sub_body = - List.fold param_det_pairs ~init:body - ~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp) - in - let g_body, det_exp_body = compile' sub_body in - (g @+ g_body, det_exp_body) + | Call (c, params) -> ( + match Env.find env ~name:c with + | Some f -> + let g, det_exps = + List.fold_map params ~init:Graph.empty ~f:(fun g e -> + let g', de = compile' e in + (g @+ g', de)) + in + let { params; body; _ } = f in + let param_det_pairs = List.zip_exn params det_exps in + let sub_body = + List.fold param_det_pairs ~init:body + ~f:(fun acc (param_name, det_exp) -> sub acc param_name det_exp) + in + let g_body, det_exp_body = compile' sub_body in + (g @+ g_body, det_exp_body) + | None -> + let g, det_exps = + List.fold_map params ~init:Graph.empty ~f:(fun g e -> + let g', de = compile' e in + (g @+ g', de)) + in + (g, Prim_call (c, det_exps))) | Add (e1, e2) -> let g1, de1 = compile' e1 in let g2, de2 = compile' e2 in diff --git a/lib/compiler.mli b/lib/compiler.mli index d2be1bb..f25fefe 100644 --- a/lib/compiler.mli +++ b/lib/compiler.mli @@ -5,4 +5,4 @@ val gather_functions : Program.program -> Env.t exception Not_closed_observation -val compile''' : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t +val compile : Env.t -> Pred.t -> Program.Exp.t -> Graph.t * Program.Det_exp.t diff --git a/lib/env.ml b/lib/env.ml index 672dddc..ac2e8b6 100644 --- a/lib/env.ml +++ b/lib/env.ml @@ -8,4 +8,4 @@ let empty : t = Map.empty (module Id) let add_exn (env : t) ~(name : Id.t) ~(fn : fn) = Map.add_exn env ~key:name ~data:fn -let find_exn (env : t) ~(name : Id.t) : fn = Map.find_exn env name +let find (env : t) ~(name : Id.t) : fn option = Map.find env name diff --git a/lib/env.mli b/lib/env.mli index 4abf7a4..c535c99 100644 --- a/lib/env.mli +++ b/lib/env.mli @@ -8,4 +8,4 @@ val add_exn : fn:Program.fn -> (string, Program.fn, Base.String.comparator_witness) Base.Map.t -val find_exn : t -> name:string -> Program.fn +val find : t -> name:string -> Program.fn option