From 646e159a89cc48903a26b763c7df7728013b1c2d Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 12:52:38 -0400 Subject: [PATCH 01/14] black formatting for GW150914_IMRPhenomD --- example/GW150914_IMRPhenomD.py | 84 ++++++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 15 deletions(-) diff --git a/example/GW150914_IMRPhenomD.py b/example/GW150914_IMRPhenomD.py index 66619ddc..23d08c7b 100644 --- a/example/GW150914_IMRPhenomD.py +++ b/example/GW150914_IMRPhenomD.py @@ -2,12 +2,22 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, +) from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.transforms import ( + ComponentMassesToChirpMassSymmetricMassRatioTransform, + SkyFrameToDetectorFrameSkyPositionTransform, + ComponentMassesToChirpMassMassRatioTransform, +) from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -65,18 +75,62 @@ sample_transforms = [ # ComponentMassesToChirpMassMassRatioTransform, - BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), - BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=eta_min, original_upper_bound=eta_max), - BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound( + name_mapping=(["M_c"], ["M_c_unbounded"]), + original_lower_bound=M_c_min, + original_upper_bound=M_c_max, + ), + BoundToUnbound( + name_mapping=(["eta"], ["eta_unbounded"]), + original_lower_bound=eta_min, + original_upper_bound=eta_max, + ), + BoundToUnbound( + name_mapping=(["s1_z"], ["s1_z_unbounded"]), + original_lower_bound=-1.0, + original_upper_bound=1.0, + ), + BoundToUnbound( + name_mapping=(["s2_z"], ["s2_z_unbounded"]), + original_lower_bound=-1.0, + original_upper_bound=1.0, + ), + BoundToUnbound( + name_mapping=(["d_L"], ["d_L_unbounded"]), + original_lower_bound=1.0, + original_upper_bound=2000.0, + ), + BoundToUnbound( + name_mapping=(["t_c"], ["t_c_unbounded"]), + original_lower_bound=-0.05, + original_upper_bound=0.05, + ), + BoundToUnbound( + name_mapping=(["phase_c"], ["phase_c_unbounded"]), + original_lower_bound=0.0, + original_upper_bound=2 * jnp.pi, + ), + BoundToUnbound( + name_mapping=(["iota"], ["iota_unbounded"]), + original_lower_bound=0.0, + original_upper_bound=jnp.pi, + ), + BoundToUnbound( + name_mapping=(["psi"], ["psi_unbounded"]), + original_lower_bound=0.0, + original_upper_bound=jnp.pi, + ), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), - BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound( + name_mapping=(["zenith"], ["zenith_unbounded"]), + original_lower_bound=0.0, + original_upper_bound=jnp.pi, + ), + BoundToUnbound( + name_mapping=(["azimuth"], ["azimuth_unbounded"]), + original_lower_bound=0.0, + original_upper_bound=2 * jnp.pi, + ), ] likelihood_transforms = [ @@ -125,9 +179,9 @@ output_thinning=10, local_sampler_arg=local_sampler_arg, strategies=[Adam_optimizer, "default"], - verbose=True + verbose=True, ) jim.sample(jax.random.PRNGKey(42)) # jim.get_samples() -# jim.print_summary() \ No newline at end of file +# jim.print_summary() From de6ab602e553c8bfc8705f14284f65af4ddc3dbc Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:07:00 -0400 Subject: [PATCH 02/14] Rename IMRPhenomPv2 --- example/GW150914_IMRPhenomPV2.py | 153 ++++++++++++++++++++++++++++ example/GW150914_PV2.py | 165 ------------------------------- 2 files changed, 153 insertions(+), 165 deletions(-) create mode 100644 example/GW150914_IMRPhenomPV2.py delete mode 100644 example/GW150914_PV2.py diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py new file mode 100644 index 00000000..45895c62 --- /dev/null +++ b/example/GW150914_IMRPhenomPV2.py @@ -0,0 +1,153 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.jim import Jim +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, + UniformSpherePrior, +) +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ( + ComponentMassesToChirpMassSymmetricMassRatioTransform, + SkyFrameToDetectorFrameSkyPositionTransform, + ComponentMassesToChirpMassMassRatioTransform, +) +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +start = gps - 2 +end = gps + 2 +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +waveform = RippleIMRPhenomPv2(f_ref=20) + +########################################### +########## Set up priors ################## +########################################### + +prior = [] + +# Mass prior +M_c_min, M_c_max = 10.0, 80.0 +eta_min, eta_max = 0.2, 0.25 +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) +eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"]) + +prior = prior + [Mc_prior, eta_prior] + +# Spin prior +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) + +prior = prior + [ + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, +] + +# Extrinsic prior +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, +] + + +prior = CombinePrior(prior) + +likelihood = TransientLikelihoodFD( + [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 +) + + +mass_matrix = jnp.eye(prior.n_dim) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 1e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +# import optax + +# n_epochs = 20 +# n_loop_training = 100 +# total_epochs = n_epochs * n_loop_training +# start = total_epochs // 10 +# learning_rate = optax.polynomial_schedule( +# 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start +# ) + +# jim = Jim( +# likelihood, +# prior, +# n_loop_training=n_loop_training, +# n_loop_production=20, +# n_local_steps=10, +# n_global_steps=1000, +# n_chains=500, +# n_epochs=n_epochs, +# learning_rate=learning_rate, +# n_max_examples=30000, +# n_flow_sample=100000, +# momentum=0.9, +# batch_size=30000, +# use_global=True, +# keep_quantile=0.0, +# train_thinning=1, +# output_thinning=10, +# local_sampler_arg=local_sampler_arg, +# # strategies=[Adam_optimizer,"default"], +# ) + +# import numpy as np + +# # chains = np.load('./GW150914_init.npz')['chain'] + +# jim.sample(jax.random.PRNGKey(42)) # ,initial_guess=chains) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py deleted file mode 100644 index 06209ba6..00000000 --- a/example/GW150914_PV2.py +++ /dev/null @@ -1,165 +0,0 @@ -import time - -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import Composite, Sphere, Unconstrained_Uniform -from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomPv2 -from flowMC.strategy.optimization import optimization_Adam - - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -total_time_start = time.time() - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 -start = gps - 2 -end = gps + 2 -fmin = 20.0 -fmax = 1024.0 - -ifos = ["H1", "L1"] - -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -waveform = RippleIMRPhenomPv2(f_ref=20) - -########################################### -########## Set up priors ################## -########################################### - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1_prior = Sphere(naming="s1") -s2_prior = Sphere(naming="s2") -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( - [ - Mc_prior, - q_prior, - s1_prior, - s2_prior, - dL_prior, - t_c_prior, - phase_c_prior, - cos_iota_prior, - psi_prior, - ra_prior, - sin_dec_prior, - ], -) - -epsilon = 1e-3 -bounds = jnp.array( - [ - [10.0, 80.0], - [0.125, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0.0, 2000], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) + jnp.array([[epsilon, -epsilon]]) - -likelihood = TransientLikelihoodFD( - [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 -) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) - - -mass_matrix = jnp.eye(prior.n_dim) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[9, 9].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 1e-3} - -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1, bounds=bounds) - -import optax -n_epochs = 20 -n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) - -jim = Jim( - likelihood, - prior, - n_loop_training=n_loop_training, - n_loop_production=20, - n_local_steps=10, - n_global_steps=1000, - n_chains=500, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30000, - n_flow_sample=100000, - momentum=0.9, - batch_size=30000, - use_global=True, - keep_quantile=0.0, - train_thinning=1, - output_thinning=10, - local_sampler_arg=local_sampler_arg, - # strategies=[Adam_optimizer,"default"], -) - -import numpy as np -# chains = np.load('./GW150914_init.npz')['chain'] - -jim.sample(jax.random.PRNGKey(42))#,initial_guess=chains) From 41abe69f85abcb9d439eba34082a491e8797b796 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:17:14 -0400 Subject: [PATCH 03/14] Delete redundant spin function --- src/jimgw/single_event/utils.py | 160 -------------------------------- 1 file changed, 160 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index fb35bf27..e136af02 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -275,166 +275,6 @@ def eta_to_q(eta: Float) -> Float: temp = 1 / eta / 2 - 1 return temp - (temp**2 - 1) ** 0.5 - -def spin_to_cartesian_spin( - thetaJN: Float, - phiJL: Float, - theta1: Float, - theta2: Float, - phi12: Float, - chi1: Float, - chi2: Float, - M_c: Float, - eta: Float, - fRef: Float, - phiRef: Float, -) -> tuple[Float, Float, Float, Float, Float, Float, Float]: - """ - Transforming the spin parameters - - The code is based on the approach used in LALsimulation: - https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html - - Parameters: - ------- - thetaJN: Float - Zenith angle between the total angular momentum and the line of sight - phiJL: Float - Difference between total and orbital angular momentum azimuthal angles - theta1: Float - Zenith angle between the spin and orbital angular momenta for the primary object - theta2: Float - Zenith angle between the spin and orbital angular momenta for the secondary object - phi12: Float - Difference between the azimuthal angles of the individual spin vector projections - onto the orbital plane - chi1: Float - Primary object aligned spin: - chi2: Float - Secondary object aligned spin: - M_c: Float - The chirp mass - eta: Float - The symmetric mass ratio - fRef: Float - The reference frequency - phiRef: Float - Binary phase at a reference frequency - - Returns: - ------- - iota: Float - Zenith angle between the orbital angular momentum and the line of sight - S1x: Float - The x-component of the primary spin - S1y: Float - The y-component of the primary spin - S1z: Float - The z-component of the primary spin - S2x: Float - The x-component of the secondary spin - S2y: Float - The y-component of the secondary spin - S2z: Float - The z-component of the secondary spin - """ - - def rotate_y(angle, vec): - """ - Rotate the vector (x, y, z) about y-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - def rotate_z(angle, vec): - """ - Rotate the vector (x, y, z) about z-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - LNh = jnp.array([0.0, 0.0, 1.0]) - - s1hat = jnp.array( - [ - jnp.sin(theta1) * jnp.cos(phiRef), - jnp.sin(theta1) * jnp.sin(phiRef), - jnp.cos(theta1), - ] - ) - s2hat = jnp.array( - [ - jnp.sin(theta2) * jnp.cos(phi12 + phiRef), - jnp.sin(theta2) * jnp.sin(phi12 + phiRef), - jnp.cos(theta2), - ] - ) - - temp = 1 / eta / 2 - 1 - q = temp - (temp**2 - 1) ** 0.5 - m1, m2 = Mc_q_to_m1m2(M_c, q) - v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) - - Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) - s1 = m1 * m1 * chi1 * s1hat - s2 = m2 * m2 * chi2 * s2hat - J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) - - Jhat = J / jnp.linalg.norm(J) - theta0 = jnp.arccos(Jhat[2]) - phi0 = jnp.arctan2(Jhat[1], Jhat[0]) - - # Rotation 1: - s1hat = rotate_z(-phi0, s1hat) - s2hat = rotate_z(-phi0, s2hat) - - # Rotation 2: - LNh = rotate_y(-theta0, LNh) - s1hat = rotate_y(-theta0, s1hat) - s2hat = rotate_y(-theta0, s2hat) - - # Rotation 3: - LNh = rotate_z(phiJL - jnp.pi, LNh) - s1hat = rotate_z(phiJL - jnp.pi, s1hat) - s2hat = rotate_z(phiJL - jnp.pi, s2hat) - - # Compute iota - N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) - iota = jnp.arccos(jnp.dot(N, LNh)) - - thetaLJ = jnp.arccos(LNh[2]) - phiL = jnp.arctan2(LNh[1], LNh[0]) - - # Rotation 4: - s1hat = rotate_z(-phiL, s1hat) - s2hat = rotate_z(-phiL, s2hat) - N = rotate_z(-phiL, N) - - # Rotation 5: - s1hat = rotate_y(-thetaLJ, s1hat) - s2hat = rotate_y(-thetaLJ, s2hat) - N = rotate_y(-thetaLJ, N) - - # Rotation 6: - phiN = jnp.arctan2(N[1], N[0]) - s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) - s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) - - S1 = s1hat * chi1 - S2 = s2hat * chi2 - return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x From 84ff0b7df83e73dbe6e3fb0b5e91953baf617c41 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:18:59 -0400 Subject: [PATCH 04/14] update PV2 example --- example/GW150914_IMRPhenomPV2.py | 125 +++++++++++++++++++------------ 1 file changed, 77 insertions(+), 48 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 45895c62..b47567dc 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -18,9 +18,10 @@ from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.transforms import BoundToUnbound from jimgw.single_event.transforms import ( - ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, - ComponentMassesToChirpMassMassRatioTransform, + SpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform + ) from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -40,7 +41,7 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] +ifos = [H1, L1] H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) @@ -55,29 +56,30 @@ # Mass prior M_c_min, M_c_max = 10.0, 80.0 -eta_min, eta_max = 0.2, 0.25 +q_min, q_max = 0.125, 1.0 Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) -eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"]) +q_prior = UniformPrior(q_min, q_max, parameter_names=["q"]) -prior = prior + [Mc_prior, eta_prior] +prior = prior + [Mc_prior, q_prior] # Spin prior +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) -theta_1_prior = SinePrior(parameter_names=["theta_1"]) -theta_2_prior = SinePrior(parameter_names=["theta_2"]) +tilt_1_prior = SinePrior(parameter_names=["tilt_1"]) +tilt_2_prior = SinePrior(parameter_names=["tilt_2"]) phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) -a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) -a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) + prior = prior + [ + a_1_prior, + a_2_prior, theta_jn_prior, phi_jl_prior, - theta_1_prior, - theta_2_prior, + tilt_1_prior, + tilt_2_prior, phi_12_prior, - a_1_prior, - a_2_prior, ] # Extrinsic prior @@ -99,52 +101,79 @@ dec_prior, ] - prior = CombinePrior(prior) +# Defining Transforms + +sample_transforms = [ + # ComponentMassesToChirpMassMassRatioTransform, + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["phi_jl"], ["phi_jl_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["theta_1"], ["theta_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["theta_2"], ["theta_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=10.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform, + SpinToCartesianSpinTransform(freq_ref=20.), +] + + likelihood = TransientLikelihoodFD( [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 ) mass_matrix = jnp.eye(prior.n_dim) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[9, 9].set(1e-3) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 1e-3} Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -# import optax - -# n_epochs = 20 -# n_loop_training = 100 -# total_epochs = n_epochs * n_loop_training -# start = total_epochs // 10 -# learning_rate = optax.polynomial_schedule( -# 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -# ) - -# jim = Jim( -# likelihood, -# prior, -# n_loop_training=n_loop_training, -# n_loop_production=20, -# n_local_steps=10, -# n_global_steps=1000, -# n_chains=500, -# n_epochs=n_epochs, -# learning_rate=learning_rate, -# n_max_examples=30000, -# n_flow_sample=100000, -# momentum=0.9, -# batch_size=30000, -# use_global=True, -# keep_quantile=0.0, -# train_thinning=1, -# output_thinning=10, -# local_sampler_arg=local_sampler_arg, -# # strategies=[Adam_optimizer,"default"], -# ) +import optax + +n_epochs = 20 +n_loop_training = 100 +total_epochs = n_epochs * n_loop_training +start = total_epochs // 10 +learning_rate = optax.polynomial_schedule( + 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start +) + +jim = Jim( + likelihood, + prior, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_sample=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + # strategies=[Adam_optimizer,"default"], +) # import numpy as np From c666bcbeb1aab683943f4c1f9a9fb27b908a6f89 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:19:27 -0400 Subject: [PATCH 05/14] update Pv2 example --- example/GW150914_IMRPhenomPV2.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index b47567dc..87e23680 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -175,8 +175,5 @@ # strategies=[Adam_optimizer,"default"], ) -# import numpy as np -# # chains = np.load('./GW150914_init.npz')['chain'] - -# jim.sample(jax.random.PRNGKey(42)) # ,initial_guess=chains) +jim.sample(jax.random.PRNGKey(42)) From 799ffc91a37025d9f5d1f53d4567190becce5c58 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:24:12 -0400 Subject: [PATCH 06/14] update PV2 --- example/GW150914_IMRPhenomPV2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 87e23680..1263c85b 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -156,6 +156,8 @@ jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=20, n_local_steps=10, From a94586e8bd0ffa4c3baed0b94813789711bb7653 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 20 Sep 2024 13:43:47 -0400 Subject: [PATCH 07/14] Update parameters from theta to tilt in accordence to Lalsuite --- example/GW150914_IMRPhenomPV2.py | 8 ++++---- src/jimgw/single_event/transforms.py | 6 +++--- src/jimgw/single_event/utils.py | 21 +++++++++++---------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 1263c85b..ef35f5cf 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -108,11 +108,11 @@ sample_transforms = [ # ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), - BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = (["phi_jl"], ["phi_jl_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["theta_1"], ["theta_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["theta_2"], ["theta_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["tilt_1"], ["tilt_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["tilt_2"], ["tilt_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), @@ -126,8 +126,8 @@ ] likelihood_transforms = [ - MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform(freq_ref=20.), + MassRatioToSymmetricMassRatioTransform, ] diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 1bf8f3a7..b3062da9 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -37,7 +37,7 @@ def __init__( freq_ref: Float, ): name_mapping = ( - ["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], + ["theta_jn", "phi_jl", "tilt_1", "tilt_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"], ) super().__init__(name_mapping) @@ -48,8 +48,8 @@ def named_transform(x): iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( x["theta_jn"], x["phi_jl"], - x["theta_1"], - x["theta_2"], + x["tilt_1"], + x["tilt_2"], x["phi_12"], x["a_1"], x["a_2"], diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index e136af02..517b2844 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -275,6 +275,7 @@ def eta_to_q(eta: Float) -> Float: temp = 1 / eta / 2 - 1 return temp - (temp**2 - 1) ** 0.5 + def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x @@ -395,8 +396,8 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F def spin_to_cartesian_spin( thetaJN: Float, phiJL: Float, - theta1: Float, - theta2: Float, + tilt1: Float, + tilt2: Float, phi12: Float, chi1: Float, chi2: Float, @@ -417,9 +418,9 @@ def spin_to_cartesian_spin( Zenith angle between the total angular momentum and the line of sight phiJL: Float Difference between total and orbital angular momentum azimuthal angles - theta1: Float + tilt1: Float Zenith angle between the spin and orbital angular momenta for the primary object - theta2: Float + tilt2: Float Zenith angle between the spin and orbital angular momenta for the secondary object phi12: Float Difference between the azimuthal angles of the individual spin vector projections @@ -483,16 +484,16 @@ def rotate_z(angle, vec): s1hat = jnp.array( [ - jnp.sin(theta1) * jnp.cos(phiRef), - jnp.sin(theta1) * jnp.sin(phiRef), - jnp.cos(theta1), + jnp.sin(tilt1) * jnp.cos(phiRef), + jnp.sin(tilt1) * jnp.sin(phiRef), + jnp.cos(tilt1), ] ) s2hat = jnp.array( [ - jnp.sin(theta2) * jnp.cos(phi12 + phiRef), - jnp.sin(theta2) * jnp.sin(phi12 + phiRef), - jnp.cos(theta2), + jnp.sin(tilt2) * jnp.cos(phi12 + phiRef), + jnp.sin(tilt2) * jnp.sin(phi12 + phiRef), + jnp.cos(tilt2), ] ) From 95073104f9b9d0579da5e053bcdae0ed167267e4 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 20 Sep 2024 13:48:25 -0400 Subject: [PATCH 08/14] update Pv2 --- example/GW150914_IMRPhenomPV2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index ef35f5cf..00a410cc 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -20,8 +20,10 @@ from jimgw.single_event.transforms import ( SkyFrameToDetectorFrameSkyPositionTransform, SpinToCartesianSpinTransform, - MassRatioToSymmetricMassRatioTransform - + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, ) from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -106,7 +108,10 @@ # Defining Transforms sample_transforms = [ - # ComponentMassesToChirpMassMassRatioTransform, + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), @@ -116,11 +121,8 @@ BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=10.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] From 7dd5145b696bb7e248f5e63ecef510c9207d7d98 Mon Sep 17 00:00:00 2001 From: kazewong Date: Fri, 20 Sep 2024 14:04:45 -0400 Subject: [PATCH 09/14] update pv2 --- example/GW150914_IMRPhenomPV2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 00a410cc..c87d0fee 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -88,7 +88,6 @@ dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -iota_prior = SinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) dec_prior = CosinePrior(parameter_names=["dec"]) @@ -97,7 +96,6 @@ dL_prior, t_c_prior, phase_c_prior, - iota_prior, psi_prior, ra_prior, dec_prior, @@ -108,6 +106,7 @@ # Defining Transforms sample_transforms = [ + SpinToCartesianSpinTransform(freq_ref=20.), DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), @@ -128,7 +127,6 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(freq_ref=20.), MassRatioToSymmetricMassRatioTransform, ] From 4cf6b2748555590247232c37f521074e29024e89 Mon Sep 17 00:00:00 2001 From: kazewong Date: Sat, 5 Oct 2024 16:55:57 -0400 Subject: [PATCH 10/14] update transform name --- src/jimgw/single_event/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index b3062da9..9297029d 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -25,7 +25,7 @@ @jaxtyped(typechecker=typechecker) -class SpinToCartesianSpinTransform(NtoNTransform): +class PrecessingSpinToCartesianSpinTransform(NtoNTransform): """ Spin to Cartesian spin transformation """ From 66027c333b4a5e136b7e27cf1f9eddbe2bd7d943 Mon Sep 17 00:00:00 2001 From: kazewong Date: Mon, 7 Oct 2024 08:38:28 -0400 Subject: [PATCH 11/14] fix Pv2 prior and transform --- example/GW150914_IMRPhenomPV2.py | 40 +++++++++++--------------- src/jimgw/jim.py | 2 +- src/jimgw/single_event/transforms.py | 42 ++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 25 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index c87d0fee..c291adbf 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -19,7 +19,7 @@ from jimgw.transforms import BoundToUnbound from jimgw.single_event.transforms import ( SkyFrameToDetectorFrameSkyPositionTransform, - SpinToCartesianSpinTransform, + SphereSpinToCartesianSpinTransform, MassRatioToSymmetricMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, @@ -65,23 +65,14 @@ prior = prior + [Mc_prior, q_prior] # Spin prior -a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) -a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) -theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) -phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) -tilt_1_prior = SinePrior(parameter_names=["tilt_1"]) -tilt_2_prior = SinePrior(parameter_names=["tilt_2"]) -phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) - +s1_prior = UniformSpherePrior(parameter_names=["s1"]) +s2_prior = UniformSpherePrior(parameter_names=["s2"]) +iota_prior = SinePrior(parameter_names=["iota"]) prior = prior + [ - a_1_prior, - a_2_prior, - theta_jn_prior, - phi_jl_prior, - tilt_1_prior, - tilt_2_prior, - phi_12_prior, + s1_prior, + s2_prior, + iota_prior, ] # Extrinsic prior @@ -106,20 +97,19 @@ # Defining Transforms sample_transforms = [ - SpinToCartesianSpinTransform(freq_ref=20.), DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), - BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["phi_jl"], ["phi_jl_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["tilt_1"], ["tilt_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["tilt_2"], ["tilt_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), @@ -128,6 +118,8 @@ likelihood_transforms = [ MassRatioToSymmetricMassRatioTransform, + SphereSpinToCartesianSpinTransform("s1"), + SphereSpinToCartesianSpinTransform("s2"), ] diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 7f16532d..805e9268 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -217,7 +217,7 @@ def get_samples(self, training: bool = False) -> dict: chains = self.sampler.get_sampler_state(training=False)["chains"] chains = chains.reshape(-1, self.prior.n_dim) - chains = self.add_name(chains) + chains = jax.vmap(self.add_name)(chains) for sample_transform in reversed(self.sample_transforms): chains = jax.vmap(sample_transform.backward)(chains) return chains diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 9297029d..e9983f06 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -71,6 +71,48 @@ def named_transform(x): self.transform_func = named_transform +@jaxtyped(typechecker=typechecker) +class SphereSpinToCartesianSpinTransform(BijectiveTransform): + """ + Spin to Cartesian spin transformation + """ + + def __init__( + self, + label: str, + ): + name_mapping = ( + [label + "_mag", label + "_theta", label + "_phi"], + [label + "_x", label + "_y", label + "_z"], + ) + super().__init__(name_mapping) + + def named_transform(x): + mag, theta, phi = x[label + "_mag"], x[label + "_theta"], x[label + "_phi"] + x = mag * jnp.sin(theta) * jnp.cos(phi) + y = mag * jnp.sin(theta) * jnp.sin(phi) + z = mag * jnp.cos(theta) + return { + label + "_x": x, + label + "_y": y, + label + "_z": z, + } + + def named_inverse_transform(x): + x, y, z = x[label + "_x"], x[label + "_y"], x[label + "_z"] + mag = jnp.sqrt(x**2 + y**2 + z**2) + theta = jnp.arccos(z / mag) + phi = jnp.arctan2(y, x) + return { + label + "_mag": mag, + label + "_theta": theta, + label + "_phi": phi, + } + + self.transform_func = named_transform + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ From d69c05a7d1ac3916afbce1842ea80f6169285e78 Mon Sep 17 00:00:00 2001 From: kazewong Date: Sat, 12 Oct 2024 22:42:07 -0400 Subject: [PATCH 12/14] move prior functions to the right place. Also add magnitude options to UniformSphere --- src/jimgw/prior.py | 19 ++++++++++++++++-- src/jimgw/single_event/prior.py | 35 --------------------------------- 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 5227ffa5..3db3f122 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -331,7 +331,7 @@ class UniformSpherePrior(CombinePrior): def __repr__(self): return f"UniformSpherePrior(parameter_names={self.parameter_names})" - def __init__(self, parameter_names: list[str], **kwargs): + def __init__(self, parameter_names: list[str], max_mag: float = 1.0, **kwargs): self.parameter_names = parameter_names assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" self.parameter_names = [ @@ -341,7 +341,7 @@ def __init__(self, parameter_names: list[str], **kwargs): ] super().__init__( [ - UniformPrior(0.0, 1.0, [self.parameter_names[0]]), + UniformPrior(0.0, max_mag, [self.parameter_names[0]]), SinePrior([self.parameter_names[1]]), UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), ] @@ -397,6 +397,21 @@ def __init__( ], ) +def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: + if prior.composite: + if isinstance(prior.base_prior, list): + for subprior in prior.base_prior: + output = trace_prior_parent(subprior, output) + elif isinstance(prior.base_prior, Prior): + output = trace_prior_parent(prior.base_prior, output) + else: + output.append(prior) + + return output + + + + # ====================== Things below may need rework ====================== diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py index 51a754eb..76ca6376 100644 --- a/src/jimgw/single_event/prior.py +++ b/src/jimgw/single_event/prior.py @@ -11,28 +11,6 @@ ) -@jaxtyped(typechecker=typechecker) -class UniformSpherePrior(CombinePrior): - - def __repr__(self): - return f"UniformSpherePrior(parameter_names={self.parameter_names})" - - def __init__(self, parameter_names: list[str], **kwargs): - self.parameter_names = parameter_names - assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector" - self.parameter_names = [ - f"{self.parameter_names[0]}_mag", - f"{self.parameter_names[0]}_theta", - f"{self.parameter_names[0]}_phi", - ] - super().__init__( - [ - UniformPrior(0.0, 1.0, [self.parameter_names[0]]), - SinePrior([self.parameter_names[1]]), - UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]), - ] - ) - @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): @@ -50,19 +28,6 @@ def __init__(self, xmin: float, xmax: float): super().__init__(xmin, xmax, 1.0, ["M_c"]) -def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: - if prior.composite: - if isinstance(prior.base_prior, list): - for subprior in prior.base_prior: - output = trace_prior_parent(subprior, output) - elif isinstance(prior.base_prior, Prior): - output = trace_prior_parent(prior.base_prior, output) - else: - output.append(prior) - - return output - - # ====================== Things below may need rework ====================== From d062ee111e58a765256e347a68e91a0d5f580047 Mon Sep 17 00:00:00 2001 From: kazewong Date: Sat, 12 Oct 2024 22:43:05 -0400 Subject: [PATCH 13/14] format --- src/jimgw/prior.py | 4 +-- src/jimgw/single_event/likelihood.py | 41 ++++++++++++++++++++-------- src/jimgw/single_event/prior.py | 6 ---- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3db3f122..87406ca8 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -397,6 +397,7 @@ def __init__( ], ) + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): @@ -410,9 +411,6 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: return output - - - # ====================== Things below may need rework ====================== # @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 9e775b33..96e11e62 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -584,18 +584,35 @@ def y(x): ) key = jax.random.PRNGKey(0) - initial_position = [] - for _ in range(popsize): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = prior.sample(key, 1) - for transform in sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_position.append(guess) - initial_position = jnp.array(initial_position) + initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] + + key, subkey = jax.random.split(key) + guess = prior.sample(subkey, popsize) + for transform in sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) + rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py index 76ca6376..194262f0 100644 --- a/src/jimgw/single_event/prior.py +++ b/src/jimgw/single_event/prior.py @@ -1,17 +1,11 @@ -import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import jaxtyped from jimgw.prior import ( - Prior, - CombinePrior, - UniformPrior, PowerLawPrior, - SinePrior, ) - @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): """ From c5957b451011621516602ca21c0fe6d09a41fc6b Mon Sep 17 00:00:00 2001 From: kazewong Date: Sat, 12 Oct 2024 22:43:32 -0400 Subject: [PATCH 14/14] Add GW170817 example --- example/GW170817_IMRPhenomPV2.py | 190 +++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 example/GW170817_IMRPhenomPV2.py diff --git a/example/GW170817_IMRPhenomPV2.py b/example/GW170817_IMRPhenomPV2.py new file mode 100644 index 00000000..55bd211b --- /dev/null +++ b/example/GW170817_IMRPhenomPV2.py @@ -0,0 +1,190 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.jim import Jim +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, + UniformSpherePrior, +) +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import TransientLikelihoodFD, HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ( + SkyFrameToDetectorFrameSkyPositionTransform, + SphereSpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, +) +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 + +gps = 1187008882.43 +trigger_time = gps +fmin = 20 +fmax = 2048 +minimum_frequency = fmin +maximum_frequency = fmax +duration = 128 +post_trigger_duration = 2 +epoch = duration - post_trigger_duration +f_ref = fmin + +ifos = [H1, L1, V1] + + +tukey_alpha = 2 / (duration / 2) +H1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +L1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +V1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) + + +waveform = RippleIMRPhenomPv2(f_ref=f_ref) + +########################################### +########## Set up priors ################## +########################################### + +prior = [] + +# Mass prior +M_c_min, M_c_max = 1.18, 1.21 +q_min, q_max = 0.125, 1.0 +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) +q_prior = UniformPrior(q_min, q_max, parameter_names=["q"]) + +prior = prior + [Mc_prior, q_prior] + +# Spin prior +s1_prior = UniformSpherePrior(parameter_names=["s1"], max_mag = 0.05) +s2_prior = UniformSpherePrior(parameter_names=["s2"], max_mag = 0.05) +iota_prior = SinePrior(parameter_names=["iota"]) + +prior = prior + [ + s1_prior, + s2_prior, + iota_prior, +] + +# Extrinsic prior +dL_prior = PowerLawPrior(1.0, 75.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.1, 0.1, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, +] + +prior = CombinePrior(prior) + +# Defining Transforms + +sample_transforms = [ + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.05), + BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.05), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform, + SphereSpinToCartesianSpinTransform("s1"), + SphereSpinToCartesianSpinTransform("s2"), +] + + +#likelihood = TransientLikelihoodFD( +# [H1, L1, V1], waveform=waveform, trigger_time=trigger_time, duration=duration, post_trigger_duration=post_trigger_duration +#) + +likelihood = HeterodynedTransientLikelihoodFD(ifos, waveform=waveform, n_bins = 1000, trigger_time=trigger_time, duration=duration, post_trigger_duration=post_trigger_duration, prior = prior, sample_transforms = sample_transforms, likelihood_transforms = likelihood_transforms, popsize = 10, n_steps = 50) + +mass_matrix = jnp.eye(prior.n_dim) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 1e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +import optax + +n_epochs = 20 +n_loop_training = 100 +total_epochs = n_epochs * n_loop_training +start = total_epochs // 10 +learning_rate = optax.polynomial_schedule( + 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start +) + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_sample=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + # strategies=[Adam_optimizer,"default"], +) + + +jim.sample(jax.random.PRNGKey(42))