Skip to content

Commit

Permalink
✨ Use OCaml bindings for Python with seaborn instead of PLplot
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 12, 2024
1 parent 25ee366 commit dafc2a5
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
(fmt (>= 0.9.0))
(logs (>= 0.7.0))
(owl (>= 1.1))
(owl-plplot (>= 1.0))
(pyml (>= 20231101))
(string_dict (>= 0.16.0))
(ppx_jane (>= 0.16.0))
(menhir (>= 20231231))))
Expand Down
2 changes: 1 addition & 1 deletion lib/dune
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(library
(name stappl)
(libraries core owl owl-plplot string_dict logs)
(libraries core owl pyml string_dict logs)
(inline_tests)
(preprocess
(pps ppx_jane)))
Expand Down
17 changes: 7 additions & 10 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ let rec eval_pmdf :
(pmdf, Ex (dty, eval_dist ctx { ty; exp }))

let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) :
float array =
floatarray =
(* Initialize the context with the observed values. Float conversion must
succeed as observed variables do not contain free variables *)
let default : type a. a dty -> a = function
Expand All @@ -79,7 +79,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) :
(* Adapted from gibbs_sampling of Owl *)
let a, b = (1000, 10) in
let num_iter = a + (b * num_samples) in
let samples = Array.create ~len:num_samples 0. in
let samples = Stdlib.Float.Array.init num_samples (fun _ -> 0.) in
for i = 0 to num_iter - 1 do
(* Gibbs step *)
List.iter unobserved ~f:(fun (name, Ex exp) ->
Expand Down Expand Up @@ -122,7 +122,7 @@ let gibbs_sampling ~(num_samples : int) (graph : Graph.t) (Ex query : query) :
| Tyi, i -> float_of_int i
| Tyr, r -> r
in
samples.((i - a) / b) <- query
Stdlib.Float.Array.set samples ((i - a) / b) query
done;

samples
Expand All @@ -134,11 +134,8 @@ let infer ?(filename : string = "out") ?(num_samples : int = 100_000)
let filename = String.chop_suffix_if_exists filename ~suffix:".stp" in
let plot_path = filename ^ ".png" in

let open Owl_plplot in
let h = Plot.create plot_path in
Plot.set_title h
Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string);
let mat = Owl.Mat.of_array samples 1 num_samples in
Plot.histogram ~h ~bin:50 mat;
Plot.output h;
Plot.draw ~plot_path
~title:Typed_tree.Erased.([%sexp (of_rv query : exp)] |> Sexp.to_string)
~samples ~num_samples;

plot_path
27 changes: 27 additions & 0 deletions lib/plot.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
open! Core

let () = Py.initialize ()

let draw ~plot_path ~title ~samples ~num_samples =
let m = Py.Import.add_module "ocaml" in
List.iter2_exn ~f:(Py.Module.set m)
[ "plot_path"; "title"; "samples"; "num_samples" ]
[
Py.String.of_string plot_path;
Py.String.of_string title;
Py.Array.numpy samples;
Py.Int.of_int num_samples;
];

Py.Run.eval ~start:Py.File
{|
from ocaml import plot_path, title, samples, num_samples
import seaborn as sns

sns.set_theme()

g = sns.displot(samples, element="step", stat="probability", bins=num_samples // 1500)
g.set_titles(title)
g.tight_layout()
g.savefig(plot_path) |}
|> ignore
4 changes: 2 additions & 2 deletions stappl.opam
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ depends: [
"fmt" {>= "0.9.0"}
"logs" {>= "0.7.0"}
"owl" {>= "1.1"}
"owl-plplot" {>= "1.0"}
"pyml" {>= "20231101"}
"string_dict" {>= "0.16.0"}
"ppx_jane" {>= "0.16.0"}
"menhir" {>= "20231231"}
Expand All @@ -38,5 +38,5 @@ build: [
dev-repo: "git+https://github.com/shapespeare/stappl.git"
pin-depends: [
[ "owl.1.1" "git+https://github.com/owlbarn/owl#06943b0267e7e80dd0eba94ebf63ca4d25c71910" ]
[ "owl-plplot.1.0" "git+https://github.com/owlbarn/owl-plplot#ebc73c09a907c1c6ca2c4b970bdb70202ec90b50" ]
[ "pyml.20231101" "git+https://github.com/Zeta611/pyml#d62a7b9c2e3a856121c9cc850d71a11b00243b0c" ]
]
2 changes: 1 addition & 1 deletion stappl.opam.template
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pin-depends: [
[ "owl.1.1" "git+https://github.com/owlbarn/owl#06943b0267e7e80dd0eba94ebf63ca4d25c71910" ]
[ "owl-plplot.1.0" "git+https://github.com/owlbarn/owl-plplot#ebc73c09a907c1c6ca2c4b970bdb70202ec90b50" ]
[ "pyml.20231101" "git+https://github.com/Zeta611/pyml#d62a7b9c2e3a856121c9cc850d71a11b00243b0c" ]
]

0 comments on commit dafc2a5

Please sign in to comment.