diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 88ffe52b..5c678b22 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -28,21 +28,20 @@ ] ) - run = SingleEventRun( seed=0, detectors=["H1", "L1"], priors={ - "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, + "M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0}, "q": {"name": "MassRatio"}, - "s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0}, - "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05}, - "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0}, + "t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, "cos_iota": {"name": "CosIota"}, - "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi}, - "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, "sin_dec": {"name": "SinDec"}, }, waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, @@ -90,3 +89,9 @@ ) run_manager = SingleEventPERunManager(run=run) +run_manager.jim.sample(jax.random.PRNGKey(42)) + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() + diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3f65166d..7b720c30 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -4,6 +4,8 @@ import jax import jax.numpy as jnp import matplotlib.pyplot as plt +import corner +import numpy as np import yaml from astropy.time import Time from jaxlib.xla_extension import ArrayImpl @@ -71,7 +73,8 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + path: str = "./experiment" + injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -123,6 +126,9 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError + if self.run.injection and not self.run.injection_parameters: + raise ValueError("Injection mode requires injection parameters.") + local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) @@ -150,6 +156,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: waveform = self.initialize_waveform() name = self.run.likelihood_parameters["name"] assert isinstance(name, str), "Likelihood name must be a string." + assert name in likelihood_presets, f"Likelihood {name} not recognized." if self.run.injection: freqs = jnp.linspace( self.run.data_parameters["f_min"], @@ -351,3 +358,73 @@ def plot_data(self, path: str): plt.ylabel("Amplitude") plt.legend() plt.savefig(path) + + def sample(self): + self.jim.sample(jax.random.PRNGKey(self.run.seed)) + + def get_samples(self): + return self.jim.get_samples() + + def plot_corner(self, path: str = "corner.jpeg", **kwargs): + """ + plot corner plot of the samples. + """ + plot_datapoint = kwargs.get("plot_datapoints", False) + title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84]) + show_titles = kwargs.get("show_titles", True) + title_fmt = kwargs.get("title_fmt", ".2E") + use_math_text = kwargs.get("use_math_text", True) + + samples = self.jim.get_samples() + param_names = list(samples.keys()) + samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T + corner.corner( + samples, + labels=param_names, + plot_datapoints=plot_datapoint, + title_quantiles=title_quantiles, + show_titles=show_titles, + title_fmt=title_fmt, + use_math_text=use_math_text, + **kwargs, + ) + plt.savefig(path) + plt.close() + + def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): + """ + plot diagnostic plot of the samples. + """ + summary = self.jim.Sampler.get_sampler_state(training=True) + chains, log_prob, local_accs, global_accs, loss_vals = summary.values() + log_prob = np.array(log_prob) + + plt.figure(figsize=(10, 10)) + axs = [plt.subplot(2, 2, i + 1) for i in range(4)] + plt.sca(axs[0]) + plt.title("log probability") + plt.plot(log_prob.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[1]) + plt.title("NF loss") + plt.plot(loss_vals.reshape(-1)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[2]) + plt.title("Local Acceptance") + plt.plot(local_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[3]) + plt.title("Global Acceptance") + plt.plot(global_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + plt.tight_layout() + + plt.savefig(path) + plt.close() diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..fb35bf27 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -276,6 +276,165 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 +def spin_to_cartesian_spin( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + eta: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + """ + Transforming the spin parameters + + The code is based on the approach used in LALsimulation: + https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html + + Parameters: + ------- + thetaJN: Float + Zenith angle between the total angular momentum and the line of sight + phiJL: Float + Difference between total and orbital angular momentum azimuthal angles + theta1: Float + Zenith angle between the spin and orbital angular momenta for the primary object + theta2: Float + Zenith angle between the spin and orbital angular momenta for the secondary object + phi12: Float + Difference between the azimuthal angles of the individual spin vector projections + onto the orbital plane + chi1: Float + Primary object aligned spin: + chi2: Float + Secondary object aligned spin: + M_c: Float + The chirp mass + eta: Float + The symmetric mass ratio + fRef: Float + The reference frequency + phiRef: Float + Binary phase at a reference frequency + + Returns: + ------- + iota: Float + Zenith angle between the orbital angular momentum and the line of sight + S1x: Float + The x-component of the primary spin + S1y: Float + The y-component of the primary spin + S1z: Float + The z-component of the primary spin + S2x: Float + The x-component of the secondary spin + S2y: Float + The y-component of the secondary spin + S2z: Float + The z-component of the secondary spin + """ + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + temp = 1 / eta / 2 - 1 + q = temp - (temp**2 - 1) ** 0.5 + m1, m2 = Mc_q_to_m1m2(M_c, q) + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + phi0 = jnp.arctan2(Jhat[1], Jhat[0]) + + # Rotation 1: + s1hat = rotate_z(-phi0, s1hat) + s2hat = rotate_z(-phi0, s2hat) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + s1hat = rotate_y(-theta0, s1hat) + s2hat = rotate_y(-theta0, s2hat) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + s1hat = rotate_z(phiJL - jnp.pi, s1hat) + s2hat = rotate_z(phiJL - jnp.pi, s2hat) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + thetaLJ = jnp.arccos(LNh[2]) + phiL = jnp.arctan2(LNh[1], LNh[0]) + + # Rotation 4: + s1hat = rotate_z(-phiL, s1hat) + s2hat = rotate_z(-phiL, s2hat) + N = rotate_z(-phiL, N) + + # Rotation 5: + s1hat = rotate_y(-thetaLJ, s1hat) + s2hat = rotate_y(-thetaLJ, s2hat) + N = rotate_y(-thetaLJ, N) + + # Rotation 6: + phiN = jnp.arctan2(N[1], N[0]) + s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) + s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) + + S1 = s1hat * chi1 + S2 = s2hat * chi2 + return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] + + def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x diff --git a/test/test_transform.py b/test/test_transform.py new file mode 100644 index 00000000..5d9fe464 --- /dev/null +++ b/test/test_transform.py @@ -0,0 +1,48 @@ +import numpy as np +import jax.numpy as jnp + +class TestTransform: + def test_sky_location_transform(self): + from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky + from bilby.gw.detector.networks import InterferometerList + + from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky + from jimgw.single_event.detector import detector_preset + from astropy.time import Time + + ifos = ["H1", "L1"] + geocent_time = 1000000000 + + import matplotlib.pyplot as plt + + for zenith in np.linspace(0, np.pi, 10): + for azimuth in np.linspace(0, 2*np.pi, 10): + bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) + jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) + assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) + + def test_spin_transform(self): + from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform + from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses + + from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform + + for _ in range(100): + thetaJN = jnp.array(np.random.uniform(0, np.pi)) + phiJL = jnp.array(np.random.uniform(0, np.pi)) + theta1 = jnp.array(np.random.uniform(0, np.pi)) + theta2 = jnp.array(np.random.uniform(0, np.pi)) + phi12 = jnp.array(np.random.uniform(0, np.pi)) + chi1 = jnp.array(np.random.uniform(0, 1)) + chi2 = jnp.array(np.random.uniform(0, 1)) + M_c = jnp.array(np.random.uniform(1, 100)) + eta = jnp.array(np.random.uniform(0.1, 0.25)) + fRef = jnp.array(np.random.uniform(10, 1000)) + phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) + + q = symmetric_mass_ratio_to_mass_ratio(eta) + m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) + MsunInkg = 1.9884e30 + bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) + jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) + assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4)