From e15e7053bb2bc3e62ad339c3d237c4a903ab5379 Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Sat, 3 Feb 2024 04:25:46 -0800 Subject: [PATCH 1/2] refactored jim for plotting --- .gitignore | 5 +++++ example/GW150914.py | 4 ++++ src/jimgw/jim.py | 28 +++++++++++++++++----------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 5d6606d3..291e6b51 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,8 @@ H1.txt L1.txt V1.txt test_data + +*.png +*.npz +*.pdf +*.txt diff --git a/example/GW150914.py b/example/GW150914.py index bb8c6ffd..7335ad00 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -123,3 +123,7 @@ ) jim.sample(jax.random.PRNGKey(42)) + +jim.print_summary() +jim.Sampler.plot_summary("training") +jim.Sampler.plot_summary("production") \ No newline at end of file diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 1961d9c7..9b9f4643 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -7,10 +7,12 @@ from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.utils.PRNG_keys import initialize_rng_keys from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer -from flowMC.sampler.flowHMC import flowHMC +# from flowMC.sampler.flowHMC import flowHMC from jimgw.prior import Prior from jimgw.base import LikelihoodBase +from jimgw.utils.hyperparameters import jim_default_hyperparameters + class Jim(object): @@ -23,14 +25,18 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.Likelihood = likelihood self.Prior = prior - seed = kwargs.get("seed", 0) - n_chains = kwargs.get("n_chains", 20) - - rng_key_set = initialize_rng_keys(n_chains, seed=seed) - num_layers = kwargs.get("num_layers", 10) - hidden_size = kwargs.get("hidden_size", [128, 128]) - num_bins = kwargs.get("num_bins", 8) - + # Set and override any given hyperparameters, and save as attribute + self.hyperparameters = jim_default_hyperparameters + hyperparameter_names = list(self.hyperparameters.keys()) + + for key, value in kwargs.items(): + if key in hyperparameter_names: + self.hyperparameters[key] = value + + for key, value in self.hyperparameters.items(): + setattr(self, key, value) + + self.rng_key_set = initialize_rng_keys(self.hyperparameters["n_chains"], seed=self.hyperparameters["seed"]) local_sampler_arg = kwargs.get("local_sampler_arg", {}) local_sampler = MALA( @@ -39,7 +45,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): flowHMC_params = kwargs.get("flowHMC_params", {}) model = MaskedCouplingRQSpline( - self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1] + self.Prior.n_dim, self.num_layers, self.hidden_size, self.num_bins, self.rng_key_set[-1] ) if len(flowHMC_params) > 0: global_sampler = flowHMC( @@ -57,7 +63,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.Sampler = Sampler( self.Prior.n_dim, - rng_key_set, + self.rng_key_set, None, # type: ignore local_sampler, model, From 521de0a64c82d98ed78271b9b5d04273dc124d01 Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Sat, 3 Feb 2024 04:27:04 -0800 Subject: [PATCH 2/2] added utilities files --- src/jimgw/utils/hyperparameters.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/jimgw/utils/hyperparameters.py diff --git a/src/jimgw/utils/hyperparameters.py b/src/jimgw/utils/hyperparameters.py new file mode 100644 index 00000000..ff6f01df --- /dev/null +++ b/src/jimgw/utils/hyperparameters.py @@ -0,0 +1,22 @@ +jim_default_hyperparameters = { + "seed": 0, + "n_chains": 20, + "num_layers": 10, + "hidden_size": [128,128], + "num_bins": 8, + "local_sampler_arg": {}, + "n_walkers_maximize_likelihood": 100, + "n_loops_maximize_likelihood": 200, +} + +jim_explanation_hyperparameters = { + "seed": "(int) Value of the random seed used", + "n_chains": "(int) Number of chains to be used", + "num_layers": "(int) Number of hidden layers of the NF", + "hidden_size": "List[int, int] Sizes of the hidden layers of the NF", + "num_bins": "(int) Number of bins used in MaskedCouplingRQSpline", + "local_sampler_arg": "(dict) Additional arguments to be used in the local sampler", + "rng_key_set": "(jnp.array) Key set to be used in PRNG keys", + "n_walkers_maximize_likelihood": "(int) Number of walkers used in the maximization of the likelihood with the evolutionary optimizer", + "n_loops_maximize_likelihood": "(int) Number of loops to run the evolutionary optimizer in the maximization of the likelihood", +}