From 1c2160595a4c5ce016c71eb1f65f489d9d9dc9d3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 10:58:46 +0800 Subject: [PATCH 1/7] Remove legacy test files --- test/test_prior.py | 1 - test/test_transform.py | 48 ------------------------------------------ 2 files changed, 49 deletions(-) delete mode 100644 test/test_prior.py delete mode 100644 test/test_transform.py diff --git a/test/test_prior.py b/test/test_prior.py deleted file mode 100644 index 8b137891..00000000 --- a/test/test_prior.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/test_transform.py b/test/test_transform.py deleted file mode 100644 index 5d9fe464..00000000 --- a/test/test_transform.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import jax.numpy as jnp - -class TestTransform: - def test_sky_location_transform(self): - from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky - from bilby.gw.detector.networks import InterferometerList - - from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky - from jimgw.single_event.detector import detector_preset - from astropy.time import Time - - ifos = ["H1", "L1"] - geocent_time = 1000000000 - - import matplotlib.pyplot as plt - - for zenith in np.linspace(0, np.pi, 10): - for azimuth in np.linspace(0, 2*np.pi, 10): - bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) - jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) - assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) - - def test_spin_transform(self): - from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform - from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses - - from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform - - for _ in range(100): - thetaJN = jnp.array(np.random.uniform(0, np.pi)) - phiJL = jnp.array(np.random.uniform(0, np.pi)) - theta1 = jnp.array(np.random.uniform(0, np.pi)) - theta2 = jnp.array(np.random.uniform(0, np.pi)) - phi12 = jnp.array(np.random.uniform(0, np.pi)) - chi1 = jnp.array(np.random.uniform(0, 1)) - chi2 = jnp.array(np.random.uniform(0, 1)) - M_c = jnp.array(np.random.uniform(1, 100)) - eta = jnp.array(np.random.uniform(0.1, 0.25)) - fRef = jnp.array(np.random.uniform(10, 1000)) - phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) - - q = symmetric_mass_ratio_to_mass_ratio(eta) - m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) - MsunInkg = 1.9884e30 - bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) - jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) - assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4) From 19c4a29892ad149149761d789d7ee3e4da49f3ff Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 11:08:52 +0800 Subject: [PATCH 2/7] Update utils.py --- src/jimgw/single_event/utils.py | 159 -------------------------------- 1 file changed, 159 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index fb35bf27..a15bd7bf 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -276,165 +276,6 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 -def spin_to_cartesian_spin( - thetaJN: Float, - phiJL: Float, - theta1: Float, - theta2: Float, - phi12: Float, - chi1: Float, - chi2: Float, - M_c: Float, - eta: Float, - fRef: Float, - phiRef: Float, -) -> tuple[Float, Float, Float, Float, Float, Float, Float]: - """ - Transforming the spin parameters - - The code is based on the approach used in LALsimulation: - https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html - - Parameters: - ------- - thetaJN: Float - Zenith angle between the total angular momentum and the line of sight - phiJL: Float - Difference between total and orbital angular momentum azimuthal angles - theta1: Float - Zenith angle between the spin and orbital angular momenta for the primary object - theta2: Float - Zenith angle between the spin and orbital angular momenta for the secondary object - phi12: Float - Difference between the azimuthal angles of the individual spin vector projections - onto the orbital plane - chi1: Float - Primary object aligned spin: - chi2: Float - Secondary object aligned spin: - M_c: Float - The chirp mass - eta: Float - The symmetric mass ratio - fRef: Float - The reference frequency - phiRef: Float - Binary phase at a reference frequency - - Returns: - ------- - iota: Float - Zenith angle between the orbital angular momentum and the line of sight - S1x: Float - The x-component of the primary spin - S1y: Float - The y-component of the primary spin - S1z: Float - The z-component of the primary spin - S2x: Float - The x-component of the secondary spin - S2y: Float - The y-component of the secondary spin - S2z: Float - The z-component of the secondary spin - """ - - def rotate_y(angle, vec): - """ - Rotate the vector (x, y, z) about y-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - def rotate_z(angle, vec): - """ - Rotate the vector (x, y, z) about z-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - LNh = jnp.array([0.0, 0.0, 1.0]) - - s1hat = jnp.array( - [ - jnp.sin(theta1) * jnp.cos(phiRef), - jnp.sin(theta1) * jnp.sin(phiRef), - jnp.cos(theta1), - ] - ) - s2hat = jnp.array( - [ - jnp.sin(theta2) * jnp.cos(phi12 + phiRef), - jnp.sin(theta2) * jnp.sin(phi12 + phiRef), - jnp.cos(theta2), - ] - ) - - temp = 1 / eta / 2 - 1 - q = temp - (temp**2 - 1) ** 0.5 - m1, m2 = Mc_q_to_m1m2(M_c, q) - v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) - - Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) - s1 = m1 * m1 * chi1 * s1hat - s2 = m2 * m2 * chi2 * s2hat - J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) - - Jhat = J / jnp.linalg.norm(J) - theta0 = jnp.arccos(Jhat[2]) - phi0 = jnp.arctan2(Jhat[1], Jhat[0]) - - # Rotation 1: - s1hat = rotate_z(-phi0, s1hat) - s2hat = rotate_z(-phi0, s2hat) - - # Rotation 2: - LNh = rotate_y(-theta0, LNh) - s1hat = rotate_y(-theta0, s1hat) - s2hat = rotate_y(-theta0, s2hat) - - # Rotation 3: - LNh = rotate_z(phiJL - jnp.pi, LNh) - s1hat = rotate_z(phiJL - jnp.pi, s1hat) - s2hat = rotate_z(phiJL - jnp.pi, s2hat) - - # Compute iota - N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) - iota = jnp.arccos(jnp.dot(N, LNh)) - - thetaLJ = jnp.arccos(LNh[2]) - phiL = jnp.arctan2(LNh[1], LNh[0]) - - # Rotation 4: - s1hat = rotate_z(-phiL, s1hat) - s2hat = rotate_z(-phiL, s2hat) - N = rotate_z(-phiL, N) - - # Rotation 5: - s1hat = rotate_y(-thetaLJ, s1hat) - s2hat = rotate_y(-thetaLJ, s2hat) - N = rotate_y(-thetaLJ, N) - - # Rotation 6: - phiN = jnp.arctan2(N[1], N[0]) - s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) - s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) - - S1 = s1hat * chi1 - S2 = s2hat * chi2 - return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x From 169aaa29e38914e5e17a90a11ea562837687599f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 11:16:49 +0800 Subject: [PATCH 3/7] Reformat --- src/jimgw/single_event/utils.py | 174 ++++++++++++++++---------------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..62e8ad6d 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -393,6 +393,93 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F return ra, dec +def zenith_azimuth_to_ra_dec( + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Copied and modified from bilby/gw/utils.py + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + theta, phi = angle_rotation(zenith, azimuth, rotation) + ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) + return ra, dec + + +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth + + def spin_to_cartesian_spin( thetaJN: Float, phiJL: Float, @@ -549,90 +636,3 @@ def rotate_z(angle, vec): S1 = s1hat * chi1 S2 = s2hat * chi2 return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - -def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. - - Parameters - ---------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Copied and modified from bilby/gw/utils.py - - Returns - ------- - ra : Float - Right ascension. - dec : Float - Declination. - """ - theta, phi = angle_rotation(zenith, azimuth, rotation) - ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - return ra, dec - - -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) - return theta, phi - - -def ra_dec_to_zenith_azimuth( - ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the right ascension and declination to the zenith angle and azimuthal angle. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Returns - ------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - """ - theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) - zenith, azimuth = angle_rotation(theta, phi, rotation) - return zenith, azimuth From 0f687f7e06d4554831210289a7c66ea2064480ba Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 14:35:54 +0800 Subject: [PATCH 4/7] Update runmanger --- example/Single_event_runManager.py | 107 ++++++++++++------------ src/jimgw/single_event/detector.py | 2 +- src/jimgw/single_event/runManager.py | 116 +++++++++++++-------------- 3 files changed, 108 insertions(+), 117 deletions(-) diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 8885ff3f..8ecba492 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -1,4 +1,3 @@ - import jax import jax.numpy as jnp @@ -12,57 +11,49 @@ mass_matrix = mass_matrix.at[5, 5].set(1e-3) mass_matrix = mass_matrix * 3e-3 local_sampler_arg = {"step_size": mass_matrix} -bounds = jnp.array( - [ - [10.0, 40.0], - [0.125, 1.0], - [-1.0, 1.0], - [-1.0, 1.0], - [0.0, 2000.0], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) run = SingleEventRun( seed=0, detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, priors={ - "M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0}, - "q": {"name": "MassRatio"}, - "s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, - "s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, - "d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0}, - "t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05}, - "phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "cos_iota": {"name": "CosIota"}, - "psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi}, - "ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "sin_dec": {"name": "SinDec"}, + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, }, waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, - jim_parameters={ - "n_loop_training": 10, - "n_loop_production": 10, - "n_local_steps": 15, - "n_global_steps": 15, - "n_chains": 500, - "n_epochs": 10, - "learning_rate": 0.001, - "n_max_examples": 45000, - "momentum": 0.9, - "batch_size": 50000, - "use_global": True, - "keep_quantile": 0.0, - "train_thinning": 1, - "output_thinning": 10, - "local_sampler_arg": local_sampler_arg, - }, - likelihood_parameters={"name": "TransientLikelihoodFD", "bounds": bounds}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + ], injection=True, injection_parameters={ "M_c": 28.6, @@ -77,22 +68,28 @@ "ra": 1.2, "dec": 0.3, }, - data_parameters={ - "trigger_time": 1126259462.4, - "duration": 4, - "post_trigger_duration": 2, - "f_min": 20.0, - "f_max": 1024.0, - "tukey_alpha": 0.2, - "f_sampling": 4096.0, + jim_parameters={ + "n_loop_training": 1, + "n_loop_production": 1, + "n_local_steps": 5, + "n_global_steps": 5, + "n_chains": 4, + "n_epochs": 2, + "learning_rate": 1e-4, + "n_max_examples": 30, + "momentum": 0.9, + "batch_size": 100, + "use_global": True, + "train_thinning": 1, + "output_thinning": 1, + "local_sampler_arg": local_sampler_arg, }, ) run_manager = SingleEventPERunManager(run=run) -run_manager.jim.sample(jax.random.PRNGKey(42)) +run_manager.sample(jax.random.PRNGKey(42)) # plot the corner plot and diagnostic plot run_manager.plot_corner() run_manager.plot_diagnostic() run_manager.save_summary() - diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 22e9e7e8..580fe6f0 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -373,7 +373,7 @@ def inject_signal( h_sky: dict[str, Float[Array, " n_sample"]], params: dict[str, Float], psd_file: str = "", - ) -> None: + ) -> tuple[Float, Float]: """ Inject a signal into the detector data. diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index f603527e..1c31bcf6 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -12,7 +12,9 @@ from jaxlib.xla_extension import ArrayImpl from jaxtyping import Array, Float, PyTree -from jimgw import prior +from jimgw import prior, transforms +from jimgw.single_event import prior as single_event_prior +from jimgw.single_event import transforms as single_event_transforms from jimgw.base import RunManager from jimgw.jim import Jim from jimgw.single_event.detector import Detector, detector_preset @@ -20,51 +22,12 @@ from jimgw.single_event.waveform import Waveform, waveform_preset - def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): return dumper.represent_list(data.tolist()) yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore -prior_presets = { - "Unconstrained_Uniform": prior.Unconstrained_Uniform, - "Uniform": prior.Uniform, - "Sphere": prior.Sphere, - "AlignedSpin": prior.AlignedSpin, - "PowerLaw": prior.PowerLaw, - "Composite": prior.Composite, - "MassRatio": lambda **kwargs: prior.Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, - ), - "CosIota": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos(params["cos_iota"]), - ) - }, - ), - "SinDec": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin(params["sin_dec"]), - ) - }, - ), - "EarthFrame": prior.EarthFrame, -} - @dataclass class SingleEventRun: @@ -75,7 +38,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - path: str = "./experiment" + path: str = "single_event_run" injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( @@ -135,7 +98,8 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) - self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) + sample_transforms, likelihood_transforms = self.initialize_transforms() + self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters), def save(self, path: str): output_dict = asdict(self.run) @@ -149,7 +113,7 @@ def load_from_path(self, path: str) -> SingleEventRun: ### Initialization functions ### - def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: + def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -205,23 +169,52 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: **self.run.data_parameters, ) - def initialize_prior(self) -> prior.Prior: + def initialize_prior(self) -> prior.CombinePrior: priors = [] for name, parameters in self.run.priors.items(): - if parameters["name"] not in prior_presets: - raise ValueError(f"Prior {name} not recognized.") - if parameters["name"] == "EarthFrame": - priors.append( - prior.EarthFrame( - gps=self.run.data_parameters["trigger_time"], - ifos=self.run.detectors, - ) - ) - else: - priors.append( - prior_presets[parameters["name"]](naming=[name], **parameters) - ) - return prior.Composite(priors) + assert isinstance(parameters, dict), "Prior parameters must be a dictionary." + assert "name" in parameters, "Prior name must be provided." + assert isinstance(parameters["name"], str), "Prior name must be a string." + try : + prior_class = getattr(single_event_prior, parameters["name"]) + except AttributeError: + try: + prior_class = getattr(prior, parameters["name"]) + except AttributeError: + raise ValueError(f"{parameters['name']} not recognized.") + parameters.pop("name") + priors.append(prior_class(parameter_names=[name], **parameters)) + return prior.CombinePrior(priors) + + def initialize_transforms(self) -> tuple[list[prior.BijectiveTransform], list[prior.NtoMTransform]]: + sample_transforms = [] + likelihood_transforms = [] + for transform in self.run.sample_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." + try: + transform_class = getattr(single_event_transforms, transform["name"]) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + sample_transforms.append(transform_class(**transform)) + for transform in self.run.likelihood_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." + try: + transform_class = getattr(single_event_transforms, transform["name"]) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + likelihood_transforms.append(transform_class(**transform)) def initialize_detector(self) -> list[Detector]: """ @@ -437,10 +430,11 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): plt.savefig(path) plt.close() - def save_summary(self, path: str = "run_manager_summary.txt", **kwargs): - sys.stdout = open(path,'wt') + def save_summary(self, path: str = None, **kwargs): + if path is None: + path = self.run.path + "run_manager_summary.txt" + sys.stdout = open(path, 'wt') self.jim.print_summary() - #print(self.SNRs) for detector, SNR in zip(self.detectors, self.SNRs): print('SNR of detector ' + detector + ' is ' + str(SNR)) networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5) From 1dbbd55af41e8bd67f08a6cc6dad19bbe64be461 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:02:09 +0800 Subject: [PATCH 5/7] Update runmanager --- example/Single_event_runManager.py | 21 ++-- src/jimgw/single_event/runManager.py | 55 ++++++----- test/integration/.gitignore | 3 + .../test_single_event_run_manager.py | 96 +++++++++++++++++++ 4 files changed, 140 insertions(+), 35 deletions(-) create mode 100644 test/integration/.gitignore create mode 100644 test/integration/test_single_event_run_manager.py diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 8ecba492..26280ead 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -53,6 +53,7 @@ {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, ], likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, ], injection=True, injection_parameters={ @@ -69,25 +70,25 @@ "dec": 0.3, }, jim_parameters={ - "n_loop_training": 1, - "n_loop_production": 1, - "n_local_steps": 5, - "n_global_steps": 5, - "n_chains": 4, - "n_epochs": 2, + "n_loop_training": 100, + "n_loop_production": 20, + "n_local_steps": 10, + "n_global_steps": 1000, + "n_chains": 500, + "n_epochs": 30, "learning_rate": 1e-4, - "n_max_examples": 30, + "n_max_examples": 30000, "momentum": 0.9, - "batch_size": 100, + "batch_size": 30000, "use_global": True, "train_thinning": 1, - "output_thinning": 1, + "output_thinning": 10, "local_sampler_arg": local_sampler_arg, }, ) run_manager = SingleEventPERunManager(run=run) -run_manager.sample(jax.random.PRNGKey(42)) +run_manager.sample() # plot the corner plot and diagnostic plot run_manager.plot_corner() diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 1c31bcf6..7d3b7619 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,6 +58,8 @@ class SingleEventRun: "f_sampling": 4096.0, } ) + sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) + likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) @@ -99,7 +101,7 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() - self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters), + self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters) def save(self, path: str): output_dict = asdict(self.run) @@ -186,35 +188,38 @@ def initialize_prior(self) -> prior.CombinePrior: priors.append(prior_class(parameter_names=[name], **parameters)) return prior.CombinePrior(priors) - def initialize_transforms(self) -> tuple[list[prior.BijectiveTransform], list[prior.NtoMTransform]]: + def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: sample_transforms = [] likelihood_transforms = [] - for transform in self.run.sample_transforms: - assert isinstance(transform, dict), "Transform must be a dictionary." - assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." - try: - transform_class = getattr(single_event_transforms, transform["name"]) - except AttributeError: + if self.run.sample_transforms: + for transform in self.run.sample_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." try: - transform_class = getattr(transforms, transform["name"]) + transform_class = getattr(single_event_transforms, transform["name"]) except AttributeError: - raise ValueError(f"{transform['name']} not recognized.") - transform.pop("name") - sample_transforms.append(transform_class(**transform)) - for transform in self.run.likelihood_transforms: - assert isinstance(transform, dict), "Transform must be a dictionary." - assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." - try: - transform_class = getattr(single_event_transforms, transform["name"]) - except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + sample_transforms.append(transform_class(**transform)) + if self.run.likelihood_transforms: + for transform in self.run.likelihood_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." try: - transform_class = getattr(transforms, transform["name"]) + transform_class = getattr(single_event_transforms, transform["name"]) except AttributeError: - raise ValueError(f"{transform['name']} not recognized.") - transform.pop("name") - likelihood_transforms.append(transform_class(**transform)) + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + likelihood_transforms.append(transform_class(**transform)) + return sample_transforms, likelihood_transforms def initialize_detector(self) -> list[Detector]: """ @@ -396,7 +401,7 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): """ plot diagnostic plot of the samples. """ - summary = self.jim.Sampler.get_sampler_state(training=True) + summary = self.jim.sampler.get_sampler_state(training=True) chains, log_prob, local_accs, global_accs, loss_vals = summary.values() log_prob = np.array(log_prob) diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a0783193 --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,3 @@ +*.txt +*.jpeg +*.jpg diff --git a/test/integration/test_single_event_run_manager.py b/test/integration/test_single_event_run_manager.py new file mode 100644 index 00000000..48bec85f --- /dev/null +++ b/test/integration/test_single_event_run_manager.py @@ -0,0 +1,96 @@ +import jax +import jax.numpy as jnp + +from jimgw.single_event.runManager import (SingleEventPERunManager, + SingleEventRun) + +jax.config.update("jax_enable_x64", True) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} + +run = SingleEventRun( + seed=0, + detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, + priors={ + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, + }, + waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, + ], + injection=True, + injection_parameters={ + "M_c": 28.6, + "eta": 0.24, + "s1_z": 0.05, + "s2_z": 0.05, + "d_L": 440.0, + "t_c": 0.0, + "phase_c": 0.0, + "iota": 0.5, + "psi": 0.7, + "ra": 1.2, + "dec": 0.3, + }, + jim_parameters={ + "n_loop_training": 1, + "n_loop_production": 1, + "n_local_steps": 5, + "n_global_steps": 5, + "n_chains": 4, + "n_epochs": 2, + "learning_rate": 1e-4, + "n_max_examples": 30, + "momentum": 0.9, + "batch_size": 100, + "use_global": True, + "train_thinning": 1, + "output_thinning": 1, + "local_sampler_arg": local_sampler_arg, + }, +) + +run_manager = SingleEventPERunManager(run=run) +run_manager.sample() + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() +run_manager.save_summary() From cb07d4eb2d8d5f42971d9fbdc2c2b703cb59f6e5 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:12:30 +0800 Subject: [PATCH 6/7] Reformat --- src/jimgw/single_event/runManager.py | 57 +++++++++++++++++++--------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 7d3b7619..034ae111 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,9 +58,12 @@ class SingleEventRun: "f_sampling": 4096.0, } ) - sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) - likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) - + sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) + likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) class SingleEventPERunManager(RunManager): @@ -101,7 +104,13 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() - self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters) + self.jim = Jim( + local_likelihood, + local_prior, + sample_transforms, + likelihood_transforms, + **self.run.jim_parameters, + ) def save(self, path: str): output_dict = asdict(self.run) @@ -158,11 +167,11 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901)) SNRs = [] for detector in detectors: - optimal_SNR,_ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + optimal_SNR, _ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore SNRs.append(optimal_SNR) key, subkey = jax.random.split(key) self.SNRs = SNRs - + return likelihood_presets[name]( detectors, waveform, @@ -174,10 +183,12 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho def initialize_prior(self) -> prior.CombinePrior: priors = [] for name, parameters in self.run.priors.items(): - assert isinstance(parameters, dict), "Prior parameters must be a dictionary." + assert isinstance( + parameters, dict + ), "Prior parameters must be a dictionary." assert "name" in parameters, "Prior name must be provided." assert isinstance(parameters["name"], str), "Prior name must be a string." - try : + try: prior_class = getattr(single_event_prior, parameters["name"]) except AttributeError: try: @@ -187,17 +198,23 @@ def initialize_prior(self) -> prior.CombinePrior: parameters.pop("name") priors.append(prior_class(parameter_names=[name], **parameters)) return prior.CombinePrior(priors) - - def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: + + def initialize_transforms( + self, + ) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: sample_transforms = [] likelihood_transforms = [] if self.run.sample_transforms: for transform in self.run.sample_transforms: assert isinstance(transform, dict), "Transform must be a dictionary." assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." try: - transform_class = getattr(single_event_transforms, transform["name"]) + transform_class = getattr( + single_event_transforms, transform["name"] + ) except AttributeError: try: transform_class = getattr(transforms, transform["name"]) @@ -209,9 +226,13 @@ def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], li for transform in self.run.likelihood_transforms: assert isinstance(transform, dict), "Transform must be a dictionary." assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." try: - transform_class = getattr(single_event_transforms, transform["name"]) + transform_class = getattr( + single_event_transforms, transform["name"] + ) except AttributeError: try: transform_class = getattr(transforms, transform["name"]) @@ -438,9 +459,9 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): def save_summary(self, path: str = None, **kwargs): if path is None: path = self.run.path + "run_manager_summary.txt" - sys.stdout = open(path, 'wt') + sys.stdout = open(path, "wt") self.jim.print_summary() for detector, SNR in zip(self.detectors, self.SNRs): - print('SNR of detector ' + detector + ' is ' + str(SNR)) - networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5) - print('network SNR is', networkSNR) + print("SNR of detector " + detector + " is " + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) + print("network SNR is", networkSNR) From e5b6e848e0cfbe0be1171f1565f891cf14d85078 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:20:44 +0800 Subject: [PATCH 7/7] Change default path --- src/jimgw/single_event/runManager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 034ae111..c1439b39 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -456,8 +456,8 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): plt.savefig(path) plt.close() - def save_summary(self, path: str = None, **kwargs): - if path is None: + def save_summary(self, path: str = "", **kwargs): + if path == "": path = self.run.path + "run_manager_summary.txt" sys.stdout = open(path, "wt") self.jim.print_summary()