Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting and hyperparameters #66

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ H1.txt
L1.txt
V1.txt
test_data

*.png
*.npz
*.pdf
*.txt
4 changes: 4 additions & 0 deletions example/GW150914.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@
)

jim.sample(jax.random.PRNGKey(42))

jim.print_summary()
jim.Sampler.plot_summary("training")
jim.Sampler.plot_summary("production")
28 changes: 17 additions & 11 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.utils.PRNG_keys import initialize_rng_keys
from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer
from flowMC.sampler.flowHMC import flowHMC
# from flowMC.sampler.flowHMC import flowHMC

from jimgw.prior import Prior
from jimgw.base import LikelihoodBase
from jimgw.utils.hyperparameters import jim_default_hyperparameters



class Jim(object):
Expand All @@ -23,14 +25,18 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior

seed = kwargs.get("seed", 0)
n_chains = kwargs.get("n_chains", 20)

rng_key_set = initialize_rng_keys(n_chains, seed=seed)
num_layers = kwargs.get("num_layers", 10)
hidden_size = kwargs.get("hidden_size", [128, 128])
num_bins = kwargs.get("num_bins", 8)

# Set and override any given hyperparameters, and save as attribute
self.hyperparameters = jim_default_hyperparameters
hyperparameter_names = list(self.hyperparameters.keys())

for key, value in kwargs.items():
if key in hyperparameter_names:
self.hyperparameters[key] = value

for key, value in self.hyperparameters.items():
setattr(self, key, value)

self.rng_key_set = initialize_rng_keys(self.hyperparameters["n_chains"], seed=self.hyperparameters["seed"])
local_sampler_arg = kwargs.get("local_sampler_arg", {})

local_sampler = MALA(
Expand All @@ -39,7 +45,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

flowHMC_params = kwargs.get("flowHMC_params", {})
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]
self.Prior.n_dim, self.num_layers, self.hidden_size, self.num_bins, self.rng_key_set[-1]
)
if len(flowHMC_params) > 0:
global_sampler = flowHMC(
Expand All @@ -57,7 +63,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

self.Sampler = Sampler(
self.Prior.n_dim,
rng_key_set,
self.rng_key_set,
None, # type: ignore
local_sampler,
model,
Expand Down
22 changes: 22 additions & 0 deletions src/jimgw/utils/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
jim_default_hyperparameters = {
"seed": 0,
"n_chains": 20,
"num_layers": 10,
"hidden_size": [128,128],
"num_bins": 8,
"local_sampler_arg": {},
"n_walkers_maximize_likelihood": 100,
"n_loops_maximize_likelihood": 200,
}

jim_explanation_hyperparameters = {
"seed": "(int) Value of the random seed used",
"n_chains": "(int) Number of chains to be used",
"num_layers": "(int) Number of hidden layers of the NF",
"hidden_size": "List[int, int] Sizes of the hidden layers of the NF",
"num_bins": "(int) Number of bins used in MaskedCouplingRQSpline",
"local_sampler_arg": "(dict) Additional arguments to be used in the local sampler",
"rng_key_set": "(jnp.array) Key set to be used in PRNG keys",
"n_walkers_maximize_likelihood": "(int) Number of walkers used in the maximization of the likelihood with the evolutionary optimizer",
"n_loops_maximize_likelihood": "(int) Number of loops to run the evolutionary optimizer in the maximization of the likelihood",
}
Loading