From f33a7824f03de5d602b22cfbf2ec93b46e278316 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 15:49:26 +0200 Subject: [PATCH 01/31] Adding transform from geocentric arrival time to detector arrival time --- src/jimgw/single_event/transforms.py | 54 ++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c3e77846..b1735f5a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -170,6 +170,60 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): + """ + Transform the geocentric arrival time to detector arrival time + + In the geocentric convention, the arrival time of the signal at the + center of Earth is gps_time + t_c + + In the detector convention, the arrival time of the signal at the + detecotr is gps_time + time_delay_from_geo_to_det + t_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + t_c: Float + t_det: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + + def named_transform(x): + t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_det": t_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + return { + "t_c": t_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From 3505394c6e1f974134ecf59e326507c376c54621 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 12 Aug 2024 16:01:11 +0200 Subject: [PATCH 02/31] Adding transform from distance to SNR weighted distance --- src/jimgw/single_event/transforms.py | 88 +++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index b1735f5a..6404fb7e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,9 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_det = x["t_c"] + self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -216,7 +218,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + t_c = x["t_det"] - self.ifo.delay_from_geocenter( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -224,6 +228,86 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): + """ + Transform the luminosity distance to network SNR weighted distance + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifos: list[GroundBased2G], + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifos = ifos + + assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + + def named_transform(x): + d_L, M_c, ra, dec, psi, iota = ( + x["d_L"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + return { + "d_hat": d_hat, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + d_hat, M_c, ra, dec, psi, iota = ( + x["d_hat"], + x["M_c"], + x["ra"], + x["dec"], + x["psi"], + x["iota"], + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_ks2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_ks2 += p_mode_term**2 + c_mode_term**2 + R_ks = jnp.sqrt(R_ks2) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + return { + "d_L": d_L, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SpinToCartesianSpinTransform(NtoNTransform): """ From df75cebc636eec563bf361e10c476465c161fe38 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 10:12:45 +0200 Subject: [PATCH 03/31] updating the typing for object attributes --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6404fb7e..e1a49e58 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -189,8 +189,7 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): """ gmst: Float - t_c: Float - t_det: Float + ifo: GroundBased2G def __init__( self, @@ -241,6 +240,7 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ gmst: Float + ifos: list[GroundBased2G] def __init__( self, From 9f2f52b323bd03923043a0be4203030229225f71 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:35:34 +0200 Subject: [PATCH 04/31] Adding geocentric phase to detector phase --- src/jimgw/single_event/transforms.py | 106 +++++++++++++++++++++------ 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index e1a49e58..bf060b8a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -227,6 +227,72 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): + """ + Transform the geocentric arrival phase to detector arrival phase + + In the geocentric convention, the arrival phase of the signal at the + center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + + In the detector convention, the arrival phase of the signal at the + detecotr is phi_det = phi_c / 2 + arg R_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifo: GroundBased2G + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + gps_time: Float, + ifo: GroundBased2G, + ): + super().__init__(name_mapping) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + + def _calc_R_det(x): + ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + + return p_mode_term - 1j * c_mode_term + + def named_transform(x): + R_det = _calc_R_det(x) + phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + return { + "phi_det": phi_det, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + R_det = _calc_R_det(x) + phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + return { + "phi_c": phi_c, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): """ @@ -257,10 +323,8 @@ def __init__( assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] - def named_transform(x): - d_L, M_c, ra, dec, psi, iota = ( - x["d_L"], - x["M_c"], + def _calc_R_dets(x): + ra, dec, psi, iota = ( x["ra"], x["dec"], x["psi"], @@ -268,14 +332,22 @@ def named_transform(x): ) p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 + R_dets2 = 0.0 for ifo in self.ifos: antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_ks + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + + def named_transform(x): + d_L, M_c = ( + x["d_L"], + x["M_c"], + ) + R_dets = _calc_R_dets(x) + d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, } @@ -283,24 +355,12 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c, ra, dec, psi, iota = ( + d_hat, M_c = ( x["d_hat"], x["M_c"], - x["ra"], - x["dec"], - x["psi"], - x["iota"], ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_ks2 = 0.0 - for ifo in self.ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_ks2 += p_mode_term**2 + c_mode_term**2 - R_ks = jnp.sqrt(R_ks2) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_ks + R_dets = _calc_R_dets(x) + d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, } From b62970f2a8f05bae53408870763aa115ea957fdc Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 11:51:22 +0200 Subject: [PATCH 05/31] Adding ZeroLikelihood for testing purpose --- src/jimgw/single_event/likelihood.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..9e775b33 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -26,6 +26,15 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: self.waveform = waveform +class ZeroLikelihood(LikelihoodBase): + + def __init__(self): + pass + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + return 0.0 + + class TransientLikelihoodFD(SingleEventLiklihood): def __init__( self, From 4ea332224fa35363961220b1dffccc167599a6c3 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 12:09:47 +0200 Subject: [PATCH 06/31] Adding the missing mode 2pi for phasing transform --- src/jimgw/single_event/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index bf060b8a..6f4f361a 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -278,7 +278,7 @@ def named_transform(x): R_det = _calc_R_det(x) phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 return { - "phi_det": phi_det, + "phi_det": phi_det % (2. * jnp.pi), } self.transform_func = named_transform @@ -287,7 +287,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 return { - "phi_c": phi_c, + "phi_c": phi_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 7a4bae0e8d06735458b6f1e004258d1afb15aac2 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:38:23 +0200 Subject: [PATCH 07/31] Test wip --- test/integration/test_extrinsic.py | 108 ++++++++++++++++++ .../integration/test_extrinsic_no_distance.py | 90 +++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 test/integration/test_extrinsic.py create mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py new file mode 100644 index 00000000..c4719dd8 --- /dev/null +++ b/test/integration/test_extrinsic.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + M_c_prior, + q_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py new file mode 100644 index 00000000..d1b1e559 --- /dev/null +++ b/test/integration/test_extrinsic_no_distance.py @@ -0,0 +1,90 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + t_c_prior, + phase_c_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # all the user reparametrization transform + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), +] + +likelihood_transforms = [] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(9) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], +) + +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() From d5f86e52156c788a06d84c1ac5d0958ae4d0b8fb Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 13 Aug 2024 13:42:45 +0200 Subject: [PATCH 08/31] Phase renaming --- src/jimgw/single_event/transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 6f4f361a..1242b40b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -233,10 +233,10 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): Transform the geocentric arrival phase to detector arrival phase In the geocentric convention, the arrival phase of the signal at the - center of Earth is phi_c / 2 (in ripple, phi_c is the orbital phase) + center of Earth is phase_c / 2 (in ripple, phase_c is the orbital phase) In the detector convention, the arrival phase of the signal at the - detecotr is phi_det = phi_c / 2 + arg R_det + detecotr is phase_det = phase_c / 2 + arg R_det Parameters ---------- @@ -261,7 +261,7 @@ def __init__( ) self.ifo = ifo - assert "phi_c" in name_mapping[0] and "phi_det" in name_mapping[1] + assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] @@ -276,18 +276,18 @@ def _calc_R_det(x): def named_transform(x): R_det = _calc_R_det(x) - phi_det = jnp.angle(R_det) + x["phi_c"] / 2.0 + phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phi_det": phi_det % (2. * jnp.pi), + "phase_det": phase_det % (2. * jnp.pi), } self.transform_func = named_transform def named_inverse_transform(x): R_det = _calc_R_det(x) - phi_c = (-jnp.angle(R_det) + x["phi_det"]) * 2.0 + phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phi_c": phi_c % (2. * jnp.pi), + "phase_c": phase_c % (2. * jnp.pi), } self.inverse_transform_func = named_inverse_transform From 0a2e68c22611fea72a2643bbea8ab358c67c3650 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Tue, 13 Aug 2024 06:09:13 -0700 Subject: [PATCH 09/31] wip --- src/jimgw/single_event/transforms.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1242b40b..538d6b8e 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -208,7 +208,7 @@ def __init__( def named_transform(x): t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_det": t_det, @@ -217,8 +217,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): + import pdb; pdb.set_trace() t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"], x["dec"], self.gmst + x["ra"][0], x["dec"][0], self.gmst ) return { "t_c": t_c, @@ -268,7 +269,7 @@ def _calc_R_det(x): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] @@ -278,7 +279,7 @@ def named_transform(x): R_det = _calc_R_det(x) phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 return { - "phase_det": phase_det % (2. * jnp.pi), + "phase_det": phase_det % (2.0 * jnp.pi), } self.transform_func = named_transform @@ -287,7 +288,7 @@ def named_inverse_transform(x): R_det = _calc_R_det(x) phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 return { - "phase_c": phase_c % (2. * jnp.pi), + "phase_c": phase_c % (2.0 * jnp.pi), } self.inverse_transform_func = named_inverse_transform From b96512c64139661b38fd7df1f2506d480c527ebb Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 14 Aug 2024 13:30:43 -0400 Subject: [PATCH 10/31] Push conditional bijective transform --- src/jimgw/transforms.py | 54 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..4d2ebb45 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -61,8 +61,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: class NtoNTransform(NtoMTransform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property def n_dim(self) -> int: return len(self.name_mapping[0]) @@ -162,6 +160,58 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy + +class ConditionalBijectiveTransform(BijectiveTransform): + + conditional_names: list[str] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], + ): + super().__init__(name_mapping) + self.conditional_names = conditional_names + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + transform_params.update( + dict((key, x_copy[key]) for key in self.conditional_names) + ) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + transform_params.update( + dict((key, y_copy[key]) for key in self.conditional_names) + ) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian = jnp.array(jax.tree.leaves(jacobian)) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian @jaxtyped(typechecker=typechecker) From 526e33cbb8eae4c32782113cd2a7495d657e5ea9 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:44:04 -0700 Subject: [PATCH 11/31] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 46 ++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 538d6b8e..2fb757c2 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,7 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +171,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival time to detector arrival time @@ -194,10 +194,11 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -205,11 +206,22 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + ) + + def _calc_delay(x): + ra, dec = x["ra"], x["dec"] + if hasattr(ra, "shape") and len(ra.shape) > 0: + delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) + else: + delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) + return delay def named_transform(x): - t_det = x["t_c"] + self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_det = x["t_c"] + delay return { "t_det": t_det, } @@ -217,10 +229,8 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - import pdb; pdb.set_trace() - t_c = x["t_det"] - self.ifo.delay_from_geocenter( - x["ra"][0], x["dec"][0], self.gmst - ) + delay = _calc_delay(x) + t_c = x["t_det"] - delay return { "t_c": t_c, } @@ -229,7 +239,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): """ Transform the geocentric arrival phase to detector arrival phase @@ -252,10 +262,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -263,13 +274,22 @@ def __init__( self.ifo = ifo assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + ) def _calc_R_det(x): ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + if hasattr(ra, "shape") and len(ra.shape) > 0: + antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + else: + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] From dbf3f3064e5135575a5c7548ee9e1c5e9157553b Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:45:40 -0700 Subject: [PATCH 12/31] Switch to using conditional transform --- src/jimgw/single_event/transforms.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 2fb757c2..dfcd01b8 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,11 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform +from jimgw.transforms import ( + ConditionalBijectiveTransform, + BijectiveTransform, + NtoNTransform, +) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -171,7 +175,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalTimeToDetectorArrivalTimeTransform(ConditionalBijectiveTransform): +class GeocentricArrivalTimeToDetectorArrivalTimeTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival time to detector arrival time @@ -206,10 +212,7 @@ def __init__( self.ifo = ifo assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] - assert ( - "ra" in conditional_names - and "dec" in conditional_names - ) + assert "ra" in conditional_names and "dec" in conditional_names def _calc_delay(x): ra, dec = x["ra"], x["dec"] @@ -239,7 +242,9 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(ConditionalBijectiveTransform): +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( + ConditionalBijectiveTransform +): """ Transform the geocentric arrival phase to detector arrival phase @@ -287,7 +292,9 @@ def _calc_R_det(x): c_iota_term = jnp.cos(iota) if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern(ra[0], dec[0], psi[0], self.gmst) + antenna_pattern = self.ifo.antenna_pattern( + ra[0], dec[0], psi[0], self.gmst + ) else: antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) p_mode_term = p_iota_term * antenna_pattern["p"] From a3753612926dd2b95b440bbc14043cdf7412df29 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Wed, 14 Aug 2024 12:46:04 -0700 Subject: [PATCH 13/31] Fixing jacobian handling --- src/jimgw/transforms.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 4d2ebb45..a26ad9af 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -160,7 +160,8 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: list(output_params.keys()), ) return y_copy - + + class ConditionalBijectiveTransform(BijectiveTransform): conditional_names: list[str] @@ -181,8 +182,14 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[0]} + for key1 in self.name_mapping[1] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -192,7 +199,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: list(output_params.keys()), ) return x_copy, jacobian - + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) @@ -201,8 +208,14 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: ) output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) - jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[1]} + for key1 in self.name_mapping[0] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From d79af97d25e29b41a3f6690bd9df4db6ef46d54f Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 02:56:38 -0700 Subject: [PATCH 14/31] Both arrival phase and time transform are fully vectorized --- src/jimgw/single_event/transforms.py | 43 ++++++++++++---------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index dfcd01b8..047c0607 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -214,17 +214,10 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names - def _calc_delay(x): - ra, dec = x["ra"], x["dec"] - if hasattr(ra, "shape") and len(ra.shape) > 0: - delay = self.ifo.delay_from_geocenter(ra[0], dec[0], self.gmst) - else: - delay = self.ifo.delay_from_geocenter(ra, dec, self.gmst) - return delay - def named_transform(x): - delay = _calc_delay(x) - t_det = x["t_c"] + delay + t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_det": t_det, } @@ -232,8 +225,9 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - delay = _calc_delay(x) - t_c = x["t_det"] - delay + t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + x["ra"], x["dec"], self.gmst + ) return { "t_c": t_c, } @@ -286,25 +280,22 @@ def __init__( and "iota" in conditional_names ) - def _calc_R_det(x): - ra, dec, psi, iota = x["ra"], x["dec"], x["psi"], x["iota"] + @jnp.vectorize + def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) - if hasattr(ra, "shape") and len(ra.shape) > 0: - antenna_pattern = self.ifo.antenna_pattern( - ra[0], dec[0], psi[0], self.gmst - ) - else: - antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, self.gmst) + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, gmst) p_mode_term = p_iota_term * antenna_pattern["p"] c_mode_term = c_iota_term * antenna_pattern["c"] - return p_mode_term - 1j * c_mode_term + return jnp.angle(p_mode_term - 1j * c_mode_term) def named_transform(x): - R_det = _calc_R_det(x) - phase_det = jnp.angle(R_det) + x["phase_c"] / 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_det = R_det_arg + x["phase_c"] / 2.0 return { "phase_det": phase_det % (2.0 * jnp.pi), } @@ -312,8 +303,10 @@ def named_transform(x): self.transform_func = named_transform def named_inverse_transform(x): - R_det = _calc_R_det(x) - phase_c = (-jnp.angle(R_det) + x["phase_det"]) * 2.0 + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_c = (-R_det_arg + x["phase_det"]) * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From bcbcbe2e8d6e1d565544cadf9b9ea8e333602cd0 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:26 -0700 Subject: [PATCH 15/31] Shifting distance transform to conditional --- src/jimgw/single_event/transforms.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 047c0607..4b6ada1f 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -315,7 +315,7 @@ def named_inverse_transform(x): @jaxtyped(typechecker=typechecker) -class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): +class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): """ Transform the luminosity distance to network SNR weighted distance @@ -332,10 +332,11 @@ class DistanceToSNRWeightedDistanceTransform(BijectiveTransform): def __init__( self, name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], ): - super().__init__(name_mapping) + super().__init__(name_mapping, conditional_names) self.gmst = ( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad @@ -343,14 +344,16 @@ def __init__( self.ifos = ifos assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + and "M_c" in conditional_names + ) - def _calc_R_dets(x): - ra, dec, psi, iota = ( - x["ra"], - x["dec"], - x["psi"], - x["iota"], - ) + @jnp.vectorize + def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) R_dets2 = 0.0 @@ -367,7 +370,7 @@ def named_transform(x): x["d_L"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets return { "d_hat": d_hat, @@ -380,7 +383,7 @@ def named_inverse_transform(x): x["d_hat"], x["M_c"], ) - R_dets = _calc_R_dets(x) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets return { "d_L": d_L, From 8dab27b6e259063f03db8f381e8613ca7e34a9d4 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 03:56:48 -0700 Subject: [PATCH 16/31] update example --- test/integration/test_extrinsic.py | 62 ++++++++++--- .../integration/test_extrinsic_no_distance.py | 90 ------------------- 2 files changed, 48 insertions(+), 104 deletions(-) delete mode 100644 test/integration/test_extrinsic_no_distance.py diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index c4719dd8..988f6ea7 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,3 +1,12 @@ +import psutil +p = psutil.Process() +p.cpu_affinity([0]) + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from astropy.time import Time + import jax import jax.numpy as jnp @@ -21,7 +30,6 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1.0, parameter_names=["q"]) dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) @@ -33,7 +41,6 @@ prior = CombinePrior( [ M_c_prior, - q_prior, dL_prior, t_c_prior, phase_c_prior, @@ -44,31 +51,56 @@ ] ) +# calculate the d_hat range +@jnp.vectorize +def calc_R_dets(ra, dec, psi, iota): + gmst = ( + Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad + ) + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_dets2 = 0.0 + for ifo in ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + +key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) +# generate 10000 samples for each +ra_samples = ra_prior.sample(key1, 10000)["ra"] +dec_samples = dec_prior.sample(key2, 10000)["dec"] +psi_samples = psi_prior.sample(key3, 10000)["psi"] +iota_samples = iota_prior.sample(key4, 10000)["iota"] +R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) + +d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) +d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) + sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], gps_time=gps, ifos=ifos), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=2.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), + BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), ] -likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), -] +likelihood_transforms = [] likelihood = ZeroLikelihood() -mass_matrix = jnp.eye(9) +mass_matrix = jnp.eye(len(prior.base_prior)) #mass_matrix = mass_matrix.at[1, 1].set(1e-3) #mass_matrix = mass_matrix.at[5, 5].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -100,9 +132,11 @@ train_thinning=1, output_thinning=1, local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], + strategies=["default"], ) -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() +print("Start sampling") +key = jax.random.PRNGKey(42) +jim.sample(key) jim.print_summary() +samples = jim.get_samples() diff --git a/test/integration/test_extrinsic_no_distance.py b/test/integration/test_extrinsic_no_distance.py deleted file mode 100644 index d1b1e559..00000000 --- a/test/integration/test_extrinsic_no_distance.py +++ /dev/null @@ -1,90 +0,0 @@ -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior -from jimgw.single_event.detector import H1, L1, V1 -from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 - -ifos = [H1, L1, V1] - -t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) -phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) -dec_prior = CosinePrior(parameter_names=["dec"]) - -prior = CombinePrior( - [ - t_c_prior, - phase_c_prior, - ra_prior, - dec_prior, - ] -) - -sample_transforms = [ - # all the user reparametrization transform - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c", "phase_det"]], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c", "t_det"]], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), - # all the bound to unbound transform - BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), -] - -likelihood_transforms = [] - -likelihood = ZeroLikelihood() - -mass_matrix = jnp.eye(9) -#mass_matrix = mass_matrix.at[1, 1].set(1e-3) -#mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) - -n_epochs = 2 -n_loop_training = 1 -learning_rate = 1e-4 - - -jim = Jim( - likelihood, - prior, - sample_transforms=sample_transforms, - likelihood_transforms=likelihood_transforms, - n_loop_training=n_loop_training, - n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30, - n_flow_samples=100, - momentum=0.9, - batch_size=100, - use_global=True, - train_thinning=1, - output_thinning=1, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], -) - -jim.sample(jax.random.PRNGKey(42)) -jim.get_samples() -jim.print_summary() From fd338825bd10557c67da5b55bac9dc59a75db379 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Fri, 16 Aug 2024 14:17:17 -0700 Subject: [PATCH 17/31] Fixing the single sided unbound transform --- src/jimgw/transforms.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index a26ad9af..8b4e2d75 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -418,17 +418,26 @@ class SingleSidedUnboundTransform(BijectiveTransform): """ + original_lower_bound: Float + def __init__( self, name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, ): super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.transform_func = lambda x: { - name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + name_mapping[1][i]: jnp.log( + x[name_mapping[0][i]] - self.original_lower_bound[i] + ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + name_mapping[0][i]: jnp.exp( + x[name_mapping[1][i]] + self.original_lower_bound[i] + ) for i in range(len(name_mapping[1])) } From 03e76dc1731c793039504b6bf709da79953e17dc Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Sat, 17 Aug 2024 01:52:18 -0700 Subject: [PATCH 18/31] Update extrinsic test --- test/integration/test_extrinsic.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 988f6ea7..7cb5bf32 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -1,10 +1,3 @@ -import psutil -p = psutil.Process() -p.cpu_affinity([0]) - -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - from astropy.time import Time import jax @@ -14,7 +7,7 @@ from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1, V1 from jimgw.single_event.likelihood import ZeroLikelihood -from jimgw.transforms import BoundToUnbound +from jimgw.transforms import BoundToUnbound, SingleSidedUnboundTransform from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform from flowMC.strategy.optimization import optimization_Adam @@ -30,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -93,7 +86,7 @@ def calc_R_dets(ra, dec, psi, iota): BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - BoundToUnbound(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=d_hat_min, original_upper_bound=d_hat_max), + SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] @@ -119,9 +112,9 @@ def calc_R_dets(ra, dec, psi, iota): likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, + n_local_steps=1, + n_global_steps=1, + n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30, From 8fe4b5fdcf43c3a2c6430ba868aeddb3e96bbc72 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 14:52:39 +0200 Subject: [PATCH 19/31] bugfix for single sided transform --- src/jimgw/transforms.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 8b4e2d75..f7d4c702 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -435,9 +435,8 @@ def __init__( for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.exp( - x[name_mapping[1][i]] + self.original_lower_bound[i] - ) + name_mapping[0][i]: jnp.exp(x[name_mapping[1][i]]) + + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } From a19b556a0db3565b141b7508ddf85f3e4250e8ef Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 05:59:11 -0700 Subject: [PATCH 20/31] Update test --- test/integration/test_extrinsic.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 7cb5bf32..55979402 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -23,7 +23,7 @@ ifos = [H1, L1, V1] M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -dL_prior = PowerLawPrior(10.0, 200.0, -2.0, parameter_names=["d_L"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) iota_prior = SinePrior(parameter_names=["iota"]) @@ -44,33 +44,7 @@ ] ) -# calculate the d_hat range -@jnp.vectorize -def calc_R_dets(ra, dec, psi, iota): - gmst = ( - Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad - ) - p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 - c_iota_term = jnp.cos(iota) - R_dets2 = 0.0 - for ifo in ifos: - antenna_pattern = ifo.antenna_pattern(ra, dec, psi, gmst) - p_mode_term = p_iota_term * antenna_pattern["p"] - c_mode_term = c_iota_term * antenna_pattern["c"] - R_dets2 += p_mode_term**2 + c_mode_term**2 - - return jnp.sqrt(R_dets2) - -key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(1234), 4) -# generate 10000 samples for each -ra_samples = ra_prior.sample(key1, 10000)["ra"] -dec_samples = dec_prior.sample(key2, 10000)["dec"] -psi_samples = psi_prior.sample(key3, 10000)["psi"] -iota_samples = iota_prior.sample(key4, 10000)["iota"] -R_dets_samples = calc_R_dets(ra_samples, dec_samples, psi_samples, iota_samples) - d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) -d_hat_max = dL_prior.xmax / jnp.power(M_c_prior.xmin, 5. / 6.) / jnp.amin(R_dets_samples) sample_transforms = [ # all the user reparametrization transform From 6d2cd9704c10a8345aebfde506ed8bcbeb3e83fe Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 16:25:39 +0200 Subject: [PATCH 21/31] update distance transform --- src/jimgw/single_event/transforms.py | 36 +++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 4b6ada1f..5d8e688b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -328,6 +328,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] + d_L_min: Float + d_L_max: Float def __init__( self, @@ -335,6 +337,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], + d_L_min: Float, + d_L_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -342,8 +346,10 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos + self.d_L_min = d_L_min + self.d_L_max = d_L_max - assert "d_L" in name_mapping[0] and "d_hat" in name_mapping[1] + assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( "ra" in conditional_names and "dec" in conditional_names @@ -371,20 +377,38 @@ def named_transform(x): x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_hat = d_L / jnp.power(M_c, 5.0 / 6.0) / R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + d_hat = scale_factor * d_L + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) + d_hat_unbounded = jnp.log(y / (1.0 - y)) + return { - "d_hat": d_hat, + "d_hat_unbounded": d_hat_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - d_hat, M_c = ( - x["d_hat"], + d_hat_unbounded, M_c = ( + x["d_hat_unbounded"], x["M_c"], ) R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) - d_L = d_hat * jnp.power(M_c, 5.0 / 6.0) * R_dets + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + + d_hat_min = scale_factor * self.d_L_min + d_hat_max = scale_factor * self.d_L_max + + d_hat = (d_hat_max - d_hat_min) / ( + 1.0 + jnp.exp(-d_hat_unbounded) + ) + d_hat_min + d_L = d_hat / scale_factor return { "d_L": d_L, } From 6993dd9072780203bf99046343f62328c63c7848 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 08:58:57 -0700 Subject: [PATCH 22/31] Update test --- test/integration/test_extrinsic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index 55979402..a5dc5c7b 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -44,11 +44,10 @@ ] ) -d_hat_min = dL_prior.xmin / jnp.power(M_c_prior.xmax, 5. / 6.) sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), @@ -60,7 +59,6 @@ BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), - SingleSidedUnboundTransform(name_mapping = [["d_hat"], ["d_hat_unbounded"]], original_lower_bound=float(d_hat_min)), ] likelihood_transforms = [] From ff65fcf2e75a0655c8e33ddfb32c378c84b3dc8e Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 21:58:57 +0200 Subject: [PATCH 23/31] Update arrival time transform --- src/jimgw/single_event/transforms.py | 58 ++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 5d8e688b..1070adaf 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -196,6 +196,8 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( gmst: Float ifo: GroundBased2G + tc_min: Float + tc_max: Float def __init__( self, @@ -203,6 +205,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, + tc_min: Float, + tc_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -210,24 +214,46 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifo = ifo + self.tc_min = tc_min + self.tc_max = tc_max - assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1] + assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names + @jnp.vectorize + def time_delay(ra, dec, gmst): + return self.ifo.delay_from_geocenter(ra, dec, gmst) + def named_transform(x): - t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)( - x["ra"], x["dec"], self.gmst - ) + + time_shift = time_delay(x["ra"], x["dec"], self.gmst) + + t_det = x["t_c"] + time_shift + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + + y = (t_det - t_det_min) / (t_det_max - t_det_min) + t_det_unbounded = jnp.log(y / (1.0 - y)) return { - "t_det": t_det, + "t_det_unbounded": t_det_unbounded, } self.transform_func = named_transform def named_inverse_transform(x): - t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)( + + time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)( x["ra"], x["dec"], self.gmst ) + + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + t_det = (t_det_max - t_det_min) / ( + 1.0 + jnp.exp(-x["t_det_unbounded"]) + ) + t_det_min + + t_c = t_det - time_shift + return { "t_c": t_c, } @@ -328,8 +354,8 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): gmst: Float ifos: list[GroundBased2G] - d_L_min: Float - d_L_max: Float + dL_min: Float + dL_max: Float def __init__( self, @@ -337,8 +363,8 @@ def __init__( conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], - d_L_min: Float, - d_L_max: Float, + dL_min: Float, + dL_max: Float, ): super().__init__(name_mapping, conditional_names) @@ -346,8 +372,8 @@ def __init__( Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad ) self.ifos = ifos - self.d_L_min = d_L_min - self.d_L_max = d_L_max + self.dL_min = dL_min + self.dL_max = dL_max assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] assert ( @@ -381,8 +407,8 @@ def named_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets d_hat = scale_factor * d_L - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) d_hat_unbounded = jnp.log(y / (1.0 - y)) @@ -402,8 +428,8 @@ def named_inverse_transform(x): scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets - d_hat_min = scale_factor * self.d_L_min - d_hat_max = scale_factor * self.d_L_max + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max d_hat = (d_hat_max - d_hat_min) / ( 1.0 + jnp.exp(-d_hat_unbounded) From b98d783db7b16a14c91ca83e7e3cd606697682d7 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Mon, 19 Aug 2024 13:04:47 -0700 Subject: [PATCH 24/31] Update test --- test/integration/test_extrinsic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index a5dc5c7b..ff79723e 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -49,7 +49,7 @@ # all the user reparametrization transform DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det"]], conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), @@ -58,7 +58,6 @@ BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["t_det"], ["t_det_unbounded"]], original_lower_bound=-0.1, original_upper_bound=0.1), ] likelihood_transforms = [] @@ -84,8 +83,8 @@ likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=1, - n_local_steps=1, - n_global_steps=1, + n_local_steps=2, + n_global_steps=2, n_chains=10, n_epochs=n_epochs, learning_rate=learning_rate, From e399a5e9c7609fe1f07ba6951182e5899ac4b692 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Aug 2024 22:18:31 +0200 Subject: [PATCH 25/31] Fix typo --- test/integration/test_extrinsic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index ff79723e..f0e089fe 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,7 +47,7 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, d_L_min=dL_prior.xmin, d_L_max=dL_prior.xmax), + DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), From 583b759de0859d0d2096bc95d7cccfa1d0b6b824 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Aug 2024 00:43:21 +0200 Subject: [PATCH 26/31] Fix typo --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1070adaf..084fe368 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -332,7 +332,7 @@ def named_inverse_transform(x): R_det_arg = _calc_R_det_arg( x["ra"], x["dec"], x["psi"], x["iota"], self.gmst ) - phase_c = (-R_det_arg + x["phase_det"]) * 2.0 + phase_c = -R_det_arg + x["phase_det"] * 2.0 return { "phase_c": phase_c % (2.0 * jnp.pi), } From 3a5cbf71522f21ba4f5ead4508f8f4e1dedbd9f5 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 18 Sep 2024 21:46:38 +0200 Subject: [PATCH 27/31] remove duplicated SpinToCartesianSpinTransform --- src/jimgw/single_event/transforms.py | 106 +++++++++------------------ test/integration/test_extrinsic.py | 8 +- 2 files changed, 37 insertions(+), 77 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index da7864ba..62a09fd6 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -140,13 +140,13 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, tc_min: Float, tc_max: Float, ): + name_mapping = [["t_c"], ["t_det_unbounded"]] + conditional_names = ["ra", "dec"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -225,11 +225,11 @@ class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifo: GroundBased2G, ): + name_mapping = [["phase_c"], ["phase_det"]] + conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -305,6 +305,8 @@ def __init__( dL_min: Float, dL_max: Float, ): + name_mapping = [["d_L"], ["d_hat_unbounded"]] + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -381,96 +383,54 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -@jaxtyped(typechecker=typechecker) -class SpinToCartesianSpinTransform(NtoNTransform): - """ - Spin to Cartesian spin transformation - """ - - freq_ref: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - freq_ref: Float, - ): - super().__init__(name_mapping) - - self.freq_ref = freq_ref - - assert ( - "theta_jn" in name_mapping[0] - and "phi_jl" in name_mapping[0] - and "theta_1" in name_mapping[0] - and "theta_2" in name_mapping[0] - and "phi_12" in name_mapping[0] - and "a_1" in name_mapping[0] - and "a_2" in name_mapping[0] - and "iota" in name_mapping[1] - and "s1_x" in name_mapping[1] - and "s1_y" in name_mapping[1] - and "s1_z" in name_mapping[1] - and "s2_x" in name_mapping[1] - and "s2_y" in name_mapping[1] - and "s2_z" in name_mapping[1] - ) - - def named_transform(x): - iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], - x["phi_jl"], - x["theta_1"], - x["theta_2"], - x["phi_12"], - x["a_1"], - x["a_2"], - x["M_c"], - x["q"], - self.freq_ref, - x["phase_c"], - ) - return { - "iota": iota, - "s1_x": s1x, - "s1_y": s1y, - "s1_z": s1z, - "s2_x": s2x, - "s2_y": s2y, - "s2_z": s2z, - } - - self.transform_func = named_transform - - def named_m1_m2_to_Mc_q(x): Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - + return {"M_c": Mc, "q": q} + + def named_Mc_q_to_m1_m2(x): m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "q"]) +) ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q -ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = ( + named_Mc_q_to_m1_m2 +) + def named_m1_m2_to_Mc_eta(x): Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) return {"M_c": Mc, "eta": eta} + def named_Mc_eta_to_m1_m2(x): m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) -ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta -ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "eta"]) +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = ( + named_m1_m2_to_Mc_eta +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = ( + named_Mc_eta_to_m1_m2 +) + def named_q_to_eta(x): return {"eta": q_to_eta(x["q"])} + + def named_eta_to_q(x): return {"q": eta_to_q(x["eta"])} + + MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py index f0e089fe..300f2132 100644 --- a/test/integration/test_extrinsic.py +++ b/test/integration/test_extrinsic.py @@ -47,10 +47,10 @@ sample_transforms = [ # all the user reparametrization transform - DistanceToSNRWeightedDistanceTransform(name_mapping=[["d_L"], ["d_hat_unbounded"]], conditional_names=["M_c","ra", "dec", "psi", "iota"], gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(name_mapping = [["phase_c"], ["phase_det"]], conditional_names=["ra", "dec", "psi", "iota"], gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(name_mapping = [["t_c"], ["t_det_unbounded"]], tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, conditional_names=["ra", "dec"], gps_time=gps, ifo=ifos[0]), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), # all the bound to unbound transform BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), From c25048f79a3968c1527cbf9aea33357f319eb30d Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 18 Sep 2024 22:07:05 +0200 Subject: [PATCH 28/31] Minor fix --- src/jimgw/single_event/transforms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 62a09fd6..d325590b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -298,8 +298,6 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): def __init__( self, - name_mapping: tuple[list[str], list[str]], - conditional_names: list[str], gps_time: Float, ifos: list[GroundBased2G], dL_min: Float, From 482296291f3f80449b814d0131b19ef3ba5ec52d Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 20 Sep 2024 12:25:15 -0400 Subject: [PATCH 29/31] replace vectorize with vmap --- src/jimgw/jim.py | 41 ++++++++++++++++++++-------- src/jimgw/single_event/transforms.py | 13 +++------ 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2f0086ac..3047c0a4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: - initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + initial_position = ( + jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan + ) + + while not jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = self.prior.sample(subkey, self.sampler.n_chains) for transform in self.sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in self.parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( @@ -157,7 +176,7 @@ def print_summary(self, transform: bool = True): training_chain = self.add_name(training_chain) if transform: for sample_transform in reversed(self.sample_transforms): - training_chain = sample_transform.backward(training_chain) + training_chain = jax.vmap(sample_transform.backward)(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] @@ -167,7 +186,7 @@ def print_summary(self, transform: bool = True): production_chain = self.add_name(production_chain) if transform: for sample_transform in reversed(self.sample_transforms): - production_chain = sample_transform.backward(production_chain) + production_chain = jax.vmap(sample_transform.backward)(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -223,10 +242,10 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = chains.transpose(2, 0, 1) + chains = chains.reshape(-1, self.prior.n_dim) chains = self.add_name(chains) for sample_transform in reversed(self.sample_transforms): - chains = sample_transform.backward(chains) + chains = jax.vmap(sample_transform.backward)(chains) return chains def plot(self): diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index d325590b..1bf8f3a7 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -145,7 +145,7 @@ def __init__( tc_min: Float, tc_max: Float, ): - name_mapping = [["t_c"], ["t_det_unbounded"]] + name_mapping = (["t_c"], ["t_det_unbounded"]) conditional_names = ["ra", "dec"] super().__init__(name_mapping, conditional_names) @@ -159,7 +159,6 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names - @jnp.vectorize def time_delay(ra, dec, gmst): return self.ifo.delay_from_geocenter(ra, dec, gmst) @@ -181,9 +180,7 @@ def named_transform(x): def named_inverse_transform(x): - time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)( - x["ra"], x["dec"], self.gmst - ) + time_shift = self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) t_det_min = self.tc_min + time_shift t_det_max = self.tc_max + time_shift @@ -228,7 +225,7 @@ def __init__( gps_time: Float, ifo: GroundBased2G, ): - name_mapping = [["phase_c"], ["phase_det"]] + name_mapping = (["phase_c"], ["phase_det"]) conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) @@ -245,7 +242,6 @@ def __init__( and "iota" in conditional_names ) - @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) @@ -303,7 +299,7 @@ def __init__( dL_min: Float, dL_max: Float, ): - name_mapping = [["d_L"], ["d_hat_unbounded"]] + name_mapping = (["d_L"], ["d_hat_unbounded"]) conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) @@ -323,7 +319,6 @@ def __init__( and "M_c" in conditional_names ) - @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) From 617f3e77d594d89545c7c83458d1b3cd7f9ffded Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 20 Sep 2024 12:38:04 -0400 Subject: [PATCH 30/31] Try fixing github precommit --- .github/workflows/pre-commit.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index c58fd662..4a7d2c2f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -25,4 +25,8 @@ jobs: python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi python -m pip install . + - id: changed-files + uses: tj-actions/changed-files@v36 - uses: pre-commit/action@v3.0.0 + with: + extra_args: pip-compile --files ${{ steps.changed-files.outputs.all_changed_files }} From 766c36e5940a02cde644593813c1e0ccd3ec0d9e Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 20 Sep 2024 12:39:46 -0400 Subject: [PATCH 31/31] Try fixing github precommit --- .github/workflows/pre-commit.yml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 4a7d2c2f..30cbfb35 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -25,8 +25,4 @@ jobs: python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi python -m pip install . - - id: changed-files - uses: tj-actions/changed-files@v36 - - uses: pre-commit/action@v3.0.0 - with: - extra_args: pip-compile --files ${{ steps.changed-files.outputs.all_changed_files }} + - uses: pre-commit/action@v3.0.1