From 49d604d3db54a09376145449fc69480d3226c7f2 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 13:31:09 -0400 Subject: [PATCH 1/5] Fix mass transform test --- src/jimgw/jim.py | 16 ++++++++++++---- test/integration/test_GW150914.py | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 7d86fecf..a13d0d3d 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,10 +104,18 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: - initial_guess_named = self.prior.sample(key, self.sampler.n_chains) - for transform in self.sample_transforms: - initial_guess_named = jax.vmap(transform.forward)(initial_guess_named) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + initial_guess = [] + for i in range(self.sampler.n_chains): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = self.prior.sample(key, 1) + for transform in self.sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_guess.append(guess) + initial_guess = jnp.array(initial_guess) self.sampler.sample(initial_guess, None) # type: ignore def maximize_likelihood( diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914.py index 6fddf9ea..90cba0f1 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914.py @@ -9,7 +9,7 @@ from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -35,8 +35,10 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) @@ -49,8 +51,8 @@ prior = CombinePrior( [ - Mc_prior, - q_prior, + m_1_prior, + m_2_prior, s1z_prior, s2z_prior, dL_prior, @@ -64,8 +66,9 @@ ) sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), @@ -79,7 +82,7 @@ ] likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), ] likelihood = TransientLikelihoodFD( From ac0b1f57ab4ebb07e7d984f79735e9a03e49be86 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 13:43:49 -0400 Subject: [PATCH 2/5] Use PowerLaw for distance --- test/integration/{test_GW150914.py => test_GW150914_D.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename test/integration/{test_GW150914.py => test_GW150914_D.py} (98%) diff --git a/test/integration/test_GW150914.py b/test/integration/test_GW150914_D.py similarity index 98% rename from test/integration/test_GW150914.py rename to test/integration/test_GW150914_D.py index 90cba0f1..9a39434e 100644 --- a/test/integration/test_GW150914.py +++ b/test/integration/test_GW150914_D.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD @@ -41,7 +41,7 @@ m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) -dL_prior = UniformPrior(0.0, 2000.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) From fa134f9b0f575f7d5921e7083afa5b4c81b99400 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:04:51 -0400 Subject: [PATCH 3/5] Add heterodyne test --- src/jimgw/jim.py | 2 +- src/jimgw/single_event/likelihood.py | 66 +++++++-- test/integration/test_GW150914_D.py | 4 - .../integration/test_GW150914_D_heterodyne.py | 131 ++++++++++++++++++ 4 files changed, 184 insertions(+), 19 deletions(-) create mode 100644 test/integration/test_GW150914_D_heterodyne.py diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index a13d0d3d..2063b0bf 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -105,7 +105,7 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess = [] - for i in range(self.sampler.n_chains): + for _ in range(self.sampler.n_chains): flag = True while flag: key = jax.random.split(key)[1] diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index ce2e8f0e..0ccb1ce8 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -14,6 +14,7 @@ from jimgw.single_event.detector import Detector from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform +from jimgw.transforms import BijectiveTransform, NtoMTransform class SingleEventLiklihood(LikelihoodBase): @@ -184,8 +185,6 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, - prior: Prior, - bounds: Float[Array, " n_dim 2"], n_bins: int = 100, trigger_time: float = 0, duration: float = 4, @@ -194,6 +193,9 @@ def __init__( n_steps: int = 2000, ref_params: dict = {}, reference_waveform: Optional[Waveform] = None, + prior: Optional[Prior] = None, + sample_transforms: Optional[list[BijectiveTransform]] = [], + likelihood_transforms: Optional[list[NtoMTransform]] = [], **kwargs, ) -> None: super().__init__( @@ -254,17 +256,24 @@ def __init__( ) self.freq_grid_low = freq_grid[:-1] - if not ref_params: + if ref_params: + self.ref_params = ref_params + print(f"Reference parameters provided, which are {self.ref_params}") + elif prior: print("No reference parameters are provided, finding it...") ref_params = self.maximize_likelihood( - bounds=bounds, prior=prior, popsize=popsize, n_steps=n_steps + prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + popsize=popsize, + n_steps=n_steps, ) self.ref_params = {key: float(value) for key, value in ref_params.items()} print(f"The reference parameters are {self.ref_params}") else: - self.ref_params = ref_params - print(f"Reference parameters provided, which are {self.ref_params}") - + raise ValueError( + "Either reference parameters or parameter names must be provided" + ) # safe guard for the reference parameters # since ripple cannot handle eta=0.25 if jnp.isclose(self.ref_params["eta"], 0.25): @@ -542,25 +551,54 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, prior: Prior, - bounds: Float[Array, " n_dim 2"], + likelihood_transforms: list[BijectiveTransform], + sample_transforms: list[NtoMTransform], popsize: int = 100, n_steps: int = 2000, ): + parameter_names = prior.parameter_names + for transform in sample_transforms: + parameter_names = transform.propagate_name(parameter_names) + def y(x): - return -self.evaluate_original(prior.transform(prior.add_name(x)), {}) + named_params = dict(zip(parameter_names, x)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return -self.evaluate_original(named_params, {}) print("Starting the optimizer") + optimizer = optimization_Adam( n_steps=n_steps, learning_rate=0.001, noise_level=1 ) - initial_position = jnp.array( - list(prior.sample(jax.random.PRNGKey(0), popsize).values()) - ).T + + key = jax.random.PRNGKey(0) + initial_position = [] + for _ in range(popsize): + flag = True + while flag: + key = jax.random.split(key)[1] + guess = prior.sample(key, 1) + for transform in sample_transforms: + guess = transform.forward(guess) + guess = jnp.array([i for i in guess.values()]).T[0] + flag = not jnp.all(jnp.isfinite(guess)) + initial_position.append(guess) + initial_position = jnp.array(initial_position) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) - best_fit = optimized_positions[jnp.nanargmin(summary["final_log_prob"])] - return prior.transform(prior.add_name(best_fit)) + + best_fit = optimized_positions[jnp.argmin(summary["final_log_prob"])] + + named_params = dict(zip(parameter_names, best_fit)) + for transform in reversed(sample_transforms): + named_params = transform.backward(named_params) + for transform in likelihood_transforms: + named_params = transform.forward(named_params) + return named_params likelihood_presets = { diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index 9a39434e..ba3ce2d6 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -1,5 +1,3 @@ -import time - import jax import jax.numpy as jnp @@ -19,8 +17,6 @@ ########## First we grab data ############# ########################################### -total_time_start = time.time() - # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 duration = 4 diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py new file mode 100644 index 00000000..66093b88 --- /dev/null +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -0,0 +1,131 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = HeterodynedTransientLikelihoodFD( + ifos, + prior=prior, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) From 2c9c6a62fac657e9d54137ff151e9cc66e24385c Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:21:02 -0400 Subject: [PATCH 4/5] Fix typecheck --- src/jimgw/jim.py | 8 ++++---- src/jimgw/single_event/likelihood.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2063b0bf..043c4672 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -102,8 +102,8 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params, data) + prior - def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): - if initial_guess.size == 0: + def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): + if initial_position.size == 0: initial_guess = [] for _ in range(self.sampler.n_chains): flag = True @@ -115,8 +115,8 @@ def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): guess = jnp.array([i for i in guess.values()]).T[0] flag = not jnp.all(jnp.isfinite(guess)) initial_guess.append(guess) - initial_guess = jnp.array(initial_guess) - self.sampler.sample(initial_guess, None) # type: ignore + initial_position = jnp.array(initial_guess) + self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( self, diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 0ccb1ce8..00e6ce6b 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -194,8 +194,8 @@ def __init__( ref_params: dict = {}, reference_waveform: Optional[Waveform] = None, prior: Optional[Prior] = None, - sample_transforms: Optional[list[BijectiveTransform]] = [], - likelihood_transforms: Optional[list[NtoMTransform]] = [], + sample_transforms: list[BijectiveTransform] = [], + likelihood_transforms: list[NtoMTransform] = [], **kwargs, ) -> None: super().__init__( @@ -551,8 +551,8 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, prior: Prior, - likelihood_transforms: list[BijectiveTransform], - sample_transforms: list[NtoMTransform], + likelihood_transforms: list[NtoMTransform], + sample_transforms: list[BijectiveTransform], popsize: int = 100, n_steps: int = 2000, ): From 15e40748844a9f5ae33bf60f3aa0880ff5632dfa Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 2 Aug 2024 15:25:09 -0400 Subject: [PATCH 5/5] Shorten test runtime --- test/integration/test_GW150914_D_heterodyne.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index 66093b88..b5945cee 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -90,6 +90,8 @@ post_trigger_duration=2, sample_transforms=sample_transforms, likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, )