diff --git a/docs/tutorials/anatomy_of_jim.md b/docs/tutorials/anatomy_of_jim.md index 15ebef47..7faacf2b 100644 --- a/docs/tutorials/anatomy_of_jim.md +++ b/docs/tutorials/anatomy_of_jim.md @@ -25,4 +25,8 @@ At its core, `flowMC` is still a MCMC algorithm, so the hyperparameter tuning is 1. If you can, use more chains, especially on a GPU. Bring the number of chains up until you start to get significant performance hit or run out of memory. 2. Run it longer, in particular the training phase. In fact, most of the computation cost goes into the training part, once you get a reasonably tuned normalizing flow model, the production phase is usually quite cheap. To be concrete, blow `n_loop_training` up until you cannot stand how slow it is. +## Run Manager + +While Jim is the main object that will handle most of the work, there are a lot of bookkeeping that needs to be done around a run. + ## Analysis \ No newline at end of file diff --git a/example/GW150914.py b/example/GW150914.py index cd50892b..bb8c6ffd 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -1,8 +1,11 @@ import time from jimgw.jim import Jim -from jimgw.detector import H1, L1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import ( + HeterodynedTransientLikelihoodFD, + TransientLikelihoodFD, +) +from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.prior import Unconstrained_Uniform, Composite import jax.numpy as jnp import jax diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index ac164357..134c45bf 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -1,8 +1,8 @@ import time from jimgw.jim import Jim -from jimgw.detector import H1, L1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomPv2 +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.prior import Uniform, Composite, Sphere import jax.numpy as jnp import jax diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py index fce24f74..d69fa2aa 100644 --- a/example/GW150914_PV2_newglobal.py +++ b/example/GW150914_PV2_newglobal.py @@ -1,8 +1,8 @@ import time from jimgw.jim import Jim -from jimgw.detector import H1, L1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 from jimgw.prior import Uniform, Unconstrained_Uniform, Composite, Sphere import jax.numpy as jnp import jax diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py index f0ed508f..bdee9d55 100644 --- a/example/GW150914_heterodyne.py +++ b/example/GW150914_heterodyne.py @@ -1,8 +1,8 @@ import time from jimgw.jim import Jim -from jimgw.detector import H1, L1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.prior import Uniform, Composite import jax.numpy as jnp import jax diff --git a/example/GW170817.py b/example/GW170817.py index 8f01cc23..1bf0e799 100644 --- a/example/GW170817.py +++ b/example/GW170817.py @@ -1,8 +1,8 @@ import time from jimgw.jim import Jim -from jimgw.detector import H1, L1, V1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.prior import Uniform from gwosc.datasets import event_gps import jax.numpy as jnp diff --git a/example/GW170817_heterodyne.py b/example/GW170817_heterodyne.py index a2178336..8e0f4b21 100644 --- a/example/GW170817_heterodyne.py +++ b/example/GW170817_heterodyne.py @@ -1,11 +1,12 @@ from jimgw.jim import Jim -from jimgw.detector import H1, L1, V1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD -from jimgw.prior import Uniform, Powerlaw, Composite +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.prior import Uniform, PowerLaw, Composite import jax.numpy as jnp import jax import time + jax.config.update("jax_enable_x64", True) @@ -22,14 +23,20 @@ duration = T post_trigger_duration = 2 epoch = duration - post_trigger_duration -f_ref = fmin +f_ref = fmin ### Getting ifos and overwriting with above data tukey_alpha = 2 / (duration / 2) -H1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha) -L1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha) -V1.load_data(gps, duration, 2, fmin, fmax, psd_pad=duration+16, tukey_alpha=tukey_alpha) +H1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +L1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +V1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) ### Define priors @@ -41,13 +48,13 @@ naming=["q"], transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, ) -s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"]) -s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"]) +s1z_prior = Uniform(-0.05, 0.05, naming=["s1_z"]) +s2z_prior = Uniform(-0.05, 0.05, naming=["s2_z"]) # External parameters -dL_prior = Powerlaw(1.0, 75.0, 2.0, naming=["d_L"]) -t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"]) -phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +dL_prior = PowerLaw(1.0, 75.0, 2.0, naming=["d_L"]) +t_c_prior = Uniform(-0.1, 0.1, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) cos_iota_prior = Uniform( -1.0, 1.0, @@ -61,8 +68,8 @@ ) }, ) -psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) sin_dec_prior = Uniform( -1.0, 1.0, @@ -77,7 +84,8 @@ }, ) -prior = Composite([ +prior = Composite( + [ Mc_prior, q_prior, s1z_prior, @@ -96,19 +104,27 @@ bounds = jnp.array([[p.xmin, p.xmax] for p in prior.priors]) ### Create likelihood object -likelihood = HeterodynedTransientLikelihoodFD([H1, L1, V1], prior=prior, bounds=bounds, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=T, n_bins=500) +likelihood = HeterodynedTransientLikelihoodFD( + [H1, L1, V1], + prior=prior, + bounds=bounds, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=T, + n_bins=500, +) ### Create sampler and jim objects eps = 3e-2 n_dim = 11 mass_matrix = jnp.eye(n_dim) -mass_matrix = mass_matrix.at[0,0].set(1e-5) -mass_matrix = mass_matrix.at[1,1].set(1e-4) -mass_matrix = mass_matrix.at[2,2].set(1e-3) -mass_matrix = mass_matrix.at[3,3].set(1e-3) -mass_matrix = mass_matrix.at[5,5].set(1e-5) -mass_matrix = mass_matrix.at[9,9].set(1e-2) -mass_matrix = mass_matrix.at[10,10].set(1e-2) +mass_matrix = mass_matrix.at[0, 0].set(1e-5) +mass_matrix = mass_matrix.at[1, 1].set(1e-4) +mass_matrix = mass_matrix.at[2, 2].set(1e-3) +mass_matrix = mass_matrix.at[3, 3].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-5) +mass_matrix = mass_matrix.at[9, 9].set(1e-2) +mass_matrix = mass_matrix.at[10, 10].set(1e-2) local_sampler_arg = {"step_size": mass_matrix * eps} outdir_name = "./outdir/" @@ -129,11 +145,11 @@ use_global=True, keep_quantile=0.0, train_thinning=10, - output_thinning=30, - n_loops_maximize_likelihood = 2000, + output_thinning=30, + n_loops_maximize_likelihood=2000, local_sampler_arg=local_sampler_arg, - outdir_name=outdir_name + outdir_name=outdir_name, ) jim.sample(jax.random.PRNGKey(42)) -jim.print_summary() \ No newline at end of file +jim.print_summary() diff --git a/example/InjectionRecovery.py b/example/InjectionRecovery.py index a390b76d..3236ddcc 100644 --- a/example/InjectionRecovery.py +++ b/example/InjectionRecovery.py @@ -1,7 +1,7 @@ from jimgw.jim import Jim -from jimgw.detector import H1, L1, V1 -from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomPv2 +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.prior import Uniform from ripple import ms_to_Mc_eta import jax.numpy as jnp diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py new file mode 100644 index 00000000..3827914d --- /dev/null +++ b/example/Single_event_runManager.py @@ -0,0 +1,91 @@ + +from jimgw.single_event.runManager import SingleEventPERunManager, SingleEventRun +import jax.numpy as jnp +import jax + +jax.config.update("jax_enable_x64", True) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} +bounds = jnp.array( + [ + [10.0, 40.0], + [0.125, 1.0], + [-1.0, 1.0], + [-1.0, 1.0], + [0.0, 2000.0], + [-0.05, 0.05], + [0.0, 2 * jnp.pi], + [-1.0, 1.0], + [0.0, jnp.pi], + [0.0, 2 * jnp.pi], + [-1.0, 1.0], + ] +) + + +run = SingleEventRun( + seed=0, + path="test_data/GW150914/", + detectors=["H1", "L1"], + priors={ + "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "MassRatio"}, + "s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0}, + "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "cos_iota": {"name": "CosIota"}, + "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "sin_dec": {"name": "SinDec"}, + }, + waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, + jim_parameters={ + "n_loop_training": 10, + "n_loop_production": 10, + "n_local_steps": 150, + "n_global_steps": 150, + "n_chains": 500, + "n_epochs": 50, + "learning_rate": 0.001, + "max_samples": 45000, + "momentum": 0.9, + "batch_size": 50000, + "use_global": True, + "keep_quantile": 0.0, + "train_thinning": 1, + "output_thinning": 10, + "local_sampler_arg": local_sampler_arg, + }, + likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds}, + injection=True, + injection_parameters={ + "M_c": 28.6, + "eta": 0.24, + "s1_z": 0.05, + "s2_z": 0.05, + "d_L": 440.0, + "t_c": 0.0, + "phase_c": 0.0, + "iota": 0.5, + "psi": 0.7, + "ra": 1.2, + "dec": 0.3, + }, + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, +) + +run_manager = SingleEventPERunManager(run=run) diff --git a/setup.cfg b/setup.cfg index 46ba4d5c..3c61f7a3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ packages = find: install_requires = jax>=0.4.12 jaxlib>=0.4.12 - flowMC>=0.2.1 + flowMC>=0.2.4 ripplegw gwpy corner diff --git a/src/jimgw/base.py b/src/jimgw/base.py new file mode 100644 index 00000000..f05dddb1 --- /dev/null +++ b/src/jimgw/base.py @@ -0,0 +1,96 @@ +from abc import ABC, abstractmethod +import equinox as eqx +from jaxtyping import Array, Float + + +class Data(ABC): + @abstractmethod + def __init__(self): + raise NotImplementedError + + @abstractmethod + def fetch(self): + raise NotImplementedError + + +class Model(eqx.Module): + params: dict + + @abstractmethod + def __init__(self): + raise NotImplementedError + + def __call__(self, x: Array) -> float: + raise NotImplementedError + + +class LikelihoodBase(ABC): + """ + Base class for likelihoods. + Note that this likelihood class should work + for a some what general class of problems. + In light of that, this class would be some what abstract, + but the idea behind it is this handles two main components of a likelihood: + the data and the model. + It should be able to take the data and model and evaluate the likelihood for + a given set of parameters. + + """ + + _model: object + _data: object + + @property + def model(self): + """ + The model for the likelihood. + """ + return self._model + + @property + def data(self): + """ + The data for the likelihood. + """ + return self._data + + @abstractmethod + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """ + Evaluate the likelihood for a given set of parameters. + """ + raise NotImplementedError + + +class RunManager(ABC): + """ + Base class for run managers. + + A run manager is a class that help with book keeping for a run. + It should be able to log metadata, summarize the run, save the run, and load the run. + Individual use cases can extend this class to implement the actual functionality. + This class is meant to be a template for the functionality. + + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the run manager. + """ + self.likelihood = kwargs["likelihood"] + self.prior = kwargs["prior"] + self.jim = kwargs["jim"] + + @abstractmethod + def save(self, path: str): + """ + Save the run. + """ + raise NotImplementedError + + @abstractmethod + def load_from_path(self, path: str): + """ + Load the run. + """ + raise NotImplementedError diff --git a/src/jimgw/data.py b/src/jimgw/data.py deleted file mode 100644 index df31b7ea..00000000 --- a/src/jimgw/data.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABC, abstractmethod - - -class Data(ABC): - @abstractmethod - def __init__(self): - raise NotImplementedError - - @abstractmethod - def fetch(self): - raise NotImplementedError diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 6c05eb03..1961d9c7 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -1,15 +1,17 @@ -from jimgw.likelihood import LikelihoodBase +from jaxtyping import Array, Float, PRNGKeyArray +import jax +import jax.numpy as jnp + from flowMC.sampler.Sampler import Sampler from flowMC.sampler.MALA import MALA from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.utils.PRNG_keys import initialize_rng_keys from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer -from jimgw.prior import Prior -from jaxtyping import Array, Float, PRNGKeyArray -import jax -import jax.numpy as jnp from flowMC.sampler.flowHMC import flowHMC +from jimgw.prior import Prior +from jimgw.base import LikelihoodBase + class Jim(object): """ @@ -20,6 +22,7 @@ class Jim(object): def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.Likelihood = likelihood self.Prior = prior + seed = kwargs.get("seed", 0) n_chains = kwargs.get("n_chains", 20) @@ -62,6 +65,19 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): **kwargs, ) + def posterior(self, params: Float[Array, " n_dim"], data: dict): + prior_params = self.Prior.add_name(params.T) + prior = self.Prior.log_prob(prior_params) + return ( + self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + ) + + def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): + if initial_guess.size == 0: + initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + self.Sampler.sample(initial_guess, None) # type: ignore + def maximize_likelihood( self, bounds: Float[Array, " n_dim 2"], @@ -87,20 +103,7 @@ def negative_posterior(x: Float[Array, " n_dim"]): best_fit = optimizer.get_result()[0] return best_fit - def posterior(self, params: Array, data: dict): - prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) - return ( - self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior - ) - - def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): - if initial_guess.size == 0: - initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T - self.Sampler.sample(initial_guess, None) # type: ignore - - def print_summary(self): + def print_summary(self, transform: bool = True): """ Generate summary of the run @@ -109,23 +112,27 @@ def print_summary(self): train_summary = self.Sampler.get_sampler_state(training=True) production_summary = self.Sampler.get_sampler_state(training=False) - training_chain: Array = train_summary["chains"] - training_log_prob: Array = train_summary["log_prob"] - training_local_acceptance: Array = train_summary["local_accs"] - training_global_acceptance: Array = train_summary["global_accs"] - training_loss: Array = train_summary["loss_vals"] - - production_chain: Array = production_summary["chains"] - production_log_prob: Array = production_summary["log_prob"] - production_local_acceptance: Array = production_summary["local_accs"] - production_global_acceptance: Array = production_summary["global_accs"] + training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T + training_chain = self.Prior.add_name(training_chain) + if transform: + training_chain = self.Prior.transform(training_chain) + training_log_prob = train_summary["log_prob"] + training_local_acceptance = train_summary["local_accs"] + training_global_acceptance = train_summary["global_accs"] + training_loss = train_summary["loss_vals"] + + production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T + production_chain = self.Prior.add_name(production_chain) + if transform: + production_chain = self.Prior.transform(production_chain) + production_log_prob = production_summary["log_prob"] + production_local_acceptance = production_summary["local_accs"] + production_global_acceptance = production_summary["global_accs"] print("Training summary") print("=" * 10) - for index in range(len(self.Prior.naming)): - print( - f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}" - ) + for key, value in training_chain.items(): + print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") print( f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" ) @@ -141,10 +148,8 @@ def print_summary(self): print("Production summary") print("=" * 10) - for index in range(len(self.Prior.naming)): - print( - f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}" - ) + for key, value in production_chain.items(): + print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}") print( f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" ) diff --git a/src/jimgw/model.py b/src/jimgw/model.py deleted file mode 100644 index add21c2b..00000000 --- a/src/jimgw/model.py +++ /dev/null @@ -1,14 +0,0 @@ -import equinox as eqx -from abc import abstractmethod -from jaxtyping import Array - - -class Model(eqx.Module): - params: dict - - @abstractmethod - def __init__(self): - raise NotImplementedError - - def __call__(self, x: Array) -> float: - raise NotImplementedError diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 764d64a8..7577ffd7 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -95,12 +95,16 @@ class Uniform(Prior): xmin: Float = 0.0 xmax: Float = 1.0 + def __repr__(self): + return f"Uniform(xmin={self.xmin}, xmax={self.xmax})" + def __init__( self, xmin: Float, xmax: Float, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, ): super().__init__(naming, transforms) assert self.n_dim == 1, "Uniform needs to be 1D distributions" @@ -146,12 +150,16 @@ class Unconstrained_Uniform(Prior): xmin: Float = 0.0 xmax: Float = 1.0 + def __repr__(self): + return f"Unconstrained_Uniform(xmin={self.xmin}, xmax={self.xmax})" + def __init__( self, xmin: Float, xmax: Float, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, ): super().__init__(naming, transforms) assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" @@ -219,7 +227,10 @@ class Sphere(Prior): Magnitude is sampled from a uniform distribution. """ - def __init__(self, naming: str): + def __repr__(self): + return f"Sphere(naming={self.naming})" + + def __init__(self, naming: str, **kwargs): self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] self.transforms = { self.naming[0]: ( @@ -256,7 +267,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped -class Alignedspin(Prior): +class AlignedSpin(Prior): """ Prior distribution for the aligned (z) component of the spin. @@ -276,11 +287,15 @@ class Alignedspin(Prior): chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) + def __repr__(self): + return f"Alignedspin(amax={self.amax}, naming={self.naming})" + def __init__( self, amax: Float, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, ): super().__init__(naming, transforms) assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" @@ -358,7 +373,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped -class Powerlaw(Prior): +class PowerLaw(Prior): """ A prior following the power-law with alpha in the range [xmin, xmax). @@ -370,6 +385,9 @@ class Powerlaw(Prior): alpha: Float = 0.0 normalization: Float = 1.0 + def __repr__(self): + return f"Powerlaw(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + def __init__( self, xmin: Float, @@ -377,6 +395,7 @@ def __init__( alpha: Union[Int, Float], naming: list[str], transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, ): super().__init__(naming, transforms) if alpha < 0.0: @@ -436,8 +455,14 @@ def log_prob(self, x: dict[str, Float]) -> Float: class Composite(Prior): priors: list[Prior] = field(default_factory=list) + def __repr__(self): + return f"Composite(priors={self.priors}, naming={self.naming})" + def __init__( - self, priors: list[Prior], transforms: dict[str, tuple[str, Callable]] = {} + self, + priors: list[Prior], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, ): naming = [] self.transforms = {} diff --git a/src/jimgw/single_event/__init__.py b/src/jimgw/single_event/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/jimgw/detector.py b/src/jimgw/single_event/detector.py similarity index 95% rename from src/jimgw/detector.py rename to src/jimgw/single_event/detector.py index eb8176ad..0194c4b2 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/single_event/detector.py @@ -10,7 +10,7 @@ from scipy.signal.windows import tukey from jimgw.constants import EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS, C_SI -from jimgw.wave import Polarization +from jimgw.single_event.wave import Polarization DEG_TO_RAD = jnp.pi / 180 @@ -22,15 +22,6 @@ } -def np2(x): - """ - Returns the next power of two as big as or larger than x.""" - p = 1 - while p < x: - p = p << 1 - return p - - class Detector(ABC): """ Base class for all detectors. @@ -42,16 +33,13 @@ class Detector(ABC): data: Float[Array, " n_sample"] psd: Float[Array, " n_sample"] - @abstractmethod - def load_data(self, data): - raise NotImplementedError - @abstractmethod def fd_response( self, frequency: Float[Array, " n_sample"], - h: dict[str, Float[Array, " n_sample"]], + h_sky: dict[str, Float[Array, " n_sample"]], params: dict, + **kwargs, ) -> Float[Array, " n_sample"]: """ Modulate the waveform in the sky frame by the detector response @@ -62,8 +50,9 @@ def fd_response( def td_response( self, time: Float[Array, " n_sample"], - h: dict[str, Float[Array, " n_sample"]], + h_sky: dict[str, Float[Array, " n_sample"]], params: dict, + **kwargs, ) -> Float[Array, " n_sample"]: """ Modulate the waveform in the sky frame by the detector response @@ -85,6 +74,9 @@ class GroundBased2G(Detector): yarm_tilt: Float = 0 elevation: Float = 0 + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name})" + def __init__(self, name: str, **kwargs) -> None: self.name = name @@ -239,18 +231,16 @@ def load_data( self.name, trigger_time - gps_start_pad, trigger_time + gps_end_pad, - **gwpy_kwargs + **gwpy_kwargs, ) assert isinstance(data_td, TimeSeries), "Data is not a TimeSeries object." segment_length = data_td.duration.value n = len(data_td) - delta_t = data_td.dt.value + delta_t = data_td.dt.value # type: ignore data = jnp.fft.rfft(jnp.array(data_td.value) * tukey(n, tukey_alpha)) * delta_t freq = jnp.fft.rfftfreq(n, delta_t) # TODO: Check if this is the right way to fetch PSD - start_psd = ( - int(trigger_time) - gps_start_pad - 2 * psd_pad - ) # What does Int do here? + start_psd = int(trigger_time) - gps_start_pad - 2 * psd_pad end_psd = int(trigger_time) - gps_start_pad - psd_pad print("Fetching PSD data...") @@ -275,6 +265,7 @@ def fd_response( frequency: Float[Array, " n_sample"], h_sky: dict[str, Float[Array, " n_sample"]], params: dict[str, Float], + **kwargs, ) -> Array: """ Modulate the waveform in the sky frame by the detector response in the frequency domain. @@ -291,7 +282,13 @@ def fd_response( ) return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)), axis=0) - def td_response(self, time: Array, h: Array, params: Array) -> Array: + def td_response( + self, + time: Float[Array, " n_sample"], + h_sky: dict[str, Float[Array, " n_sample"]], + params: dict, + **kwargs, + ) -> Array: """ Modulate the waveform in the sky frame by the detector response in the time domain. """ @@ -405,6 +402,7 @@ def inject_signal( align_time = jnp.exp( -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"]) ) + signal = self.fd_response(freqs, h_sky, params) * align_time self.data = signal + noise_real + 1j * noise_imag @@ -462,3 +460,9 @@ def load_psd( elevation=51.884, mode="pc", ) + +detector_preset = { + "H1": H1, + "L1": L1, + "V1": V1, +} diff --git a/src/jimgw/likelihood.py b/src/jimgw/single_event/likelihood.py similarity index 92% rename from src/jimgw/likelihood.py rename to src/jimgw/single_event/likelihood.py index f00cbabc..d6f36195 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -1,5 +1,3 @@ -from abc import ABC, abstractmethod - import jax import jax.numpy as jnp import numpy as np @@ -9,53 +7,22 @@ from jaxtyping import Array, Float from scipy.interpolate import interp1d -from jimgw.detector import Detector +from jimgw.single_event.detector import Detector from jimgw.prior import Prior -from jimgw.waveform import Waveform - - -class LikelihoodBase(ABC): - """ - Base class for likelihoods. - Note that this likelihood class should work - for a some what general class of problems. - In light of that, this class would be some what abstract, - but the idea behind it is this handles two main components of a likelihood: - the data and the model. - It should be able to take the data and model and evaluate the likelihood for - a given set of parameters. - - """ - - _model: object - _data: object - - @property - def model(self): - """ - The model for the likelihood. - """ - return self._model - - @property - def data(self): - """ - The data for the likelihood. - """ - return self._data +from jimgw.single_event.waveform import Waveform +from jimgw.base import LikelihoodBase - @abstractmethod - def evaluate(self, params: dict[str, Float], data: dict) -> Float: - """ - Evaluate the likelihood for a given set of parameters. - """ - raise NotImplementedError - -class TransientLikelihoodFD(LikelihoodBase): +class SingleEventLiklihood(LikelihoodBase): detectors: list[Detector] waveform: Waveform + def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: + self.detectors = detectors + self.waveform = waveform + + +class TransientLikelihoodFD(SingleEventLiklihood): def __init__( self, detectors: list[Detector], @@ -63,6 +30,7 @@ def __init__( trigger_time: float = 0, duration: float = 4, post_trigger_duration: float = 2, + **kwargs, ) -> None: self.detectors = detectors assert jnp.all( @@ -167,6 +135,7 @@ def __init__( post_trigger_duration: float = 2, popsize: int = 100, n_loops: int = 2000, + **kwargs, ) -> None: super().__init__( detectors, waveform, trigger_time, duration, post_trigger_duration @@ -476,6 +445,7 @@ def y(x): return prior.transform(prior.add_name(best_fit)) -class PopulationLikelihood(LikelihoodBase): - events: Float[Array, " n_events n_samples n_dim"] - reference_pop: Float[Array, " n_det n_dim"] +likelihood_presets = { + "TransientLikelihoodFD": TransientLikelihoodFD, + "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, +} diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py new file mode 100644 index 00000000..bed1ee3d --- /dev/null +++ b/src/jimgw/single_event/runManager.py @@ -0,0 +1,344 @@ +from jimgw.base import RunManager +from dataclasses import dataclass, field, asdict +from jimgw.single_event.likelihood import likelihood_presets, SingleEventLiklihood +from jimgw.single_event.detector import detector_preset, Detector +from jimgw.single_event.waveform import waveform_preset, Waveform +from jimgw import prior +from jimgw.jim import Jim +import jax.numpy as jnp +import jax +import yaml +from astropy.time import Time +from jaxtyping import Array, Float, PyTree +import matplotlib.pyplot as plt +from jaxlib.xla_extension import ArrayImpl + + +def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): + return dumper.represent_list(data.tolist()) + + +yaml.add_representer(ArrayImpl, jaxarray_representer) + +prior_presets = { + "Unconstrained_Uniform": prior.Unconstrained_Uniform, + "Uniform": prior.Uniform, + "Sphere": prior.Sphere, + "AlignedSpin": prior.AlignedSpin, + "PowerLaw": prior.PowerLaw, + "Composite": prior.Composite, + "MassRatio": lambda **kwargs: prior.Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, + ), + "CosIota": lambda **kwargs: prior.Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos(params["cos_iota"]), + ) + }, + ), + "SinDec": lambda **kwargs: prior.Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin(params["sin_dec"]), + ) + }, + ), +} + + +@dataclass +class SingleEventRun: + seed: int + path: str + + detectors: list[str] + priors: dict[ + str, dict[str, str | float | int | bool] + ] # Transform cannot be included in this way, add it to preset if used often. + jim_parameters: dict[str, str | float | int | bool | dict] + injection_parameters: dict[str, float] + injection: bool = False + likelihood_parameters: dict[str, str | float | int | bool | PyTree] = field( + default_factory=lambda: {"name": "TransientLikelihoodFD"} + ) + waveform_parameters: dict[str, str | float | int | bool] = field( + default_factory=lambda: {"name": ""} + ) + data_parameters: dict[str, float | int] = field( + default_factory=lambda: { + "trigger_time": 0.0, + "duration": 0, + "post_trigger_duration": 0, + "f_min": 0.0, + "f_max": 0.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + } + ) + + +class SingleEventPERunManager(RunManager): + run: SingleEventRun + jim: Jim + + @property + def waveform(self): + return self.run.waveform_parameters["name"] + + @property + def detectors(self): + return self.run.detectors + + @property + def data(self): + return [detector.data for detector in self.likelihood.detectors] + + @property + def psds(self): + return self.run.detectors + + def __init__(self, **kwargs): + if "run" in kwargs: + print("Run instance provided. Loading from instance.") + self.run = kwargs["run"] + elif "path" in kwargs: + print("Run instance not provided. Loading from path.") + self.run = self.load_from_path(kwargs["path"]) + else: + print("Neither run instance nor path provided.") + raise ValueError + + local_prior = self.initialize_prior() + local_likelihood = self.initialize_likelihood(local_prior) + self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) + + def save(self, path: str): + output_dict = asdict(self.run) + with open(path + ".yaml", "w") as f: + yaml.dump(output_dict, f, sort_keys=False) + + def load_from_path(self, path: str) -> SingleEventRun: + with open(path, "r") as f: + data = yaml.safe_load(f) + return SingleEventRun(**data) + + ### Initialization functions ### + + def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: + """ + Since prior contains information about types, naming and ranges of parameters, + some of the likelihood class require the prior to be initialized, such as the + heterodyned likelihood. + + """ + detectors = self.initialize_detector() + waveform = self.initialize_waveform() + name = self.run.likelihood_parameters["name"] + assert isinstance(name, str), "Likelihood name must be a string." + if self.run.injection: + freqs = jnp.linspace( + self.run.data_parameters["f_min"], + self.run.data_parameters["f_sampling"] / 2, + int( + self.run.data_parameters["f_sampling"] + * self.run.data_parameters["duration"] + ), + ) + freqs = freqs[ + (freqs >= self.run.data_parameters["f_min"]) + & (freqs <= self.run.data_parameters["f_max"]) + ] + gmst = ( + Time(self.run.data_parameters["trigger_time"], format="gps") + .sidereal_time("apparent", "greenwich") + .rad + ) + h_sky = waveform(freqs, self.run.injection_parameters) + detector_parameters = { + "ra": self.run.injection_parameters["ra"], + "dec": self.run.injection_parameters["dec"], + "psi": self.run.injection_parameters["psi"], + "t_c": self.run.injection_parameters["t_c"], + "gmst": gmst, + "epoch": self.run.data_parameters["duration"] + - self.run.data_parameters["post_trigger_duration"], + } + key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901)) + for detector in detectors: + detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + key, subkey = jax.random.split(key) + return likelihood_presets[name]( + detectors, + waveform, + prior=prior, + **self.run.likelihood_parameters, + **self.run.data_parameters, + ) + + def initialize_prior(self) -> prior.Prior: + priors = [] + for name, parameters in self.run.priors.items(): + if parameters["name"] not in prior_presets: + raise ValueError(f"Prior {name} not recognized.") + priors.append( + prior_presets[parameters["name"]](naming=[name], **parameters) + ) + return prior.Composite(priors) + + def initialize_detector(self) -> list[Detector]: + """ + Initialize the detectors. + """ + print("Initializing detectors.") + trigger_time = self.run.data_parameters["trigger_time"] + duration = self.run.data_parameters["duration"] + post_trigger_duration = self.run.data_parameters["post_trigger_duration"] + f_min = self.run.data_parameters["f_min"] + f_max = self.run.data_parameters["f_max"] + tukey_alpha = self.run.data_parameters["tukey_alpha"] + + assert trigger_time >= 0, "Trigger time must be positive." + assert duration > 0, "Duration must be positive." + assert post_trigger_duration >= 0, "Post trigger duration must be positive." + assert f_min >= 0, "f_min must be positive." + assert f_max > f_min, "f_max must be greater than f_min." + assert 0 <= tukey_alpha <= 1, "Tukey alpha must be between 0 and 1." + + epoch = duration - post_trigger_duration + detectors = [] + for name in self.run.detectors: + detector = detector_preset[name] + if not self.run.injection: + print("Loading real data.") + detector.load_data( + trigger_time=trigger_time, + gps_start_pad=int(epoch), + gps_end_pad=int(post_trigger_duration), + psd_pad=int(duration * 4), + f_min=f_min, + f_max=f_max, + tukey_alpha=tukey_alpha, + ) + else: + print("Injection mode. Need to wait until waveform model is loaded.") + detectors.append(detector_preset[name]) + return detectors + + def initialize_waveform(self) -> Waveform: + """ + Initialize the waveform. + """ + print("Initializing waveform.") + name = self.run.waveform_parameters["name"] + if name not in waveform_preset: + raise ValueError(f"Waveform {name} not recognized.") + waveform = waveform_preset[name](**self.run.waveform_parameters) + return waveform + + ### Utility functions ### + + def get_detector_waveform( + self, params: dict[str, float] + ) -> tuple[ + Float[Array, " n_sample"], + dict[str, Float[Array, " n_sample"]], + dict[str, Float[Array, " n_sample"]], + ]: + """ + Get the waveform in each detector. + """ + if not self.run.injection: + raise ValueError("No injection provided.") + freqs = jnp.linspace( + self.run.data_parameters["f_min"], + self.run.data_parameters["f_sampling"] / 2, + int( + self.run.data_parameters["f_sampling"] + * self.run.data_parameters["duration"] + ), + ) + freqs = freqs[ + (freqs >= self.run.data_parameters["f_min"]) + & (freqs <= self.run.data_parameters["f_max"]) + ] + gmst = ( + Time(self.run.data_parameters["trigger_time"], format="gps") + .sidereal_time("apparent", "greenwich") + .rad + ) + h_sky = self.jim.Likelihood.waveform(freqs, params) # type: ignore + align_time = jnp.exp( + -1j * 2 * jnp.pi * freqs * (self.jim.Likelihood.epoch + params["t_c"]) # type: ignore + ) + detector_parameters = { + "ra": params["ra"], + "dec": params["dec"], + "psi": params["psi"], + "t_c": params["t_c"], + "gmst": gmst, + "epoch": self.run.data_parameters["duration"] + - self.run.data_parameters["post_trigger_duration"], + } + print(detector_parameters) + detector_waveforms = {} + for detector in self.jim.Likelihood.detectors: # type: ignore + detector_waveforms[detector.name] = ( + detector.fd_response(freqs, h_sky, detector_parameters) * align_time + ) + return freqs, detector_waveforms, h_sky + + def plot_injection_waveform(self, path: str): + """ + Plot the injection waveform. + """ + freqs, waveforms, h_sky = self.get_detector_waveform( + self.run.injection_parameters + ) + plt.figure() + for detector in self.jim.Likelihood.detectors: # type: ignore + plt.loglog( + freqs, + jnp.abs(waveforms[detector.name]), + label=detector.name + " (injection)", + ) + plt.loglog( + freqs, jnp.sqrt(jnp.abs(detector.psd)), label=detector.name + " (PSD)" + ) + plt.xlabel("Frequency (Hz)") + plt.ylabel("Amplitude") + plt.legend() + plt.savefig(path) + + def plot_data(self, path: str): + """ + Plot the data. + """ + + plt.figure() + for detector in self.jim.Likelihood.detectors: # type: ignore + plt.loglog( + detector.freqs, + jnp.abs(detector.data), + label=detector.name + " (data)", + ) + plt.loglog( + detector.freqs, + jnp.sqrt(jnp.abs(detector.psd)), + label=detector.name + " (PSD)", + ) + plt.xlabel("Frequency (Hz)") + plt.ylabel("Amplitude") + plt.legend() + plt.savefig(path) diff --git a/src/jimgw/utils.py b/src/jimgw/single_event/utils.py similarity index 100% rename from src/jimgw/utils.py rename to src/jimgw/single_event/utils.py diff --git a/src/jimgw/wave.py b/src/jimgw/single_event/wave.py similarity index 100% rename from src/jimgw/wave.py rename to src/jimgw/single_event/wave.py diff --git a/src/jimgw/waveform.py b/src/jimgw/single_event/waveform.py similarity index 83% rename from src/jimgw/waveform.py rename to src/jimgw/single_event/waveform.py index 582267cf..6c043f15 100644 --- a/src/jimgw/waveform.py +++ b/src/jimgw/single_event/waveform.py @@ -18,7 +18,7 @@ def __call__( class RippleIMRPhenomD(Waveform): f_ref: float - def __init__(self, f_ref: float = 20.0): + def __init__(self, f_ref: float = 20.0, **kwargs): self.f_ref = f_ref def __call__( @@ -42,11 +42,14 @@ def __call__( output["c"] = hc return output + def __repr__(self): + return f"RippleIMRPhenomD(f_ref={self.f_ref})" + class RippleIMRPhenomPv2(Waveform): f_ref: float - def __init__(self, f_ref: float = 20.0): + def __init__(self, f_ref: float = 20.0, **kwargs): self.f_ref = f_ref def __call__( @@ -73,3 +76,12 @@ def __call__( output["p"] = hp output["c"] = hc return output + + def __repr__(self): + return f"RippleIMRPhenomPv2(f_ref={self.f_ref})" + + +waveform_preset = { + "RippleIMRPhenomD": RippleIMRPhenomD, + "RippleIMRPhenomPv2": RippleIMRPhenomPv2, +}