From f33a7824f03de5d602b22cfbf2ec93b46e278316 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 15:49:26 +0200 Subject: [PATCH 01/57] Adding transform from geocentric arrival time to detector arrival time --- src/jimgw/single_event/transforms.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c3e77846..b1735f5a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -170,6 +170,60 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): + """ + Transform the geocentric arrival time to detector arrival time + + In the geocentric convention, the arrival time of the signal at the + center of Earth is gps_time + t_c + + In the detector convention, the arrival time of the signal at the + detecotr is gps_time + time_delay_from_geo_to_det + t_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + t_c: Float + t_det: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + + def named_transform(x): + t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_det": t_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_c": t_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From 3505394c6e1f974134ecf59e326507c376c54621 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 16:01:11 +0200 Subject: [PATCH 02/57] Adding transform from distance to SNR weighted distance --- src/jimgw/single_event/transforms.py | 88 +++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index b1735f5a..6404fb7e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,9 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_det = x["t_c"] + self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -216,7 +218,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_c = x["t_det"] - self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -224,6 +228,86 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): + """ + Transform the luminosity distance to network SNR weighted distance + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifos: list[GroundBased2G], + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifos = ifos + + assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + + def named_transform(x): + d_L, M_c, ra, dec, psi, iota = ( + x["d_L"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + return { + "d_hat": d_hat, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + d_hat, M_c, ra, dec, psi, iota = ( + x["d_hat"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + return { + "d_L": d_L, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From df75cebc636eec563bf361e10c476465c161fe38 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 10:12:45 +0200 Subject: [PATCH 03/57] updating the typing for object attributes --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6404fb7e..e1a49e58 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -189,8 +189,7 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): """ gmst: Float - t_c: Float - t_det: Float + ifo: GroundBased2G def __init__( self, @@ -241,6 +240,7 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ gmst: Float + ifos: list[GroundBased2G] def __init__( self, From 9f2f52b323bd03923043a0be4203030229225f71 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:35:34 +0200 Subject: [PATCH 04/57] Adding geocentric phase to detector phase --- src/jimgw/single_event/transforms.py | 106 +++++++++++++++++++++------ 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index e1a49e58..bf060b8a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -227,6 +227,72 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): + """ + Transform the geocentric arrival phase to detector arrival phase + + In the geocentric convention, the arrival phase of the signal at the + center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + + In the detector convention, the arrival phase of the signal at the + detecotr is phi_det = phi_c / 2 + arg R_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifo: GroundBased2G + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + + def _calc_R_det(x): + ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + + return p_mode_term - 1j * c_mode_term + + def named_transform(x): + R_det = _calc_R_det(x) + phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + return { + "phi_det": phi_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + R_det = _calc_R_det(x) + phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + return { + "phi_c": phi_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ @@ -257,10 +323,8 @@ def __init__( assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] - def named_transform(x): - d_L, M_c, ra, dec, psi, iota = ( - x["d_L"], - x["M_c"], + def _calc_R_dets(x): + ra, dec, psi, iota = ( x["ra"], x["dec"], x["psi"], @@ -268,14 +332,22 @@ def named_transform(x): ) p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 + R_dets2 = 0.0 for ifo in self.ifos: antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + + def named_transform(x): + d_L, M_c = ( + x["d_L"], + x["M_c"], + ) + R_dets = _calc_R_dets(x) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, } @@ -283,24 +355,12 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c, ra, dec, psi, iota = ( + d_hat, M_c = ( x["d_hat"], x["M_c"], - x["ra"], - x["dec"], - x["psi"], - x["iota"], ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 - for ifo in self.ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + R_dets = _calc_R_dets(x) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, } From b62970f2a8f05bae53408870763aa115ea957fdc Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:51:22 +0200 Subject: [PATCH 05/57] Adding ZeroLikelihood for testing purpose --- src/jimgw/single_event/likelihood.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..9e775b33 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -26,6 +26,15 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: self.waveform = waveform +class ZeroLikelihood(LikelihoodBase): + + def __init__(self): + pass + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + return 0.0 + + class TransientLikelihoodFD(SingleEventLiklihood): def __init__( self, From 4ea332224fa35363961220b1dffccc167599a6c3 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 12:09:47 +0200 Subject: [PATCH 06/57] Adding the missing mode 2pi for phasing transform --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index bf060b8a..6f4f361a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -278,7 +278,7 @@ def named_transform(x): R_det = _calc_R_det(x) phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 return { - "phi_det": phi_det, + "phi_det": phi_det % (2. * jnp.pi), } self.transform_func = named_transform @@ -287,7 +287,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 return { - "phi_c": phi_c, + "phi_c": phi_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 7a4bae0e8d06735458b6f1e004258d1afb15aac2 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:38:23 +0200 Subject: [PATCH 07/57] Test wip --- test/integration/test_extrinsic.py | 108 ++++++++++++++++++ .../integration/test_extrinsic_no_distance.py | 90 +++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 test/integration/test_extrinsic.py create mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py new file mode 100644 index 00000000..c4719dd8 --- /dev/null +++ b/test/integration/test_extrinsic.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + M_c_prior, + q_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py new file mode 100644 index 00000000..d1b1e559 --- /dev/null +++ b/test/integration/test_extrinsic_no_distance.py @@ -0,0 +1,90 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + t_c_prior, + phase_c_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() From d5f86e52156c788a06d84c1ac5d0958ae4d0b8fb Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:42:45 +0200 Subject: [PATCH 08/57] Phase renaming --- src/jimgw/single_event/transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6f4f361a..1242b40b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -233,10 +233,10 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): Transform the geocentric arrival phase to detector arrival phase In the geocentric convention, the arrival phase of the signal at the - center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + center of Earth is phase_c / 2 (in ripple, phase_c is the orbital phase) In the detector convention, the arrival phase of the signal at the - detecotr is phi_det = phi_c / 2 + arg R_det + detecotr is phase_det = phase_c / 2 + arg R_det Parameters ---------- @@ -261,7 +261,7 @@ def __init__( ) self.ifo = ifo - assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] @@ -276,18 +276,18 @@ def _calc_R_det(x): def named_transform(x): R_det = _calc_R_det(x) - phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phi_det": phi_det % (2. * jnp.pi), + "phase_det": phase_det % (2. * jnp.pi), } self.transform_func = named_transform def named_inverse_transform(x): R_det = _calc_R_det(x) - phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phi_c": phi_c % (2. * jnp.pi), + "phase_c": phase_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 0a2e68c22611fea72a2643bbea8ab358c67c3650 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Tue, 13 Aug 2024 06:09:13 -0700 Subject: [PATCH 09/57] wip --- src/jimgw/single_event/transforms.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1242b40b..538d6b8e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,7 @@ def __init__( def named_transform(x): t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_det": t_det, @@ -217,8 +217,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): + import pdb; pdb.set_trace() t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_c": t_c, @@ -268,7 +269,7 @@ def _calc_R_det(x): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] @@ -278,7 +279,7 @@ def named_transform(x): R_det = _calc_R_det(x) phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phase_det": phase_det % (2. * jnp.pi), + "phase_det": phase_det % (2.0 * jnp.pi), } self.transform_func = named_transform @@ -287,7 +288,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phase_c": phase_c % (2. * jnp.pi), + "phase_c": phase_c % (2.0 * jnp.pi), } self.inverse_transform_func = named_inverse_transform From b96512c64139661b38fd7df1f2506d480c527ebb Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 14 Aug 2024 13:30:43 -0400 Subject: [PATCH 10/57] Push conditional bijective transform --- src/jimgw/transforms.py | 54 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..4d2ebb45 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -61,8 +61,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: class NtoNTransform(NtoMTransform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property def n_dim(self) -> int: return len(self.name_mapping[0]) @@ -162,6 +160,58 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy + +class ConditionalBijectiveTransform(BijectiveTransform): + + conditional_names: list[str] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], + ): + super().__init__(name_mapping) + self.conditional_names = conditional_names + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + transform_params.update( + dict((key, x_copy[key]) for key in self.conditional_names) + ) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + transform_params.update( + dict((key, y_copy[key]) for key in self.conditional_names) + ) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian @jaxtyped(typechecker=typechecker) From 526e33cbb8eae4c32782113cd2a7495d657e5ea9 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:44:04 -0700 Subject: [PATCH 11/57] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 46 ++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 538d6b8e..2fb757c2 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,7 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +171,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival time to detector arrival time @@ -194,10 +194,11 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -205,11 +206,22 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + ) + + def _calc_delay(x): + ra, dec = x["ra"], x["dec"] + if hasattr(ra, "shape") and len(ra.shape) > 0: + delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) + else: + delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) + return delay def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_det = x["t_c"] + delay return { "t_det": t_det, } @@ -217,10 +229,8 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - import pdb; pdb.set_trace() - t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_c = x["t_det"] - delay return { "t_c": t_c, } @@ -229,7 +239,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival phase to detector arrival phase @@ -252,10 +262,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -263,13 +274,22 @@ def __init__( self.ifo = ifo assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + ) def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + if hasattr(ra, "shape") and len(ra.shape) > 0: + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + else: + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] From dbf3f3064e5135575a5c7548ee9e1c5e9157553b Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:45:40 -0700 Subject: [PATCH 12/57] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 2fb757c2..dfcd01b8 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,11 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform +from jimgw.transforms import ( + ConditionalBijectiveTransform, + BijectiveTransform, + NtoNTransform, +) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +175,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival time to detector arrival time @@ -206,10 +212,7 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] - assert ( - "ra" in conditional_names - and "dec" in conditional_names - ) + assert "ra" in conditional_names and "dec" in conditional_names def _calc_delay(x): ra, dec = x["ra"], x["dec"] @@ -239,7 +242,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival phase to detector arrival phase @@ -287,7 +292,9 @@ def _calc_R_det(x): c_iota_term = jnp.cos(iota) if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + antenna_pattern = self.ifo.antenna_pattern( + ra[0], dec[0], psi[0], self.gmst + ) else: antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] From a3753612926dd2b95b440bbc14043cdf7412df29 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:46:04 -0700 Subject: [PATCH 13/57] Fixing jacobian handling --- src/jimgw/transforms.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4d2ebb45..a26ad9af 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -160,7 +160,8 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy - + + class ConditionalBijectiveTransform(BijectiveTransform): conditional_names: list[str] @@ -181,8 +182,14 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[0]} + for key1 in self.name_mapping[1] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -192,7 +199,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: list(output_params.keys()), ) return x_copy, jacobian - + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -201,8 +208,14 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[1]} + for key1 in self.name_mapping[0] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From d79af97d25e29b41a3f6690bd9df4db6ef46d54f Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 02:56:38 -0700 Subject: [PATCH 14/57] Both arrival phase and time transform are fully vectorized --- src/jimgw/single_event/transforms.py | 43 ++++++++++++---------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index dfcd01b8..047c0607 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -214,17 +214,10 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names - def _calc_delay(x): - ra, dec = x["ra"], x["dec"] - if hasattr(ra, "shape") and len(ra.shape) > 0: - delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) - else: - delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) - return delay - def named_transform(x): - delay = _calc_delay(x) - t_det = x["t_c"] + delay + t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -232,8 +225,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - delay = _calc_delay(x) - t_c = x["t_det"] - delay + t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -286,25 +280,22 @@ def __init__( and "iota" in conditional_names ) - def _calc_R_det(x): - ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + @jnp.vectorize + def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern( - ra[0], dec[0], psi[0], self.gmst - ) - else: - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - return p_mode_term - 1j * c_mode_term + return jnp.angle(p_mode_term - 1j * c_mode_term) def named_transform(x): - R_det = _calc_R_det(x) - phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_det = R_det_arg + x["phase_c"] / 2.0 return { "phase_det": phase_det % (2.0 * jnp.pi), } @@ -312,8 +303,10 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - R_det = _calc_R_det(x) - phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_c = (-R_det_arg + x["phase_det"]) * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From bcbcbe2e8d6e1d565544cadf9b9ea8e333602cd0 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:26 -0700 Subject: [PATCH 15/57] Shifting distance transform to conditional --- src/jimgw/single_event/transforms.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 047c0607..4b6ada1f 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -315,7 +315,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): +class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): """ Transform the luminosity distance to network SNR weighted distance @@ -332,10 +332,11 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -343,14 +344,16 @@ def __init__( self.ifos = ifos assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + and "M_c" in conditional_names + ) - def _calc_R_dets(x): - ra, dec, psi, iota = ( - x["ra"], - x["dec"], - x["psi"], - x["iota"], - ) + @jnp.vectorize + def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) R_dets2 = 0.0 @@ -367,7 +370,7 @@ def named_transform(x): x["d_L"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, @@ -380,7 +383,7 @@ def named_inverse_transform(x): x["d_hat"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, From 8dab27b6e259063f03db8f381e8613ca7e34a9d4 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:48 -0700 Subject: [PATCH 16/57] update example --- test/integration/test_extrinsic.py | 62 ++++++++++--- .../integration/test_extrinsic_no_distance.py | 90 ------------------- 2 files changed, 48 insertions(+), 104 deletions(-) delete mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index c4719dd8..988f6ea7 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,3 +1,12 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from astropy.time import Time + import jax import jax.numpy as jnp @@ -21,7 +30,6 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) @@ -33,7 +41,6 @@ prior = CombinePrior( [ M_c_prior, - q_prior, dL_prior, t_c_prior, phase_c_prior, @@ -44,31 +51,56 @@ ] ) +# calculate the d_hat range +@jnp.vectorize +def calc_R_dets(ra, dec, psi, iota): + gmst = ( + Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_dets2 = 0.0 + for ifo in ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + +key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) +# generate 10000 samples for each +ra_samples = ra_prior.sample(key1, 10000)["ra"] +dec_samples = dec_prior.sample(key2, 10000)["dec"] +psi_samples = psi_prior.sample(key3, 10000)["psi"] +iota_samples = iota_prior.sample(key4, 10000)["iota"] +R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) + +d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) +d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) + sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), ] -likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), -] +likelihood_transforms = [] likelihood = ZeroLikelihood() -mass_matrix = jnp.eye(9) +mass_matrix = jnp.eye(len(prior.base_prior)) #mass_matrix = mass_matrix.at[1, 1].set(1e-3) #mass_matrix = mass_matrix.at[5, 5].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -100,9 +132,11 @@ train_thinning=1, output_thinning=1, local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], + strategies=["default"], ) -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() +print("Start sampling") +key = jax.random.PRNGKey(42) +jim.sample(key) jim.print_summary() +samples = jim.get_samples() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py deleted file mode 100644 index d1b1e559..00000000 --- a/test/integration/test_extrinsic_no_distance.py +++ /dev/null @@ -1,90 +0,0 @@ -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior -from jimgw.single_event.detector import H1, L1, V1 -from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 - -ifos = [H1, L1, V1] - -t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) -phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = CosinePrior(parameter_names=["dec"]) - -prior = CombinePrior( - [ - t_c_prior, - phase_c_prior, - ra_prior, - dec_prior, - ] -) - -sample_transforms = [ - # all the user reparametrization transform - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), - # all the bound to unbound transform - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), -] - -likelihood_transforms = [] - -likelihood = ZeroLikelihood() - -mass_matrix = jnp.eye(9) -#mass_matrix = mass_matrix.at[1, 1].set(1e-3) -#mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) - -n_epochs = 2 -n_loop_training = 1 -learning_rate = 1e-4 - - -jim = Jim( - likelihood, - prior, - sample_transforms=sample_transforms, - likelihood_transforms=likelihood_transforms, - n_loop_training=n_loop_training, - n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30, - n_flow_samples=100, - momentum=0.9, - batch_size=100, - use_global=True, - train_thinning=1, - output_thinning=1, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], -) - -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() -jim.print_summary() From fd338825bd10557c67da5b55bac9dc59a75db379 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 14:17:17 -0700 Subject: [PATCH 17/57] Fixing the single sided unbound transform --- src/jimgw/transforms.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a26ad9af..8b4e2d75 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -418,17 +418,26 @@ class SingleSidedUnboundTransform(BijectiveTransform): """ + original_lower_bound: Float + def __init__( self, name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, ): super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.transform_func = lambda x: { - name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + name_mapping[1][i]: jnp.log( + x[name_mapping[0][i]] - self.original_lower_bound[i] + ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + name_mapping[0][i]: jnp.exp( + x[name_mapping[1][i]] + self.original_lower_bound[i] + ) for i in range(len(name_mapping[1])) } From 03e76dc1731c793039504b6bf709da79953e17dc Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Sat, 17 Aug 2024 01:52:18 -0700 Subject: [PATCH 18/57] Update extrinsic test --- test/integration/test_extrinsic.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 988f6ea7..7cb5bf32 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,10 +1,3 @@ -import psutil -p = psutil.Process() -p.cpu_affinity([0]) - -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - from astropy.time import Time import jax @@ -14,7 +7,7 @@ from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1, V1 from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound +from jimgw.transforms import BoundToUnbound, SingleSidedUnboundTransform from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform from flowMC.strategy.optimization import optimization_Adam @@ -30,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -93,7 +86,7 @@ def calc_R_dets(ra, dec, psi, iota): BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), + SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] @@ -119,9 +112,9 @@ def calc_R_dets(ra, dec, psi, iota): likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, + n_local_steps=1, + n_global_steps=1, + n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30, From 8fe4b5fdcf43c3a2c6430ba868aeddb3e96bbc72 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 14:52:39 +0200 Subject: [PATCH 19/57] bugfix for single sided transform --- src/jimgw/transforms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8b4e2d75..f7d4c702 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -435,9 +435,8 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.exp( - x[name_mapping[1][i]] + self.original_lower_bound[i] - ) + name_mapping[0][i]: jnp.exp(x[name_mapping[1][i]]) + + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } From a19b556a0db3565b141b7508ddf85f3e4250e8ef Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 05:59:11 -0700 Subject: [PATCH 20/57] Update test --- test/integration/test_extrinsic.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 7cb5bf32..55979402 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -23,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -44,33 +44,7 @@ ] ) -# calculate the d_hat range -@jnp.vectorize -def calc_R_dets(ra, dec, psi, iota): - gmst = ( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_dets2 = 0.0 - for ifo in ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_dets2 += p_mode_term**2 + c_mode_term**2 - - return jnp.sqrt(R_dets2) - -key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) -# generate 10000 samples for each -ra_samples = ra_prior.sample(key1, 10000)["ra"] -dec_samples = dec_prior.sample(key2, 10000)["dec"] -psi_samples = psi_prior.sample(key3, 10000)["psi"] -iota_samples = iota_prior.sample(key4, 10000)["iota"] -R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) - d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) -d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) sample_transforms = [ # all the user reparametrization transform From 6d2cd9704c10a8345aebfde506ed8bcbeb3e83fe Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 16:25:39 +0200 Subject: [PATCH 21/57] update distance transform --- src/jimgw/single_event/transforms.py | 36 +++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 4b6ada1f..5d8e688b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -328,6 +328,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] + d_L_min: Float + d_L_max: Float def __init__( self, @@ -335,6 +337,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], + d_L_min: Float, + d_L_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -342,8 +346,10 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos + self.d_L_min = d_L_min + self.d_L_max = d_L_max - assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( "ra" in conditional_names and "dec" in conditional_names @@ -371,20 +377,38 @@ def named_transform(x): x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + d_hat = scale_factor * d_L + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) + d_hat_unbounded = jnp.log(y / (1.0 - y)) + return { - "d_hat": d_hat, + "d_hat_unbounded": d_hat_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c = ( - x["d_hat"], + d_hat_unbounded, M_c = ( + x["d_hat_unbounded"], x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + d_hat = (d_hat_max - d_hat_min) / ( + 1.0 + jnp.exp(-d_hat_unbounded) + ) + d_hat_min + d_L = d_hat / scale_factor return { "d_L": d_L, } From 6993dd9072780203bf99046343f62328c63c7848 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 08:58:57 -0700 Subject: [PATCH 22/57] Update test --- test/integration/test_extrinsic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 55979402..a5dc5c7b 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -44,11 +44,10 @@ ] ) -d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), @@ -60,7 +59,6 @@ BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] From ff65fcf2e75a0655c8e33ddfb32c378c84b3dc8e Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 21:58:57 +0200 Subject: [PATCH 23/57] Update arrival time transform --- src/jimgw/single_event/transforms.py | 58 ++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 5d8e688b..1070adaf 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -196,6 +196,8 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( gmst: Float ifo: GroundBased2G + tc_min: Float + tc_max: Float def __init__( self, @@ -203,6 +205,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, + tc_min: Float, + tc_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -210,24 +214,46 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifo = ifo + self.tc_min = tc_min + self.tc_max = tc_max - assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names + @jnp.vectorize + def time_delay(ra, dec, gmst): + return self.ifo.delay_from_geocenter(ra, dec, gmst) + def named_transform(x): - t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( - x["ra"], x["dec"], self.gmst - ) + + time_shift = time_delay(x["ra"], x["dec"], self.gmst) + + t_det = x["t_c"] + time_shift + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + + y = (t_det - t_det_min) / (t_det_max - t_det_min) + t_det_unbounded = jnp.log(y / (1.0 - y)) return { - "t_det": t_det, + "t_det_unbounded": t_det_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + + time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)( x["ra"], x["dec"], self.gmst ) + + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + t_det = (t_det_max - t_det_min) / ( + 1.0 + jnp.exp(-x["t_det_unbounded"]) + ) + t_det_min + + t_c = t_det - time_shift + return { "t_c": t_c, } @@ -328,8 +354,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] - d_L_min: Float - d_L_max: Float + dL_min: Float + dL_max: Float def __init__( self, @@ -337,8 +363,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], - d_L_min: Float, - d_L_max: Float, + dL_min: Float, + dL_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -346,8 +372,8 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos - self.d_L_min = d_L_min - self.d_L_max = d_L_max + self.dL_min = dL_min + self.dL_max = dL_max assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( @@ -381,8 +407,8 @@ def named_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets d_hat = scale_factor * d_L - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) d_hat_unbounded = jnp.log(y / (1.0 - y)) @@ -402,8 +428,8 @@ def named_inverse_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max d_hat = (d_hat_max - d_hat_min) / ( 1.0 + jnp.exp(-d_hat_unbounded) From b98d783db7b16a14c91ca83e7e3cd606697682d7 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 13:04:47 -0700 Subject: [PATCH 24/57] Update test --- test/integration/test_extrinsic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index a5dc5c7b..ff79723e 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -49,7 +49,7 @@ # all the user reparametrization transform DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), @@ -58,7 +58,6 @@ BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), ] likelihood_transforms = [] @@ -84,8 +83,8 @@ likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=1, - n_global_steps=1, + n_local_steps=2, + n_global_steps=2, n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, From e399a5e9c7609fe1f07ba6951182e5899ac4b692 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 22:18:31 +0200 Subject: [PATCH 25/57] Fix typo --- test/integration/test_extrinsic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index ff79723e..f0e089fe 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,7 +47,7 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), From 583b759de0859d0d2096bc95d7cccfa1d0b6b824 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Aug 2024 00:43:21 +0200 Subject: [PATCH 26/57] Fix typo --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1070adaf..084fe368 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -332,7 +332,7 @@ def named_inverse_transform(x): R_det_arg = _calc_R_det_arg( x["ra"], x["dec"], x["psi"], x["iota"], self.gmst ) - phase_c = (-R_det_arg + x["phase_det"]) * 2.0 + phase_c = -R_det_arg + x["phase_det"] * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From b1133d36f4e6ff20667b30be9d7a91d58419d562 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:06 +0800 Subject: [PATCH 27/57] Update runManager.py --- src/jimgw/single_event/runManager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index aa8d0dc7..0a4b502d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,9 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -125,9 +123,6 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError - if self.run.injection and not self.run.injection_parameters: - raise ValueError("Injection mode requires injection parameters.") - local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From 2a9d696a20d28fdd1693698a1840fef9ab276578 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:49 +0800 Subject: [PATCH 28/57] Update runManager.py --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 0a4b502d..3f65166d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} From cd559b6ae127b8683933ddcc85776b8b27f078c6 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 22 Aug 2024 21:53:41 +0200 Subject: [PATCH 29/57] Adding docstring for zerolikelihood --- src/jimgw/single_event/likelihood.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 9e775b33..1508cdfa 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -27,6 +27,21 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: class ZeroLikelihood(LikelihoodBase): + """ + A likelihood class that always returns a log-likelihood of zero. + + This class is primarily used for testing or debugging purposes. + + Methods + ------- + __init__() -> None + Initializes the ZeroLikelihood object. No parameters are required or set. + + evaluate(params: dict[str, Float], data: dict) -> Float + Evaluates the likelihood for a given set of parameters and data, + always returning 0.0. This method does not perform any computation + based on the input parameters or data, making it useful for debugging. + """ def __init__(self): pass From 5d6a795616003ded35814babc018bf6564f489e4 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Sun, 25 Aug 2024 23:02:51 +0800 Subject: [PATCH 30/57] Added run script --- example/GW150914_PV2.py | 210 +++++++++++++++++++++------------------- 1 file changed, 110 insertions(+), 100 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 06209ba6..9dadb2e9 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -4,13 +4,15 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import Composite, Sphere, Unconstrained_Uniform +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomPv2 +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform, ComponentMassesToChirpMassMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam - jax.config.update("jax_enable_x64", True) ########################################### @@ -21,125 +23,114 @@ # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 -start = gps - 2 -end = gps + 2 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] - -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -waveform = RippleIMRPhenomPv2(f_ref=20) - -########################################### -########## Set up priors ################## -########################################### - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1_prior = Sphere(naming="s1") -s2_prior = Sphere(naming="s2") -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( +ifos = [H1, L1] + +f_ref = 20.0 + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( [ - Mc_prior, - q_prior, - s1_prior, - s2_prior, + m_1_prior, + m_2_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, psi_prior, ra_prior, - sin_dec_prior, - ], + dec_prior, + ] ) -epsilon = 1e-3 -bounds = jnp.array( - [ - [10.0, 80.0], - [0.125, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0.0, 2000], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) + jnp.array([[epsilon, -epsilon]]) +sample_transforms = [ + # all the user reparametrization transform + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=f_ref), + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] likelihood = TransientLikelihoodFD( - [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 + ifos, + waveform=RippleIMRPhenomPv2(f_ref=f_ref), + trigger_time=gps, + duration=4, + post_trigger_duration=2, ) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) -mass_matrix = jnp.eye(prior.n_dim) +mass_matrix = jnp.eye(15) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 1e-3} -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1, bounds=bounds) +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -import optax -n_epochs = 20 +n_epochs = 30 n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) +learning_rate = 1e-4 + jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=20, n_local_steps=10, @@ -148,18 +139,37 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30000, - n_flow_sample=100000, + n_flow_samples=100000, momentum=0.9, batch_size=30000, use_global=True, - keep_quantile=0.0, train_thinning=1, output_thinning=10, local_sampler_arg=local_sampler_arg, - # strategies=[Adam_optimizer,"default"], + strategies=[Adam_optimizer, "default"], ) +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt import numpy as np -# chains = np.load('./GW150914_init.npz')['chain'] -jim.sample(jax.random.PRNGKey(42))#,initial_guess=chains) +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_Pv2_testing_reparam.jpeg") From a844170bed39755aaf9f3045b3ced95cb48723f3 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 00:35:14 +0800 Subject: [PATCH 31/57] Added run script --- example/GW150914_D.py | 155 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 example/GW150914_D.py diff --git a/example/GW150914_D.py b/example/GW150914_D.py new file mode 100644 index 00000000..2fa23c01 --- /dev/null +++ b/example/GW150914_D.py @@ -0,0 +1,155 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_D_reparam.jpeg") \ No newline at end of file From bfa4e4781ef4677ae0c26596d71bae6fb0e36783 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:07:10 +0800 Subject: [PATCH 32/57] Added run script --- example/GW150914.py | 138 ------------------------------------------ example/GW150914_D.py | 8 ++- 2 files changed, 7 insertions(+), 139 deletions(-) delete mode 100644 example/GW150914.py diff --git a/example/GW150914.py b/example/GW150914.py deleted file mode 100644 index 559b5b7c..00000000 --- a/example/GW150914.py +++ /dev/null @@ -1,138 +0,0 @@ -import time - -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import Composite, Unconstrained_Uniform -from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -total_time_start = time.time() - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 -duration = 4 -post_trigger_duration = 2 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration -fmin = 20.0 -fmax = 1024.0 - -ifos = ["H1", "L1"] - -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) -s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( - [ - Mc_prior, - q_prior, - s1z_prior, - s2z_prior, - dL_prior, - t_c_prior, - phase_c_prior, - cos_iota_prior, - psi_prior, - ra_prior, - sin_dec_prior, - ] -) -likelihood = TransientLikelihoodFD( - [H1, L1], - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=4, - post_trigger_duration=2, -) - - -mass_matrix = jnp.eye(11) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) - -import optax -n_epochs = 20 -n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) - - -jim = Jim( - likelihood, - prior, - n_loop_training=n_loop_training, - n_loop_production=20, - n_local_steps=10, - n_global_steps=1000, - n_chains=500, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30000, - n_flow_samples=100000, - momentum=0.9, - batch_size=30000, - use_global=True, - train_thinning=1, - output_thinning=10, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer,"default"], -) - -jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_D.py b/example/GW150914_D.py index 2fa23c01..ebf2e895 100644 --- a/example/GW150914_D.py +++ b/example/GW150914_D.py @@ -152,4 +152,10 @@ samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array transposed_array = samples.T # transpose the array figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) -plt.savefig("GW1500914_D_reparam.jpeg") \ No newline at end of file +plt.savefig("GW1500914_D_reparam.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW150914_D_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From 32b8e2e95f02ad782b5ca41172f1ba84f9bb1c71 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:50:30 +0800 Subject: [PATCH 33/57] Added run script --- example/{GW150914_D.py => GW150914_D_reparam.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename example/{GW150914_D.py => GW150914_D_reparam.py} (100%) diff --git a/example/GW150914_D.py b/example/GW150914_D_reparam.py similarity index 100% rename from example/GW150914_D.py rename to example/GW150914_D_reparam.py From 041cf2a7bb79de25b71eb2dc92c83859d0cf1e68 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:51:23 +0800 Subject: [PATCH 34/57] Added run script --- example/GW150914_D.py | 158 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 example/GW150914_D.py diff --git a/example/GW150914_D.py b/example/GW150914_D.py new file mode 100644 index 00000000..06ac234c --- /dev/null +++ b/example/GW150914_D.py @@ -0,0 +1,158 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_D.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW150914_D.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From 0803ae8a43fb2d48862552c5234c5937f6f4cd88 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:44:42 +0800 Subject: [PATCH 35/57] Added run script on gw200112 --- example/GW200112_D.py | 158 +++++++++++++++++++++++++++++++++ example/GW200112_D_reparam.py | 161 ++++++++++++++++++++++++++++++++++ 2 files changed, 319 insertions(+) create mode 100644 example/GW200112_D.py create mode 100644 example/GW200112_D_reparam.py diff --git a/example/GW200112_D.py b/example/GW200112_D.py new file mode 100644 index 00000000..2e7b7f03 --- /dev/null +++ b/example/GW200112_D.py @@ -0,0 +1,158 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW200112 +gps = 1262879936.0 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [V1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW200112_D.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW200112_D.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/example/GW200112_D_reparam.py b/example/GW200112_D_reparam.py new file mode 100644 index 00000000..28a8e5ab --- /dev/null +++ b/example/GW200112_D_reparam.py @@ -0,0 +1,161 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW200112 +gps = 1262879936.0 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [V1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW200112_D_reparam.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW200112_D_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From d9321833f78bd1d40e8abdaa168fce0a6f1e827d Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:23:25 +0800 Subject: [PATCH 36/57] Added functions to calculate iota --- src/jimgw/single_event/utils.py | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..5741f64b 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -551,6 +551,86 @@ def rotate_z(angle, vec): return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] +def spin_to_iota( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + q: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + m1, m2 = Mc_q_to_m1_m2(M_c, q) + eta = q / (1 + q) ** 2 + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + return iota + + def zenith_azimuth_to_ra_dec( zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] ) -> tuple[Float, Float]: From 7b038ac435dc113545893e92d407302bf4fe9f7b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:42:04 +0800 Subject: [PATCH 37/57] Updated transforms.py --- src/jimgw/single_event/transforms.py | 49 ++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 084fe368..aba7ee9c 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -20,6 +20,7 @@ zenith_azimuth_to_ra_dec, euler_rotation, spin_to_cartesian_spin, + spin_to_iota, ) @@ -303,8 +304,25 @@ def __init__( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and "iota" in conditional_names + and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names)) ) + + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -319,7 +337,7 @@ def _calc_R_det_arg(ra, dec, psi, iota, gmst): def named_transform(x): R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + x["ra"], x["dec"], x["psi"], self.get_iota(x), self.gmst ) phase_det = R_det_arg + x["phase_c"] / 2.0 return { @@ -330,7 +348,7 @@ def named_transform(x): def named_inverse_transform(x): R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + x["ra"], x["dec"], x["psi"], self.get_iota(x), self.gmst ) phase_c = -R_det_arg + x["phase_det"] * 2.0 return { @@ -365,6 +383,7 @@ def __init__( ifos: list[GroundBased2G], dL_min: Float, dL_max: Float, + freq_ref: Float = None, ): super().__init__(name_mapping, conditional_names) @@ -374,16 +393,34 @@ def __init__( self.ifos = ifos self.dL_min = dL_min self.dL_max = dL_max + self.freq_ref = freq_ref assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and "iota" in conditional_names + and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names)) and "M_c" in conditional_names ) + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 @@ -402,7 +439,7 @@ def named_transform(x): x["d_L"], x["M_c"], ) - R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], self.get_iota(x)) scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets d_hat = scale_factor * d_L @@ -424,7 +461,7 @@ def named_inverse_transform(x): x["d_hat_unbounded"], x["M_c"], ) - R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], self.get_iota(x)) scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets From c2f0c79c4fc46d66d05a8a4fab3056a795e6cf34 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:51:02 +0800 Subject: [PATCH 38/57] Updated transforms.py --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index aba7ee9c..18ca7215 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -304,7 +304,7 @@ def __init__( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names)) + and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names and "M_c" in conditional_names and "q" in conditional_names)) ) if "iota" in conditional_names: @@ -400,7 +400,7 @@ def __init__( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names)) + and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names and "phase_c" in conditional_names)) and "M_c" in conditional_names ) From 12732461517f35606ab511e62479c15cda1a995e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:51:24 +0800 Subject: [PATCH 39/57] Updated GW150914_Pv2.py --- example/GW150914_PV2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 9dadb2e9..41ed4c83 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -78,18 +78,13 @@ sample_transforms = [ # all the user reparametrization transform ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c", "q", "ra", "dec", "psi", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2", "phase_c"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "M_c", "q", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), @@ -97,6 +92,10 @@ BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ From 0e181d8b50f8cc2c5b563fc806dca7c19479a271 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:52:35 +0800 Subject: [PATCH 40/57] Updated GW150914_Pv2_reparam.py --- example/{GW150914_PV2.py => GW150914_Pv2_reparam.py} | 6 ++++++ 1 file changed, 6 insertions(+) rename example/{GW150914_PV2.py => GW150914_Pv2_reparam.py} (97%) diff --git a/example/GW150914_PV2.py b/example/GW150914_Pv2_reparam.py similarity index 97% rename from example/GW150914_PV2.py rename to example/GW150914_Pv2_reparam.py index 41ed4c83..ad77b9fb 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_Pv2_reparam.py @@ -172,3 +172,9 @@ transposed_array = samples.T # transpose the array figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) plt.savefig("GW1500914_Pv2_testing_reparam.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW150914_Pv2_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From 79bdb1f026dc7448a86b47cc85f1c67bb3f885e2 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:04:21 +0800 Subject: [PATCH 41/57] Updated GW150914_Pv2_reparam.py --- example/GW150914_Pv2_reparam.py | 4 ++-- src/jimgw/single_event/transforms.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/example/GW150914_Pv2_reparam.py b/example/GW150914_Pv2_reparam.py index ad77b9fb..23797617 100644 --- a/example/GW150914_Pv2_reparam.py +++ b/example/GW150914_Pv2_reparam.py @@ -78,8 +78,8 @@ sample_transforms = [ # all the user reparametrization transform ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c", "q", "ra", "dec", "psi", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2", "phase_c"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "M_c", "q", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c", "q", "ra", "dec", "psi", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2", "phase_c"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax, freq_ref=f_ref), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "M_c", "q", "theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], gps_time=gps, ifo=ifos[0], freq_ref=f_ref), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 18ca7215..70b2fd17 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -284,6 +284,7 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( gmst: Float ifo: GroundBased2G + freq_ref: Float def __init__( self, @@ -291,6 +292,7 @@ def __init__( conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, + freq_ref: Float = None, ): super().__init__(name_mapping, conditional_names) @@ -298,6 +300,7 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifo = ifo + self.freq_ref = freq_ref assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] assert ( @@ -374,6 +377,7 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): ifos: list[GroundBased2G] dL_min: Float dL_max: Float + freq_ref: Float def __init__( self, From 65a15c469817869cd09603b7eeaa9f756b9eac11 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:10:58 +0800 Subject: [PATCH 42/57] Updated transforms.py --- src/jimgw/single_event/transforms.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 70b2fd17..b5a24bd6 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -339,8 +339,9 @@ def _calc_R_det_arg(ra, dec, psi, iota, gmst): return jnp.angle(p_mode_term - 1j * c_mode_term) def named_transform(x): + iota = self.get_iota(x) R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], self.get_iota(x), self.gmst + x["ra"], x["dec"], x["psi"], iota, self.gmst ) phase_det = R_det_arg + x["phase_c"] / 2.0 return { @@ -350,8 +351,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): + iota = self.get_iota(x) R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], self.get_iota(x), self.gmst + x["ra"], x["dec"], x["psi"], iota, self.gmst ) phase_c = -R_det_arg + x["phase_det"] * 2.0 return { @@ -443,7 +445,8 @@ def named_transform(x): x["d_L"], x["M_c"], ) - R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], self.get_iota(x)) + iota = self.get_iota(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], iota) scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets d_hat = scale_factor * d_L @@ -465,7 +468,8 @@ def named_inverse_transform(x): x["d_hat_unbounded"], x["M_c"], ) - R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], self.get_iota(x)) + iota = self.get_iota(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], iota) scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets From 6a4fcaae4ad7fd94ad0bfc61a132df785cf50a75 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:55:13 +0800 Subject: [PATCH 43/57] Updated utils.py --- src/jimgw/single_event/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 5741f64b..bc242ad1 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -572,7 +572,7 @@ def rotate_y(angle, vec): cos_angle = jnp.cos(angle) sin_angle = jnp.sin(angle) rotation_matrix = jnp.array( - [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + [[cos_angle, jnp.zeros_like(angle), sin_angle], [jnp.zeros_like(angle), jnp.ones_like(angle), jnp.zeros_like(angle)], [-sin_angle, jnp.zeros_like(angle), cos_angle]] ) rotated_vec = jnp.dot(rotation_matrix, vec) return rotated_vec @@ -584,7 +584,7 @@ def rotate_z(angle, vec): cos_angle = jnp.cos(angle) sin_angle = jnp.sin(angle) rotation_matrix = jnp.array( - [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + [[cos_angle, -sin_angle, jnp.zeros_like(angle)], [sin_angle, cos_angle, jnp.zeros_like(angle)], [jnp.zeros_like(angle), jnp.zeros_like(angle), jnp.ones_like(angle)]] ) rotated_vec = jnp.dot(rotation_matrix, vec) return rotated_vec @@ -613,7 +613,7 @@ def rotate_z(angle, vec): Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) s1 = m1 * m1 * chi1 * s1hat s2 = m2 * m2 * chi2 * s2hat - J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + J = s1 + s2 + jnp.array([jnp.zeros_like(Lmag), jnp.zeros_like(Lmag), Lmag]) Jhat = J / jnp.linalg.norm(J) theta0 = jnp.arccos(Jhat[2]) @@ -625,7 +625,7 @@ def rotate_z(angle, vec): LNh = rotate_z(phiJL - jnp.pi, LNh) # Compute iota - N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + N = jnp.array([jnp.zeros_like(thetaJN), jnp.sin(thetaJN), jnp.cos(thetaJN)]) iota = jnp.arccos(jnp.dot(N, LNh)) return iota From 2c440a348a743475d42dead664e0b80e0afa85b6 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:15:40 +0800 Subject: [PATCH 44/57] Updated jim.py --- src/jimgw/jim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index fae0bc98..71e8aaad 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -111,7 +111,7 @@ def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): key = jax.random.split(key)[1] guess = self.prior.sample(key, 1) for transform in self.sample_transforms: - guess = transform.forward(guess) + guess = jax.vmap(transform.forward)(guess) guess = jnp.array([i for i in guess.values()]).T[0] flag = not jnp.all(jnp.isfinite(guess)) initial_guess.append(guess) From 39c536456532e670deb1aebb41715dda3bbb8ef9 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:55:36 +0800 Subject: [PATCH 45/57] Added run script --- example/GW150914_heterodyne.py | 158 ---------------------------- example/GW191216_D_het.py | 163 +++++++++++++++++++++++++++++ example/GW191216_D_het_reparam.py | 166 ++++++++++++++++++++++++++++++ 3 files changed, 329 insertions(+), 158 deletions(-) delete mode 100644 example/GW150914_heterodyne.py create mode 100644 example/GW191216_D_het.py create mode 100644 example/GW191216_D_het_reparam.py diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py deleted file mode 100644 index c1faed03..00000000 --- a/example/GW150914_heterodyne.py +++ /dev/null @@ -1,158 +0,0 @@ -import time - -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import Composite, Unconstrained_Uniform -from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import ( - HeterodynedTransientLikelihoodFD, - TransientLikelihoodFD, -) -from jimgw.single_event.waveform import RippleIMRPhenomD -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -total_time_start = time.time() - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 -duration = 4 -post_trigger_duration = 2 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration -fmin = 20.0 -fmax = 1024.0 - -ifos = ["H1", "L1"] - -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) -s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( - [ - Mc_prior, - q_prior, - s1z_prior, - s2z_prior, - dL_prior, - t_c_prior, - phase_c_prior, - cos_iota_prior, - psi_prior, - ra_prior, - sin_dec_prior, - ] -) - -bounds = jnp.array( - [ - [10.0, 80.0], - [0.125, 1.0], - [-1.0, 1.0], - [-1.0, 1.0], - [0.0, 2000.0], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) - -likelihood = HeterodynedTransientLikelihoodFD( - [H1, L1], - prior=prior, - bounds=bounds, - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=duration, - post_trigger_duration=post_trigger_duration, - n_steps=3000, -) - -mass_matrix = jnp.eye(11) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -import optax -n_epochs = 20 -n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) - -jim = Jim( - likelihood, - prior, - n_loop_training=n_loop_training, - n_loop_production=20, - n_local_steps=10, - n_global_steps=1000, - n_chains=500, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30000, - n_flow_sample=100000, - momentum=0.9, - batch_size=30000, - use_global=True, - keep_quantile=0.0, - train_thinning=1, - output_thinning=10, - local_sampler_arg=local_sampler_arg, - # strategies=[Adam_optimizer,"default"], -) -jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW191216_D_het.py b/example/GW191216_D_het.py new file mode 100644 index 00000000..7543a0c0 --- /dev/null +++ b/example/GW191216_D_het.py @@ -0,0 +1,163 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1260567236.4 +duration = 16 +post_trigger_duration = duration//2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, V1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=duration*4, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = HeterodynedTransientLikelihoodFD( + ifos, + prior=prior, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW191216_D_het.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW191216_D_het.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/example/GW191216_D_het_reparam.py b/example/GW191216_D_het_reparam.py new file mode 100644 index 00000000..5b084919 --- /dev/null +++ b/example/GW191216_D_het_reparam.py @@ -0,0 +1,166 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1260567236.4 +duration = 16 +post_trigger_duration = duration//2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, V1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=duration*4, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + m_1_prior, + m_2_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + +likelihood = HeterodynedTransientLikelihoodFD( + ifos, + prior=prior, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 100 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_samples=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in reversed(jim.sample_transforms): + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW191216_D_het_reparam.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW191216_D_het_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file From 7457c56dad08704b05ec972d71de40d3174c6b5e Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 26 Aug 2024 13:13:21 +0200 Subject: [PATCH 46/57] Fixing phase inverse transformation --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 084fe368..1070adaf 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -332,7 +332,7 @@ def named_inverse_transform(x): R_det_arg = _calc_R_det_arg( x["ra"], x["dec"], x["psi"], x["iota"], self.gmst ) - phase_c = -R_det_arg + x["phase_det"] * 2.0 + phase_c = (-R_det_arg + x["phase_det"]) * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From ddfafc1c3199c40e9ff224e2633c928f7fe6160a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:14:00 +0800 Subject: [PATCH 47/57] Added run script --- example/GW191216_D_het.py | 2 +- example/GW191216_D_het_reparam.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/GW191216_D_het.py b/example/GW191216_D_het.py index 7543a0c0..b8b9f672 100644 --- a/example/GW191216_D_het.py +++ b/example/GW191216_D_het.py @@ -31,7 +31,7 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=duration*4, tukey_alpha=0.2) -M_c_min, M_c_max = 10.0, 80.0 +M_c_min, M_c_max = 5.0, 15.0 q_min, q_max = 0.125, 1.0 m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) diff --git a/example/GW191216_D_het_reparam.py b/example/GW191216_D_het_reparam.py index 5b084919..f0ead16f 100644 --- a/example/GW191216_D_het_reparam.py +++ b/example/GW191216_D_het_reparam.py @@ -31,7 +31,7 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=duration*4, tukey_alpha=0.2) -M_c_min, M_c_max = 10.0, 80.0 +M_c_min, M_c_max = 5.0, 15.0 q_min, q_max = 0.125, 1.0 m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) From 86605eac7604dd55e4b22205e49e2ba4e623581c Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Mon, 2 Sep 2024 21:49:48 +0200 Subject: [PATCH 48/57] Remove duplicated import --- src/jimgw/single_event/transforms.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 81111384..e3e4795d 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -5,9 +5,6 @@ from jimgw.single_event.detector import GroundBased2G from jimgw.transforms import ( - ConditionalBijectiveTransform, - BijectiveTransform, - NtoNTransform, ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform, From 2fbfc04c7fb6dceed06f1a2eeae89ea76845c1a1 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 2 Sep 2024 21:54:40 +0200 Subject: [PATCH 49/57] Adding the named_Mc_q_to_m1_m2 back --- src/jimgw/single_event/transforms.py | 44 +++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index e3e4795d..200348fe 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -112,10 +112,6 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -def named_m1_m2_to_Mc_q(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - @jaxtyped(typechecker=typechecker) class GeocentricArrivalTimeToDetectorArrivalTimeTransform( @@ -383,28 +379,56 @@ def named_inverse_transform(x): } self.inverse_transform_func = named_inverse_transform - -ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) + +def named_m1_m2_to_Mc_q(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + + +def named_Mc_q_to_m1_m2(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "q"]) +) ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q -ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = ( + named_Mc_q_to_m1_m2 +) + def named_m1_m2_to_Mc_eta(x): Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) return {"M_c": Mc, "eta": eta} + def named_Mc_eta_to_m1_m2(x): m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) -ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta -ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "eta"]) +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = ( + named_m1_m2_to_Mc_eta +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = ( + named_Mc_eta_to_m1_m2 +) + def named_q_to_eta(x): return {"eta": q_to_eta(x["q"])} + + def named_eta_to_q(x): return {"q": eta_to_q(x["eta"])} + + MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q From ce7b3081f5e2124427b91646fd4508217af424d6 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 2 Sep 2024 22:58:33 +0200 Subject: [PATCH 50/57] Hard-code transform name_mapping and conditional_parameters --- src/jimgw/single_event/transforms.py | 32 ++++++---------------------- test/integration/test_extrinsic.py | 8 +++---- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 200348fe..5688786b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -140,13 +140,13 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, tc_min: Float, tc_max: Float, ): + name_mapping = [["t_c"], ["t_det_unbounded"]] + conditional_names = ["ra", "dec"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -156,9 +156,6 @@ def __init__( self.tc_min = tc_min self.tc_max = tc_max - assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] - assert "ra" in conditional_names and "dec" in conditional_names - @jnp.vectorize def time_delay(ra, dec, gmst): return self.ifo.delay_from_geocenter(ra, dec, gmst) @@ -225,11 +222,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): + name_mapping = [["phase_c"], ["phase_det"]] + conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -237,14 +234,6 @@ def __init__( ) self.ifo = ifo - assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] - assert ( - "ra" in conditional_names - and "dec" in conditional_names - and "psi" in conditional_names - and "iota" in conditional_names - ) - @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 @@ -298,13 +287,13 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], dL_min: Float, dL_max: Float, ): + name_mapping = [["d_L"], ["d_hat_unbounded"]] + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -314,15 +303,6 @@ def __init__( self.dL_min = dL_min self.dL_max = dL_max - assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] - assert ( - "ra" in conditional_names - and "dec" in conditional_names - and "psi" in conditional_names - and "iota" in conditional_names - and "M_c" in conditional_names - ) - @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index f0e089fe..300f2132 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,10 +47,10 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), From c4995d755c7d41dfabcda784757559731f360499 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 5 Sep 2024 22:35:21 +0200 Subject: [PATCH 51/57] setting phiRef to 0 for getting the iota --- src/jimgw/single_event/transforms.py | 46 ++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c5c34258..ad91a56a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -243,9 +243,23 @@ def __init__( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names and "M_c" in conditional_names and "q" in conditional_names)) + and ( + "iota" in conditional_names + or ( + "theta_jn" in conditional_names + and "phi_jl" in conditional_names + and "theta_1" in conditional_names + and "theta_2" in conditional_names + and "phi_12" in conditional_names + and "a_1" in conditional_names + and "a_2" in conditional_names + and "q" in conditional_names + and "M_c" in conditional_names + and "q" in conditional_names + ) + ) ) - + if "iota" in conditional_names: self.get_iota = lambda x: x["iota"] else: @@ -260,7 +274,7 @@ def __init__( x["M_c"], x["q"], self.freq_ref, - x["phase_c"], + 0.0, ) @jnp.vectorize @@ -276,9 +290,7 @@ def _calc_R_det_arg(ra, dec, psi, iota, gmst): def named_transform(x): iota = self.get_iota(x) - R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], iota, self.gmst - ) + R_det_arg = _calc_R_det_arg(x["ra"], x["dec"], x["psi"], iota, self.gmst) phase_det = R_det_arg + x["phase_c"] / 2.0 return { "phase_det": phase_det % (2.0 * jnp.pi), @@ -288,9 +300,7 @@ def named_transform(x): def named_inverse_transform(x): iota = self.get_iota(x) - R_det_arg = _calc_R_det_arg( - x["ra"], x["dec"], x["psi"], iota, self.gmst - ) + R_det_arg = _calc_R_det_arg(x["ra"], x["dec"], x["psi"], iota, self.gmst) phase_c = (-R_det_arg + x["phase_det"]) * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), @@ -342,7 +352,19 @@ def __init__( "ra" in conditional_names and "dec" in conditional_names and "psi" in conditional_names - and ("iota" in conditional_names or ("theta_jn" in conditional_names and "phi_jl" in conditional_names and "theta_1" in conditional_names and "theta_2" in conditional_names and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names and "q" in conditional_names and "phase_c" in conditional_names)) + and ( + "iota" in conditional_names + or ( + "theta_jn" in conditional_names + and "phi_jl" in conditional_names + and "theta_1" in conditional_names + and "theta_2" in conditional_names + and "phi_12" in conditional_names + and "a_1" in conditional_names + and "a_2" in conditional_names + and "q" in conditional_names + ) + ) and "M_c" in conditional_names ) @@ -360,9 +382,9 @@ def __init__( x["M_c"], x["q"], self.freq_ref, - x["phase_c"], + 0.0, ) - + @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 From 762b7e0f7738284f8943fb842eab430e709a4a93 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 5 Sep 2024 23:05:02 +0200 Subject: [PATCH 52/57] Remove if-else function --- src/jimgw/single_event/transforms.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index ad91a56a..0af0fe42 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,3 +1,4 @@ +import jax.lax import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -260,10 +261,10 @@ def __init__( ) ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: - self.get_iota = lambda x: spin_to_iota( + self.get_iota = lambda x: jax.lax.cond( + "iota" in conditional_names, + lambda _: x["iota"], + lambda _: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -275,7 +276,9 @@ def __init__( x["q"], self.freq_ref, 0.0, - ) + ), + operand=None, + ) @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -368,10 +371,10 @@ def __init__( and "M_c" in conditional_names ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: - self.get_iota = lambda x: spin_to_iota( + self.get_iota = lambda x: jax.lax.cond( + "iota" in conditional_names, + lambda _: x["iota"], + lambda _: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -383,7 +386,9 @@ def __init__( x["q"], self.freq_ref, 0.0, - ) + ), + operand=None, + ) @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): From 880520554796104d1f0762ad4d839a7182a0c313 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 5 Sep 2024 23:36:45 +0200 Subject: [PATCH 53/57] Revert "Remove if-else function" This reverts commit 762b7e0f7738284f8943fb842eab430e709a4a93. --- src/jimgw/single_event/transforms.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 0af0fe42..ad91a56a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,4 +1,3 @@ -import jax.lax import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import Float, Array, jaxtyped @@ -261,10 +260,10 @@ def __init__( ) ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -276,9 +275,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -371,10 +368,10 @@ def __init__( and "M_c" in conditional_names ) - self.get_iota = lambda x: jax.lax.cond( - "iota" in conditional_names, - lambda _: x["iota"], - lambda _: spin_to_iota( + if "iota" in conditional_names: + self.get_iota = lambda x: x["iota"] + else: + self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], x["theta_1"], @@ -386,9 +383,7 @@ def __init__( x["q"], self.freq_ref, 0.0, - ), - operand=None, - ) + ) @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): From 5ce686cd76adcfe81b61689c61e3dc393827f651 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Thu, 5 Sep 2024 23:44:33 +0200 Subject: [PATCH 54/57] Improve the precession handling for extrinsic transform --- src/jimgw/single_event/transforms.py | 53 ++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index ad91a56a..16f84429 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from beartype import beartype as typechecker -from jaxtyping import Float, Array, jaxtyped +from jaxtyping import Float, Array, jaxtyped, Bool from astropy.time import Time from jimgw.single_event.detector import GroundBased2G @@ -227,9 +227,26 @@ def __init__( gps_time: Float, ifo: GroundBased2G, freq_ref: Float = None, + with_precession: Bool = False, ): name_mapping = [["phase_c"], ["phase_det"]] - conditional_names = ["ra", "dec", "psi", "iota"] + if with_precession: + conditional_names = [ + "ra", + "dec", + "psi", + "theta_jn", + "phi_jl", + "theta_1", + "theta_2", + "phi_12", + "a_1", + "a_2", + "q", + "M_c", + ] + else: + conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -253,16 +270,13 @@ def __init__( and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names - and "q" in conditional_names and "M_c" in conditional_names and "q" in conditional_names ) ) ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: + if with_precession: self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], @@ -276,6 +290,8 @@ def __init__( self.freq_ref, 0.0, ) + else: + self.get_iota = lambda x: x["iota"] @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -334,9 +350,26 @@ def __init__( dL_min: Float, dL_max: Float, freq_ref: Float = None, + with_precession: Bool = False, ): name_mapping = [["d_L"], ["d_hat_unbounded"]] - conditional_names = ["M_c", "ra", "dec", "psi", "iota"] + if with_precession: + conditional_names = [ + "M_c", + "ra", + "dec", + "psi", + "theta_jn", + "phi_jl", + "theta_1", + "theta_2", + "phi_12", + "a_1", + "a_2", + "q", + ] + else: + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -368,9 +401,7 @@ def __init__( and "M_c" in conditional_names ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: + if with_precession: self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], @@ -384,6 +415,8 @@ def __init__( self.freq_ref, 0.0, ) + else: + self.get_iota = lambda x: x["iota"] @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): From de8ca6f72835294da41346b9c0edf89c62260e6f Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 5 Sep 2024 16:33:28 -0400 Subject: [PATCH 55/57] Update GW150914_D script --- example/GW150914_D_reparam.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/example/GW150914_D_reparam.py b/example/GW150914_D_reparam.py index ebf2e895..91b58ff4 100644 --- a/example/GW150914_D_reparam.py +++ b/example/GW150914_D_reparam.py @@ -63,11 +63,11 @@ sample_transforms = [ # all the user reparametrization transform - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + ComponentMassesToChirpMassMassRatioTransform, + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), @@ -81,7 +81,7 @@ ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -154,8 +154,8 @@ figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) plt.savefig("GW1500914_D_reparam.jpeg") -########################################### -############# Save the Run ################ -########################################### -import pickle -pickle.dump(result, open("GW150914_D_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file +############################################ +############## Save the Run ################ +############################################ +#import pickle +#pickle.dump(result, open("GW150914_D_reparam.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) From ce91d2c3376763d9bd0891ce7d529738e918c997 Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 10 Sep 2024 14:05:25 -0400 Subject: [PATCH 56/57] update xample --- example/GW150914_D.py | 37 +++++------------------------------ example/GW150914_D_reparam.py | 4 ++-- src/jimgw/jim.py | 2 ++ 3 files changed, 9 insertions(+), 34 deletions(-) diff --git a/example/GW150914_D.py b/example/GW150914_D.py index 06ac234c..80eef695 100644 --- a/example/GW150914_D.py +++ b/example/GW150914_D.py @@ -62,7 +62,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -72,13 +72,13 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -98,7 +98,7 @@ Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) n_epochs = 30 -n_loop_training = 100 +n_loop_training = 20 learning_rate = 1e-4 @@ -123,36 +123,9 @@ output_thinning=10, local_sampler_arg=local_sampler_arg, strategies=[Adam_optimizer, "default"], + verbose=True ) jim.sample(jax.random.PRNGKey(42)) jim.get_samples() jim.print_summary() - - -########################################### -########## Visualize the Data ############# -########################################### -import corner -import matplotlib.pyplot as plt -import numpy as np - -production_summary = jim.sampler.get_sampler_state(training=False) -production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T -if jim.sample_transforms: - transformed_chain = jim.add_name(production_chain) - for transform in reversed(jim.sample_transforms): - transformed_chain = transform.backward(transformed_chain) -result = transformed_chain -labels = list(transformed_chain.keys()) - -samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array -transposed_array = samples.T # transpose the array -figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) -plt.savefig("GW1500914_D.jpeg") - -########################################### -############# Save the Run ################ -########################################### -import pickle -pickle.dump(result, open("GW150914_D.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/example/GW150914_D_reparam.py b/example/GW150914_D_reparam.py index 91b58ff4..2877f6f7 100644 --- a/example/GW150914_D_reparam.py +++ b/example/GW150914_D_reparam.py @@ -129,8 +129,8 @@ ) jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() -jim.print_summary() +#jim.get_samples() +#jim.print_summary() ########################################### diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 71e8aaad..1fafc024 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,6 +104,7 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: + print("Initial guess not provided, sampling initial guess") initial_guess = [] for _ in range(self.sampler.n_chains): flag = True @@ -116,6 +117,7 @@ def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): flag = not jnp.all(jnp.isfinite(guess)) initial_guess.append(guess) initial_position = jnp.array(initial_guess) + print("Starting sample") self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( From 94ea9fd405a0e62e69ba599c6b9901c8206dcf9e Mon Sep 17 00:00:00 2001 From: kazewong Date: Tue, 10 Sep 2024 14:09:06 -0400 Subject: [PATCH 57/57] fix typing mistake --- example/GW150914_D.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/example/GW150914_D.py b/example/GW150914_D.py index 80eef695..ebfab90d 100644 --- a/example/GW150914_D.py +++ b/example/GW150914_D.py @@ -63,18 +63,18 @@ sample_transforms = [ ComponentMassesToChirpMassMassRatioTransform, - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), - BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [