Skip to content

Commit

Permalink
Merge pull request #10 from kazewong/run-manager
Browse files Browse the repository at this point in the history
Run manager
  • Loading branch information
xuyuon authored Aug 29, 2024
2 parents 2a9d696 + 15a103d commit 2dae7b9
Show file tree
Hide file tree
Showing 21 changed files with 2,591 additions and 775 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
on: [push, pull_request]

jobs:
build:
Expand Down
Empty file added docs/tutorials/naming_system.md
Empty file.
113 changes: 58 additions & 55 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import jax
import jax.numpy as jnp

Expand All @@ -12,59 +11,50 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
[
[10.0, 40.0],
[0.125, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[0.0, 2000.0],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
)


run = SingleEventRun(
seed=0,
path="test_data/GW150914/",
detectors=["H1", "L1"],
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
},
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
"M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0},
"s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0},
"t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"iota": {"name": "SinePrior"},
"psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"dec": {"name": "CosinePrior"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
jim_parameters={
"n_loop_training": 10,
"n_loop_production": 10,
"n_local_steps": 150,
"n_global_steps": 150,
"n_chains": 500,
"n_epochs": 50,
"learning_rate": 0.001,
"n_max_examples": 45000,
"momentum": 0.9,
"batch_size": 50000,
"use_global": True,
"keep_quantile": 0.0,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds},
likelihood_parameters={"name": "TransientLikelihoodFD"},
sample_transforms=[
{"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,},
{"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,},
{"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,},
{"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
],
likelihood_transforms=[
{"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]},
],
injection=True,
injection_parameters={
"M_c": 28.6,
Expand All @@ -79,15 +69,28 @@
"ra": 1.2,
"dec": 0.3,
},
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
jim_parameters={
"n_loop_training": 100,
"n_loop_production": 20,
"n_local_steps": 10,
"n_global_steps": 1000,
"n_chains": 500,
"n_epochs": 30,
"learning_rate": 1e-4,
"n_max_examples": 30000,
"momentum": 0.9,
"batch_size": 30000,
"use_global": True,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
)

run_manager = SingleEventPERunManager(run=run)
run_manager.sample()

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()
run_manager.save_summary()
127 changes: 96 additions & 31 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,51 @@

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


class Jim(object):
"""
Master class for interfacing with flowMC
"""

def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior
likelihood: LikelihoodBase
prior: Prior

# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
self,
likelihood: LikelihoodBase,
prior: Prior,
sample_transforms: list[BijectiveTransform] = [],
likelihood_transforms: list[NtoMTransform] = [],
**kwargs,
):
self.likelihood = likelihood
self.prior = prior

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand All @@ -33,30 +67,56 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.Prior.n_dim,
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
local_sampler,
model,
**kwargs,
)

def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior_params = self.Prior.add_name(params.T)
prior = self.Prior.log_prob(prior_params)
return (
self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior
)
def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary
def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
"""

return dict(zip(self.parameter_names, x))

def posterior(self, params: Float[Array, " n_dim"], data: dict):
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in reversed(self.sample_transforms):
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
return self.likelihood.evaluate(named_params, data) + prior

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_guess = []
for _ in range(self.sampler.n_chains):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = self.prior.sample(key, 1)
for transform in self.sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_guess.append(guess)
initial_position = jnp.array(initial_guess)
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
self,
Expand All @@ -67,7 +127,7 @@ def maximize_likelihood(
):
key = jax.random.PRNGKey(seed)
set_nwalkers = set_nwalkers
initial_guess = self.Prior.sample(key, set_nwalkers)
initial_guess = self.prior.sample(key, set_nwalkers)

def negative_posterior(x: Float[Array, " n_dim"]):
return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet
Expand All @@ -78,7 +138,7 @@ def negative_posterior(x: Float[Array, " n_dim"]):
print("Done compiling")

print("Starting the optimizer")
optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True)
optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True)
_ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return best_fit
Expand All @@ -89,22 +149,24 @@ def print_summary(self, transform: bool = True):
"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T
training_chain = self.Prior.add_name(training_chain)
training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
training_chain = self.Prior.transform(training_chain)
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T
production_chain = self.Prior.add_name(production_chain)
production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
production_chain = self.Prior.transform(production_chain)
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -156,11 +218,14 @@ def get_samples(self, training: bool = False) -> dict:
"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1)))
chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
return chains

def plot(self):
Expand Down
Loading

0 comments on commit 2dae7b9

Please sign in to comment.