From ee08171aeb855c92a1c9dfdbcefb65ceea15e1db Mon Sep 17 00:00:00 2001 From: zipengwang98 Date: Wed, 31 Jul 2024 12:24:52 -0400 Subject: [PATCH 01/27] summary output in run manager --- example/Single_event_runManager.py | 9 +++++---- src/jimgw/single_event/detector.py | 4 +++- src/jimgw/single_event/runManager.py | 19 ++++++++++++++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 5c678b22..8885ff3f 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -48,10 +48,10 @@ jim_parameters={ "n_loop_training": 10, "n_loop_production": 10, - "n_local_steps": 150, - "n_global_steps": 150, + "n_local_steps": 15, + "n_global_steps": 15, "n_chains": 500, - "n_epochs": 50, + "n_epochs": 10, "learning_rate": 0.001, "n_max_examples": 45000, "momentum": 0.9, @@ -62,7 +62,7 @@ "output_thinning": 10, "local_sampler_arg": local_sampler_arg, }, - likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds}, + likelihood_parameters={"name": "TransientLikelihoodFD", "bounds": bounds}, injection=True, injection_parameters={ "M_c": 28.6, @@ -94,4 +94,5 @@ # plot the corner plot and diagnostic plot run_manager.plot_corner() run_manager.plot_diagnostic() +run_manager.save_summary() diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 6c3079cf..22e9e7e8 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -392,7 +392,7 @@ def inject_signal( Returns ------- - None + SNR """ self.frequencies = freqs self.psd = self.load_psd(freqs, psd_file) @@ -415,6 +415,8 @@ def inject_signal( print(f"The injected optimal SNR is {optimal_SNR}") print(f"The injected match filter SNR is {match_filter_SNR}") + return optimal_SNR, match_filter_SNR + @jaxtyped(typechecker=typechecker) def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 7b720c30..f603527e 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import corner +import sys import numpy as np import yaml from astropy.time import Time @@ -19,6 +20,7 @@ from jimgw.single_event.waveform import Waveform, waveform_preset + def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): return dumper.represent_list(data.tolist()) @@ -95,9 +97,11 @@ class SingleEventRun: ) + class SingleEventPERunManager(RunManager): run: SingleEventRun jim: Jim + SNRs: list[float] @property def waveform(self): @@ -186,9 +190,13 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: - self.run.data_parameters["post_trigger_duration"], } key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901)) + SNRs = [] for detector in detectors: - detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + optimal_SNR,_ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + SNRs.append(optimal_SNR) key, subkey = jax.random.split(key) + self.SNRs = SNRs + return likelihood_presets[name]( detectors, waveform, @@ -428,3 +436,12 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): plt.savefig(path) plt.close() + + def save_summary(self, path: str = "run_manager_summary.txt", **kwargs): + sys.stdout = open(path,'wt') + self.jim.print_summary() + #print(self.SNRs) + for detector, SNR in zip(self.detectors, self.SNRs): + print('SNR of detector ' + detector + ' is ' + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5) + print('network SNR is', networkSNR) From 1c2160595a4c5ce016c71eb1f65f489d9d9dc9d3 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 10:58:46 +0800 Subject: [PATCH 02/27] Remove legacy test files --- test/test_prior.py | 1 - test/test_transform.py | 48 ------------------------------------------ 2 files changed, 49 deletions(-) delete mode 100644 test/test_prior.py delete mode 100644 test/test_transform.py diff --git a/test/test_prior.py b/test/test_prior.py deleted file mode 100644 index 8b137891..00000000 --- a/test/test_prior.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/test_transform.py b/test/test_transform.py deleted file mode 100644 index 5d9fe464..00000000 --- a/test/test_transform.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import jax.numpy as jnp - -class TestTransform: - def test_sky_location_transform(self): - from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky - from bilby.gw.detector.networks import InterferometerList - - from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky - from jimgw.single_event.detector import detector_preset - from astropy.time import Time - - ifos = ["H1", "L1"] - geocent_time = 1000000000 - - import matplotlib.pyplot as plt - - for zenith in np.linspace(0, np.pi, 10): - for azimuth in np.linspace(0, 2*np.pi, 10): - bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) - jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) - assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) - - def test_spin_transform(self): - from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform - from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses - - from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform - - for _ in range(100): - thetaJN = jnp.array(np.random.uniform(0, np.pi)) - phiJL = jnp.array(np.random.uniform(0, np.pi)) - theta1 = jnp.array(np.random.uniform(0, np.pi)) - theta2 = jnp.array(np.random.uniform(0, np.pi)) - phi12 = jnp.array(np.random.uniform(0, np.pi)) - chi1 = jnp.array(np.random.uniform(0, 1)) - chi2 = jnp.array(np.random.uniform(0, 1)) - M_c = jnp.array(np.random.uniform(1, 100)) - eta = jnp.array(np.random.uniform(0.1, 0.25)) - fRef = jnp.array(np.random.uniform(10, 1000)) - phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) - - q = symmetric_mass_ratio_to_mass_ratio(eta) - m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) - MsunInkg = 1.9884e30 - bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) - jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) - assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4) From 19c4a29892ad149149761d789d7ee3e4da49f3ff Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 11:08:52 +0800 Subject: [PATCH 03/27] Update utils.py --- src/jimgw/single_event/utils.py | 159 -------------------------------- 1 file changed, 159 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index fb35bf27..a15bd7bf 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -276,165 +276,6 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 -def spin_to_cartesian_spin( - thetaJN: Float, - phiJL: Float, - theta1: Float, - theta2: Float, - phi12: Float, - chi1: Float, - chi2: Float, - M_c: Float, - eta: Float, - fRef: Float, - phiRef: Float, -) -> tuple[Float, Float, Float, Float, Float, Float, Float]: - """ - Transforming the spin parameters - - The code is based on the approach used in LALsimulation: - https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html - - Parameters: - ------- - thetaJN: Float - Zenith angle between the total angular momentum and the line of sight - phiJL: Float - Difference between total and orbital angular momentum azimuthal angles - theta1: Float - Zenith angle between the spin and orbital angular momenta for the primary object - theta2: Float - Zenith angle between the spin and orbital angular momenta for the secondary object - phi12: Float - Difference between the azimuthal angles of the individual spin vector projections - onto the orbital plane - chi1: Float - Primary object aligned spin: - chi2: Float - Secondary object aligned spin: - M_c: Float - The chirp mass - eta: Float - The symmetric mass ratio - fRef: Float - The reference frequency - phiRef: Float - Binary phase at a reference frequency - - Returns: - ------- - iota: Float - Zenith angle between the orbital angular momentum and the line of sight - S1x: Float - The x-component of the primary spin - S1y: Float - The y-component of the primary spin - S1z: Float - The z-component of the primary spin - S2x: Float - The x-component of the secondary spin - S2y: Float - The y-component of the secondary spin - S2z: Float - The z-component of the secondary spin - """ - - def rotate_y(angle, vec): - """ - Rotate the vector (x, y, z) about y-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - def rotate_z(angle, vec): - """ - Rotate the vector (x, y, z) about z-axis - """ - cos_angle = jnp.cos(angle) - sin_angle = jnp.sin(angle) - rotation_matrix = jnp.array( - [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] - ) - rotated_vec = jnp.dot(rotation_matrix, vec) - return rotated_vec - - LNh = jnp.array([0.0, 0.0, 1.0]) - - s1hat = jnp.array( - [ - jnp.sin(theta1) * jnp.cos(phiRef), - jnp.sin(theta1) * jnp.sin(phiRef), - jnp.cos(theta1), - ] - ) - s2hat = jnp.array( - [ - jnp.sin(theta2) * jnp.cos(phi12 + phiRef), - jnp.sin(theta2) * jnp.sin(phi12 + phiRef), - jnp.cos(theta2), - ] - ) - - temp = 1 / eta / 2 - 1 - q = temp - (temp**2 - 1) ** 0.5 - m1, m2 = Mc_q_to_m1m2(M_c, q) - v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) - - Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) - s1 = m1 * m1 * chi1 * s1hat - s2 = m2 * m2 * chi2 * s2hat - J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) - - Jhat = J / jnp.linalg.norm(J) - theta0 = jnp.arccos(Jhat[2]) - phi0 = jnp.arctan2(Jhat[1], Jhat[0]) - - # Rotation 1: - s1hat = rotate_z(-phi0, s1hat) - s2hat = rotate_z(-phi0, s2hat) - - # Rotation 2: - LNh = rotate_y(-theta0, LNh) - s1hat = rotate_y(-theta0, s1hat) - s2hat = rotate_y(-theta0, s2hat) - - # Rotation 3: - LNh = rotate_z(phiJL - jnp.pi, LNh) - s1hat = rotate_z(phiJL - jnp.pi, s1hat) - s2hat = rotate_z(phiJL - jnp.pi, s2hat) - - # Compute iota - N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) - iota = jnp.arccos(jnp.dot(N, LNh)) - - thetaLJ = jnp.arccos(LNh[2]) - phiL = jnp.arctan2(LNh[1], LNh[0]) - - # Rotation 4: - s1hat = rotate_z(-phiL, s1hat) - s2hat = rotate_z(-phiL, s2hat) - N = rotate_z(-phiL, N) - - # Rotation 5: - s1hat = rotate_y(-thetaLJ, s1hat) - s2hat = rotate_y(-thetaLJ, s2hat) - N = rotate_y(-thetaLJ, N) - - # Rotation 6: - phiN = jnp.arctan2(N[1], N[0]) - s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) - s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) - - S1 = s1hat * chi1 - S2 = s2hat * chi2 - return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x From 169aaa29e38914e5e17a90a11ea562837687599f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 11:16:49 +0800 Subject: [PATCH 04/27] Reformat --- src/jimgw/single_event/utils.py | 174 ++++++++++++++++---------------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..62e8ad6d 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -393,6 +393,93 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F return ra, dec +def zenith_azimuth_to_ra_dec( + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Copied and modified from bilby/gw/utils.py + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + theta, phi = angle_rotation(zenith, azimuth, rotation) + ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) + return ra, dec + + +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth + + def spin_to_cartesian_spin( thetaJN: Float, phiJL: Float, @@ -549,90 +636,3 @@ def rotate_z(angle, vec): S1 = s1hat * chi1 S2 = s2hat * chi2 return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - -def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. - - Parameters - ---------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Copied and modified from bilby/gw/utils.py - - Returns - ------- - ra : Float - Right ascension. - dec : Float - Declination. - """ - theta, phi = angle_rotation(zenith, azimuth, rotation) - ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - return ra, dec - - -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) - return theta, phi - - -def ra_dec_to_zenith_azimuth( - ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the right ascension and declination to the zenith angle and azimuthal angle. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Returns - ------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - """ - theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) - zenith, azimuth = angle_rotation(theta, phi, rotation) - return zenith, azimuth From 0f687f7e06d4554831210289a7c66ea2064480ba Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 14:35:54 +0800 Subject: [PATCH 05/27] Update runmanger --- example/Single_event_runManager.py | 107 ++++++++++++------------ src/jimgw/single_event/detector.py | 2 +- src/jimgw/single_event/runManager.py | 116 +++++++++++++-------------- 3 files changed, 108 insertions(+), 117 deletions(-) diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 8885ff3f..8ecba492 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -1,4 +1,3 @@ - import jax import jax.numpy as jnp @@ -12,57 +11,49 @@ mass_matrix = mass_matrix.at[5, 5].set(1e-3) mass_matrix = mass_matrix * 3e-3 local_sampler_arg = {"step_size": mass_matrix} -bounds = jnp.array( - [ - [10.0, 40.0], - [0.125, 1.0], - [-1.0, 1.0], - [-1.0, 1.0], - [0.0, 2000.0], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) run = SingleEventRun( seed=0, detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, priors={ - "M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0}, - "q": {"name": "MassRatio"}, - "s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, - "s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, - "d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0}, - "t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05}, - "phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "cos_iota": {"name": "CosIota"}, - "psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi}, - "ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "sin_dec": {"name": "SinDec"}, + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, }, waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, - jim_parameters={ - "n_loop_training": 10, - "n_loop_production": 10, - "n_local_steps": 15, - "n_global_steps": 15, - "n_chains": 500, - "n_epochs": 10, - "learning_rate": 0.001, - "n_max_examples": 45000, - "momentum": 0.9, - "batch_size": 50000, - "use_global": True, - "keep_quantile": 0.0, - "train_thinning": 1, - "output_thinning": 10, - "local_sampler_arg": local_sampler_arg, - }, - likelihood_parameters={"name": "TransientLikelihoodFD", "bounds": bounds}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + ], injection=True, injection_parameters={ "M_c": 28.6, @@ -77,22 +68,28 @@ "ra": 1.2, "dec": 0.3, }, - data_parameters={ - "trigger_time": 1126259462.4, - "duration": 4, - "post_trigger_duration": 2, - "f_min": 20.0, - "f_max": 1024.0, - "tukey_alpha": 0.2, - "f_sampling": 4096.0, + jim_parameters={ + "n_loop_training": 1, + "n_loop_production": 1, + "n_local_steps": 5, + "n_global_steps": 5, + "n_chains": 4, + "n_epochs": 2, + "learning_rate": 1e-4, + "n_max_examples": 30, + "momentum": 0.9, + "batch_size": 100, + "use_global": True, + "train_thinning": 1, + "output_thinning": 1, + "local_sampler_arg": local_sampler_arg, }, ) run_manager = SingleEventPERunManager(run=run) -run_manager.jim.sample(jax.random.PRNGKey(42)) +run_manager.sample(jax.random.PRNGKey(42)) # plot the corner plot and diagnostic plot run_manager.plot_corner() run_manager.plot_diagnostic() run_manager.save_summary() - diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 22e9e7e8..580fe6f0 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -373,7 +373,7 @@ def inject_signal( h_sky: dict[str, Float[Array, " n_sample"]], params: dict[str, Float], psd_file: str = "", - ) -> None: + ) -> tuple[Float, Float]: """ Inject a signal into the detector data. diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index f603527e..1c31bcf6 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -12,7 +12,9 @@ from jaxlib.xla_extension import ArrayImpl from jaxtyping import Array, Float, PyTree -from jimgw import prior +from jimgw import prior, transforms +from jimgw.single_event import prior as single_event_prior +from jimgw.single_event import transforms as single_event_transforms from jimgw.base import RunManager from jimgw.jim import Jim from jimgw.single_event.detector import Detector, detector_preset @@ -20,51 +22,12 @@ from jimgw.single_event.waveform import Waveform, waveform_preset - def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): return dumper.represent_list(data.tolist()) yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore -prior_presets = { - "Unconstrained_Uniform": prior.Unconstrained_Uniform, - "Uniform": prior.Uniform, - "Sphere": prior.Sphere, - "AlignedSpin": prior.AlignedSpin, - "PowerLaw": prior.PowerLaw, - "Composite": prior.Composite, - "MassRatio": lambda **kwargs: prior.Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, - ), - "CosIota": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos(params["cos_iota"]), - ) - }, - ), - "SinDec": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin(params["sin_dec"]), - ) - }, - ), - "EarthFrame": prior.EarthFrame, -} - @dataclass class SingleEventRun: @@ -75,7 +38,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - path: str = "./experiment" + path: str = "single_event_run" injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( @@ -135,7 +98,8 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) - self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) + sample_transforms, likelihood_transforms = self.initialize_transforms() + self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters), def save(self, path: str): output_dict = asdict(self.run) @@ -149,7 +113,7 @@ def load_from_path(self, path: str) -> SingleEventRun: ### Initialization functions ### - def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: + def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -205,23 +169,52 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: **self.run.data_parameters, ) - def initialize_prior(self) -> prior.Prior: + def initialize_prior(self) -> prior.CombinePrior: priors = [] for name, parameters in self.run.priors.items(): - if parameters["name"] not in prior_presets: - raise ValueError(f"Prior {name} not recognized.") - if parameters["name"] == "EarthFrame": - priors.append( - prior.EarthFrame( - gps=self.run.data_parameters["trigger_time"], - ifos=self.run.detectors, - ) - ) - else: - priors.append( - prior_presets[parameters["name"]](naming=[name], **parameters) - ) - return prior.Composite(priors) + assert isinstance(parameters, dict), "Prior parameters must be a dictionary." + assert "name" in parameters, "Prior name must be provided." + assert isinstance(parameters["name"], str), "Prior name must be a string." + try : + prior_class = getattr(single_event_prior, parameters["name"]) + except AttributeError: + try: + prior_class = getattr(prior, parameters["name"]) + except AttributeError: + raise ValueError(f"{parameters['name']} not recognized.") + parameters.pop("name") + priors.append(prior_class(parameter_names=[name], **parameters)) + return prior.CombinePrior(priors) + + def initialize_transforms(self) -> tuple[list[prior.BijectiveTransform], list[prior.NtoMTransform]]: + sample_transforms = [] + likelihood_transforms = [] + for transform in self.run.sample_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." + try: + transform_class = getattr(single_event_transforms, transform["name"]) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + sample_transforms.append(transform_class(**transform)) + for transform in self.run.likelihood_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." + try: + transform_class = getattr(single_event_transforms, transform["name"]) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + likelihood_transforms.append(transform_class(**transform)) def initialize_detector(self) -> list[Detector]: """ @@ -437,10 +430,11 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): plt.savefig(path) plt.close() - def save_summary(self, path: str = "run_manager_summary.txt", **kwargs): - sys.stdout = open(path,'wt') + def save_summary(self, path: str = None, **kwargs): + if path is None: + path = self.run.path + "run_manager_summary.txt" + sys.stdout = open(path, 'wt') self.jim.print_summary() - #print(self.SNRs) for detector, SNR in zip(self.detectors, self.SNRs): print('SNR of detector ' + detector + ' is ' + str(SNR)) networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5) From 1dbbd55af41e8bd67f08a6cc6dad19bbe64be461 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:02:09 +0800 Subject: [PATCH 06/27] Update runmanager --- example/Single_event_runManager.py | 21 ++-- src/jimgw/single_event/runManager.py | 55 ++++++----- test/integration/.gitignore | 3 + .../test_single_event_run_manager.py | 96 +++++++++++++++++++ 4 files changed, 140 insertions(+), 35 deletions(-) create mode 100644 test/integration/.gitignore create mode 100644 test/integration/test_single_event_run_manager.py diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 8ecba492..26280ead 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -53,6 +53,7 @@ {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, ], likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, ], injection=True, injection_parameters={ @@ -69,25 +70,25 @@ "dec": 0.3, }, jim_parameters={ - "n_loop_training": 1, - "n_loop_production": 1, - "n_local_steps": 5, - "n_global_steps": 5, - "n_chains": 4, - "n_epochs": 2, + "n_loop_training": 100, + "n_loop_production": 20, + "n_local_steps": 10, + "n_global_steps": 1000, + "n_chains": 500, + "n_epochs": 30, "learning_rate": 1e-4, - "n_max_examples": 30, + "n_max_examples": 30000, "momentum": 0.9, - "batch_size": 100, + "batch_size": 30000, "use_global": True, "train_thinning": 1, - "output_thinning": 1, + "output_thinning": 10, "local_sampler_arg": local_sampler_arg, }, ) run_manager = SingleEventPERunManager(run=run) -run_manager.sample(jax.random.PRNGKey(42)) +run_manager.sample() # plot the corner plot and diagnostic plot run_manager.plot_corner() diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 1c31bcf6..7d3b7619 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,6 +58,8 @@ class SingleEventRun: "f_sampling": 4096.0, } ) + sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) + likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) @@ -99,7 +101,7 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() - self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters), + self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters) def save(self, path: str): output_dict = asdict(self.run) @@ -186,35 +188,38 @@ def initialize_prior(self) -> prior.CombinePrior: priors.append(prior_class(parameter_names=[name], **parameters)) return prior.CombinePrior(priors) - def initialize_transforms(self) -> tuple[list[prior.BijectiveTransform], list[prior.NtoMTransform]]: + def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: sample_transforms = [] likelihood_transforms = [] - for transform in self.run.sample_transforms: - assert isinstance(transform, dict), "Transform must be a dictionary." - assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." - try: - transform_class = getattr(single_event_transforms, transform["name"]) - except AttributeError: + if self.run.sample_transforms: + for transform in self.run.sample_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." try: - transform_class = getattr(transforms, transform["name"]) + transform_class = getattr(single_event_transforms, transform["name"]) except AttributeError: - raise ValueError(f"{transform['name']} not recognized.") - transform.pop("name") - sample_transforms.append(transform_class(**transform)) - for transform in self.run.likelihood_transforms: - assert isinstance(transform, dict), "Transform must be a dictionary." - assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." - try: - transform_class = getattr(single_event_transforms, transform["name"]) - except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + sample_transforms.append(transform_class(**transform)) + if self.run.likelihood_transforms: + for transform in self.run.likelihood_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance(transform["name"], str), "Transform name must be a string." try: - transform_class = getattr(transforms, transform["name"]) + transform_class = getattr(single_event_transforms, transform["name"]) except AttributeError: - raise ValueError(f"{transform['name']} not recognized.") - transform.pop("name") - likelihood_transforms.append(transform_class(**transform)) + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform.pop("name") + likelihood_transforms.append(transform_class(**transform)) + return sample_transforms, likelihood_transforms def initialize_detector(self) -> list[Detector]: """ @@ -396,7 +401,7 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): """ plot diagnostic plot of the samples. """ - summary = self.jim.Sampler.get_sampler_state(training=True) + summary = self.jim.sampler.get_sampler_state(training=True) chains, log_prob, local_accs, global_accs, loss_vals = summary.values() log_prob = np.array(log_prob) diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a0783193 --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,3 @@ +*.txt +*.jpeg +*.jpg diff --git a/test/integration/test_single_event_run_manager.py b/test/integration/test_single_event_run_manager.py new file mode 100644 index 00000000..48bec85f --- /dev/null +++ b/test/integration/test_single_event_run_manager.py @@ -0,0 +1,96 @@ +import jax +import jax.numpy as jnp + +from jimgw.single_event.runManager import (SingleEventPERunManager, + SingleEventRun) + +jax.config.update("jax_enable_x64", True) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} + +run = SingleEventRun( + seed=0, + detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, + priors={ + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, + }, + waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, + ], + injection=True, + injection_parameters={ + "M_c": 28.6, + "eta": 0.24, + "s1_z": 0.05, + "s2_z": 0.05, + "d_L": 440.0, + "t_c": 0.0, + "phase_c": 0.0, + "iota": 0.5, + "psi": 0.7, + "ra": 1.2, + "dec": 0.3, + }, + jim_parameters={ + "n_loop_training": 1, + "n_loop_production": 1, + "n_local_steps": 5, + "n_global_steps": 5, + "n_chains": 4, + "n_epochs": 2, + "learning_rate": 1e-4, + "n_max_examples": 30, + "momentum": 0.9, + "batch_size": 100, + "use_global": True, + "train_thinning": 1, + "output_thinning": 1, + "local_sampler_arg": local_sampler_arg, + }, +) + +run_manager = SingleEventPERunManager(run=run) +run_manager.sample() + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() +run_manager.save_summary() From cb07d4eb2d8d5f42971d9fbdc2c2b703cb59f6e5 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:12:30 +0800 Subject: [PATCH 07/27] Reformat --- src/jimgw/single_event/runManager.py | 57 +++++++++++++++++++--------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 7d3b7619..034ae111 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -58,9 +58,12 @@ class SingleEventRun: "f_sampling": 4096.0, } ) - sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) - likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(default_factory=lambda: []) - + sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) + likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) class SingleEventPERunManager(RunManager): @@ -101,7 +104,13 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() - self.jim = Jim(local_likelihood, local_prior, sample_transforms, likelihood_transforms, **self.run.jim_parameters) + self.jim = Jim( + local_likelihood, + local_prior, + sample_transforms, + likelihood_transforms, + **self.run.jim_parameters, + ) def save(self, path: str): output_dict = asdict(self.run) @@ -158,11 +167,11 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901)) SNRs = [] for detector in detectors: - optimal_SNR,_ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + optimal_SNR, _ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore SNRs.append(optimal_SNR) key, subkey = jax.random.split(key) self.SNRs = SNRs - + return likelihood_presets[name]( detectors, waveform, @@ -174,10 +183,12 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho def initialize_prior(self) -> prior.CombinePrior: priors = [] for name, parameters in self.run.priors.items(): - assert isinstance(parameters, dict), "Prior parameters must be a dictionary." + assert isinstance( + parameters, dict + ), "Prior parameters must be a dictionary." assert "name" in parameters, "Prior name must be provided." assert isinstance(parameters["name"], str), "Prior name must be a string." - try : + try: prior_class = getattr(single_event_prior, parameters["name"]) except AttributeError: try: @@ -187,17 +198,23 @@ def initialize_prior(self) -> prior.CombinePrior: parameters.pop("name") priors.append(prior_class(parameter_names=[name], **parameters)) return prior.CombinePrior(priors) - - def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: + + def initialize_transforms( + self, + ) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: sample_transforms = [] likelihood_transforms = [] if self.run.sample_transforms: for transform in self.run.sample_transforms: assert isinstance(transform, dict), "Transform must be a dictionary." assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." try: - transform_class = getattr(single_event_transforms, transform["name"]) + transform_class = getattr( + single_event_transforms, transform["name"] + ) except AttributeError: try: transform_class = getattr(transforms, transform["name"]) @@ -209,9 +226,13 @@ def initialize_transforms(self) -> tuple[list[transforms.BijectiveTransform], li for transform in self.run.likelihood_transforms: assert isinstance(transform, dict), "Transform must be a dictionary." assert "name" in transform, "Transform name must be provided." - assert isinstance(transform["name"], str), "Transform name must be a string." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." try: - transform_class = getattr(single_event_transforms, transform["name"]) + transform_class = getattr( + single_event_transforms, transform["name"] + ) except AttributeError: try: transform_class = getattr(transforms, transform["name"]) @@ -438,9 +459,9 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): def save_summary(self, path: str = None, **kwargs): if path is None: path = self.run.path + "run_manager_summary.txt" - sys.stdout = open(path, 'wt') + sys.stdout = open(path, "wt") self.jim.print_summary() for detector, SNR in zip(self.detectors, self.SNRs): - print('SNR of detector ' + detector + ' is ' + str(SNR)) - networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5) - print('network SNR is', networkSNR) + print("SNR of detector " + detector + " is " + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) + print("network SNR is", networkSNR) From e5b6e848e0cfbe0be1171f1565f891cf14d85078 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 21 Aug 2024 15:20:44 +0800 Subject: [PATCH 08/27] Change default path --- src/jimgw/single_event/runManager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 034ae111..c1439b39 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -456,8 +456,8 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): plt.savefig(path) plt.close() - def save_summary(self, path: str = None, **kwargs): - if path is None: + def save_summary(self, path: str = "", **kwargs): + if path == "": path = self.run.path + "run_manager_summary.txt" sys.stdout = open(path, "wt") self.jim.print_summary() From b1133d36f4e6ff20667b30be9d7a91d58419d562 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:06 +0800 Subject: [PATCH 09/27] Update runManager.py --- src/jimgw/single_event/runManager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index aa8d0dc7..0a4b502d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,9 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -125,9 +123,6 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError - if self.run.injection and not self.run.injection_parameters: - raise ValueError("Injection mode requires injection parameters.") - local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From 2a9d696a20d28fdd1693698a1840fef9ab276578 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:49 +0800 Subject: [PATCH 10/27] Update runManager.py --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 0a4b502d..3f65166d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} From ec1c90f04b0b142d9a67d8d99c718a2752d37730 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 29 Aug 2024 23:53:08 +0800 Subject: [PATCH 11/27] Fix bug in initializing likelihood --- src/jimgw/single_event/runManager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index c1439b39..eb1a12de 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -102,8 +102,8 @@ def __init__(self, **kwargs): raise ValueError("Injection mode requires injection parameters.") local_prior = self.initialize_prior() - local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() + local_likelihood = self.initialize_likelihood(local_prior, sample_transforms, likelihood_transforms) self.jim = Jim( local_likelihood, local_prior, @@ -124,7 +124,7 @@ def load_from_path(self, path: str) -> SingleEventRun: ### Initialization functions ### - def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood: + def initialize_likelihood(self, prior: prior.CombinePrior, sample_transforms: transforms.Transform, likelihood_transforms: transforms.Transform) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -176,6 +176,8 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho detectors, waveform, prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, **self.run.likelihood_parameters, **self.run.data_parameters, ) From fd63c8bda9e05158e0fd8c4cff78da94864b87d7 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 29 Aug 2024 23:55:09 +0800 Subject: [PATCH 12/27] Fix bug in save_summary() --- src/jimgw/single_event/runManager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index eb1a12de..e8cc65ef 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -463,7 +463,9 @@ def save_summary(self, path: str = "", **kwargs): path = self.run.path + "run_manager_summary.txt" sys.stdout = open(path, "wt") self.jim.print_summary() - for detector, SNR in zip(self.detectors, self.SNRs): - print("SNR of detector " + detector + " is " + str(SNR)) - networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) - print("network SNR is", networkSNR) + if self.run.injection: + for detector, SNR in zip(self.detectors, self.SNRs): + print("SNR of detector " + detector + " is " + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) + print("network SNR is", networkSNR) + sys.stdout.close() From 77a75dc9d2955678e19b60e3be0b0597513fd177 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Fri, 30 Aug 2024 09:21:26 +0800 Subject: [PATCH 13/27] Fix bug --- src/jimgw/single_event/runManager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index c1439b39..3e5bc092 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -195,6 +195,7 @@ def initialize_prior(self) -> prior.CombinePrior: prior_class = getattr(prior, parameters["name"]) except AttributeError: raise ValueError(f"{parameters['name']} not recognized.") + parameters = parameters.copy() parameters.pop("name") priors.append(prior_class(parameter_names=[name], **parameters)) return prior.CombinePrior(priors) From 55a79e8c39afc08b6afc956f20680e0adc6d0bcd Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:51:51 +0800 Subject: [PATCH 14/27] Fix transform initialization --- src/jimgw/single_event/runManager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 88770263..ea232441 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -223,6 +223,7 @@ def initialize_transforms( transform_class = getattr(transforms, transform["name"]) except AttributeError: raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() transform.pop("name") sample_transforms.append(transform_class(**transform)) if self.run.likelihood_transforms: @@ -241,6 +242,7 @@ def initialize_transforms( transform_class = getattr(transforms, transform["name"]) except AttributeError: raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() transform.pop("name") likelihood_transforms.append(transform_class(**transform)) return sample_transforms, likelihood_transforms From f046a703fa789ab0d0e5ff2298e6c2321070761a Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:59:16 +0800 Subject: [PATCH 15/27] Fix stdout issue --- src/jimgw/single_event/runManager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index ea232441..e4d95e08 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -464,6 +464,7 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): def save_summary(self, path: str = "", **kwargs): if path == "": path = self.run.path + "run_manager_summary.txt" + orig_stdout = sys.stdout sys.stdout = open(path, "wt") self.jim.print_summary() if self.run.injection: @@ -472,3 +473,4 @@ def save_summary(self, path: str = "", **kwargs): networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) print("network SNR is", networkSNR) sys.stdout.close() + sys.stdout=orig_stdout From 5d1458ed335f08969d23c5f1321d9ffeec1ef9b9 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:06:59 +0800 Subject: [PATCH 16/27] Reformatted --- src/jimgw/single_event/runManager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index e4d95e08..2fa00cd9 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -103,7 +103,9 @@ def __init__(self, **kwargs): local_prior = self.initialize_prior() sample_transforms, likelihood_transforms = self.initialize_transforms() - local_likelihood = self.initialize_likelihood(local_prior, sample_transforms, likelihood_transforms) + local_likelihood = self.initialize_likelihood( + local_prior, sample_transforms, likelihood_transforms + ) self.jim = Jim( local_likelihood, local_prior, @@ -124,7 +126,12 @@ def load_from_path(self, path: str) -> SingleEventRun: ### Initialization functions ### - def initialize_likelihood(self, prior: prior.CombinePrior, sample_transforms: transforms.Transform, likelihood_transforms: transforms.Transform) -> SingleEventLiklihood: + def initialize_likelihood( + self, + prior: prior.CombinePrior, + sample_transforms: transforms.Transform, + likelihood_transforms: transforms.Transform, + ) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -473,4 +480,4 @@ def save_summary(self, path: str = "", **kwargs): networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) print("network SNR is", networkSNR) sys.stdout.close() - sys.stdout=orig_stdout + sys.stdout = orig_stdout From 522d7c036b0e9da04119dfb22b00c323e3cad731 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 30 Aug 2024 18:59:01 +0800 Subject: [PATCH 17/27] Added MultipleEventPERunManager in runManager.py --- src/jimgw/single_event/runManager.py | 46 ++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 2fa00cd9..3d583baa 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import corner import sys +import os import numpy as np import yaml from astropy.time import Time @@ -481,3 +482,48 @@ def save_summary(self, path: str = "", **kwargs): print("network SNR is", networkSNR) sys.stdout.close() sys.stdout = orig_stdout + + +class MultipleEventPERunManager: + """ + Class to manage multiple events run. + """ + run_config_path: str + output_path: str + + def __init__(self, run_config_path: str, output_path: str = "output") -> None: + """ + Arguments: + run_config_path (str): load the run configuration from the path. + output_path (str, optional): save the output to this path. Defaults to "output". + """ + + self.run_config_path = run_config_path + self.output_path = output_path + + def run(self, plot_corner: bool = True, plot_diagnostic: bool = True, save_summary: bool = True): + """ + Loop over all the configuration files in the run_config_path and run the PE for each configuration. + """ + + if plot_corner and not os.path.exists("corner_plots"): + os.makedirs("corner_plots") + if plot_diagnostic and not os.path.exists("diagnostic_plots"): + os.makedirs("diagnostic_plots") + if save_summary and not os.path.exists("summaries"): + os.makedirs("summaries") + + config_directory = os.fsencode(self.run_config_path) + for file in os.listdir(config_directory): + filename = os.fsdecode(file) + if filename.endswith(".yaml"): + config_path = os.path.join(self.run_config_path, filename) + run_manager = SingleEventPERunManager(path=config_path) + run_manager.sample() + + if plot_corner: + run_manager.plot_corner("corner_plots/" + filename + "_corner.jpeg") + if plot_diagnostic: + run_manager.plot_diagnostic("diagnostic_plots/" + filename + "_diagnostic.jpeg") + if save_summary: + run_manager.save_summary("summaries/" + filename + "_summary.txt") \ No newline at end of file From 06e22b77299128321710469aa220b9a33692483f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Fri, 30 Aug 2024 18:59:30 +0800 Subject: [PATCH 18/27] Added Multiple_event_runManager.py --- example/Multiple_event_runManager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 example/Multiple_event_runManager.py diff --git a/example/Multiple_event_runManager.py b/example/Multiple_event_runManager.py new file mode 100644 index 00000000..04967a2c --- /dev/null +++ b/example/Multiple_event_runManager.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp + +from jimgw.single_event.runManager import MultipleEventRunManager + +jax.config.update("jax_enable_x64", True) + +run_manager = MultipleEventRunManager( + run_config_path="config", +) + +run_manager.run() From 186b9d9889ccb767b88b433541f2ba1c4725ddea Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Sat, 31 Aug 2024 09:54:57 +0800 Subject: [PATCH 19/27] Fix bug in load_from_path --- src/jimgw/single_event/runManager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3d583baa..1a54520d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -123,6 +123,8 @@ def save(self, path: str): def load_from_path(self, path: str) -> SingleEventRun: with open(path, "r") as f: data = yaml.safe_load(f) + if "jim_parameters" in data and "local_sampler_arg" in data["jim_parameters"] and "step_size" in data["jim_parameters"]["local_sampler_arg"]: + data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array(data["jim_parameters"]["local_sampler_arg"]["step_size"]) return SingleEventRun(**data) ### Initialization functions ### From de18edaf1c093161d2c7e2a5ae5cdd706a625c56 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:19:10 +0800 Subject: [PATCH 20/27] Added error log in MultipleEventPERunManager --- src/jimgw/single_event/runManager.py | 46 ++++++++++++++++++---------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 1a54520d..3c0c6ed1 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -508,24 +508,36 @@ def run(self, plot_corner: bool = True, plot_diagnostic: bool = True, save_summa Loop over all the configuration files in the run_config_path and run the PE for each configuration. """ - if plot_corner and not os.path.exists("corner_plots"): - os.makedirs("corner_plots") - if plot_diagnostic and not os.path.exists("diagnostic_plots"): - os.makedirs("diagnostic_plots") - if save_summary and not os.path.exists("summaries"): - os.makedirs("summaries") + if plot_corner and not os.path.exists(self.output_path+"/corner_plots"): + os.makedirs(self.output_path+"/corner_plots") + if plot_diagnostic and not os.path.exists(self.output_path+"/diagnostic_plots"): + os.makedirs(self.output_path+"/diagnostic_plots") + if save_summary and not os.path.exists(self.output_path+"/summaries"): + os.makedirs(self.output_path+"/summaries") + if not os.path.exists(self.output_path+"/error_log"): + os.makedirs(self.output_path+"/error_log") config_directory = os.fsencode(self.run_config_path) for file in os.listdir(config_directory): filename = os.fsdecode(file) - if filename.endswith(".yaml"): - config_path = os.path.join(self.run_config_path, filename) - run_manager = SingleEventPERunManager(path=config_path) - run_manager.sample() - - if plot_corner: - run_manager.plot_corner("corner_plots/" + filename + "_corner.jpeg") - if plot_diagnostic: - run_manager.plot_diagnostic("diagnostic_plots/" + filename + "_diagnostic.jpeg") - if save_summary: - run_manager.save_summary("summaries/" + filename + "_summary.txt") \ No newline at end of file + + try: + if filename.endswith(".yaml"): + config_path = os.path.join(self.run_config_path, filename) + run_manager = SingleEventPERunManager(path=config_path) + run_manager.sample() + + if plot_corner: + run_manager.plot_corner(self.output_path+"/corner_plots/" + filename + "_corner.jpeg") + if plot_diagnostic: + run_manager.plot_diagnostic(self.output_path+"/diagnostic_plots/" + filename + "_diagnostic.jpeg") + if save_summary: + run_manager.save_summary(self.output_path+"/summaries/" + filename + "_summary.txt") + + except Exception as e: + orig_stdout = sys.stdout + sys.stdout = open(self.output_path+"/error_log/"+filename+"error_log.txt", "wt") + print(f"Error in running {filename}. Error: {e}") + sys.stdout.close() + sys.stdout = orig_stdout + continue \ No newline at end of file From 3ddc1fc593711cb82ecf54d1171289b779ad470d Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:19:41 +0800 Subject: [PATCH 21/27] Fixed typo in Multiple_event_runManager.py --- example/Multiple_event_runManager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/Multiple_event_runManager.py b/example/Multiple_event_runManager.py index 04967a2c..860efb92 100644 --- a/example/Multiple_event_runManager.py +++ b/example/Multiple_event_runManager.py @@ -1,11 +1,11 @@ import jax import jax.numpy as jnp -from jimgw.single_event.runManager import MultipleEventRunManager +from jimgw.single_event.runManager import MultipleEventPERunManager jax.config.update("jax_enable_x64", True) -run_manager = MultipleEventRunManager( +run_manager = MultipleEventPERunManager( run_config_path="config", ) From 2d3976fa7924c71d72363624f2a898e5c82d76f8 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:22:14 +0800 Subject: [PATCH 22/27] reformatted --- src/jimgw/single_event/runManager.py | 76 +++++++++++++++++++--------- 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3c0c6ed1..ff4f571e 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -123,8 +123,14 @@ def save(self, path: str): def load_from_path(self, path: str) -> SingleEventRun: with open(path, "r") as f: data = yaml.safe_load(f) - if "jim_parameters" in data and "local_sampler_arg" in data["jim_parameters"] and "step_size" in data["jim_parameters"]["local_sampler_arg"]: - data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array(data["jim_parameters"]["local_sampler_arg"]["step_size"]) + if ( + "jim_parameters" in data + and "local_sampler_arg" in data["jim_parameters"] + and "step_size" in data["jim_parameters"]["local_sampler_arg"] + ): + data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array( + data["jim_parameters"]["local_sampler_arg"]["step_size"] + ) return SingleEventRun(**data) ### Initialization functions ### @@ -490,54 +496,76 @@ class MultipleEventPERunManager: """ Class to manage multiple events run. """ + run_config_path: str output_path: str - + def __init__(self, run_config_path: str, output_path: str = "output") -> None: """ Arguments: run_config_path (str): load the run configuration from the path. output_path (str, optional): save the output to this path. Defaults to "output". """ - + self.run_config_path = run_config_path self.output_path = output_path - - def run(self, plot_corner: bool = True, plot_diagnostic: bool = True, save_summary: bool = True): + + def run( + self, + plot_corner: bool = True, + plot_diagnostic: bool = True, + save_summary: bool = True, + ): """ Loop over all the configuration files in the run_config_path and run the PE for each configuration. """ - - if plot_corner and not os.path.exists(self.output_path+"/corner_plots"): - os.makedirs(self.output_path+"/corner_plots") - if plot_diagnostic and not os.path.exists(self.output_path+"/diagnostic_plots"): - os.makedirs(self.output_path+"/diagnostic_plots") - if save_summary and not os.path.exists(self.output_path+"/summaries"): - os.makedirs(self.output_path+"/summaries") - if not os.path.exists(self.output_path+"/error_log"): - os.makedirs(self.output_path+"/error_log") - + + if plot_corner and not os.path.exists(self.output_path + "/corner_plots"): + os.makedirs(self.output_path + "/corner_plots") + if plot_diagnostic and not os.path.exists( + self.output_path + "/diagnostic_plots" + ): + os.makedirs(self.output_path + "/diagnostic_plots") + if save_summary and not os.path.exists(self.output_path + "/summaries"): + os.makedirs(self.output_path + "/summaries") + if not os.path.exists(self.output_path + "/error_log"): + os.makedirs(self.output_path + "/error_log") + config_directory = os.fsencode(self.run_config_path) for file in os.listdir(config_directory): filename = os.fsdecode(file) - + try: if filename.endswith(".yaml"): config_path = os.path.join(self.run_config_path, filename) run_manager = SingleEventPERunManager(path=config_path) run_manager.sample() - + if plot_corner: - run_manager.plot_corner(self.output_path+"/corner_plots/" + filename + "_corner.jpeg") + run_manager.plot_corner( + self.output_path + + "/corner_plots/" + + filename + + "_corner.jpeg" + ) if plot_diagnostic: - run_manager.plot_diagnostic(self.output_path+"/diagnostic_plots/" + filename + "_diagnostic.jpeg") + run_manager.plot_diagnostic( + self.output_path + + "/diagnostic_plots/" + + filename + + "_diagnostic.jpeg" + ) if save_summary: - run_manager.save_summary(self.output_path+"/summaries/" + filename + "_summary.txt") - + run_manager.save_summary( + self.output_path + "/summaries/" + filename + "_summary.txt" + ) + except Exception as e: orig_stdout = sys.stdout - sys.stdout = open(self.output_path+"/error_log/"+filename+"error_log.txt", "wt") + sys.stdout = open( + self.output_path + "/error_log/" + filename + "error_log.txt", "wt" + ) print(f"Error in running {filename}. Error: {e}") sys.stdout.close() sys.stdout = orig_stdout - continue \ No newline at end of file + continue From c906c674d76fb8d3586d0c9d9a0dc076e682353b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:45:36 +0800 Subject: [PATCH 23/27] Added documentation in Multiple_event_runManager.py --- example/Multiple_event_runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/Multiple_event_runManager.py b/example/Multiple_event_runManager.py index 860efb92..eb6d87ce 100644 --- a/example/Multiple_event_runManager.py +++ b/example/Multiple_event_runManager.py @@ -6,7 +6,7 @@ jax.config.update("jax_enable_x64", True) run_manager = MultipleEventPERunManager( - run_config_path="config", + run_config_path="config", # the configuration file is stored in the config folder ) run_manager.run() From 7309e442f93c2622190f4fb7b26aa78dbef82d40 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:48:18 +0800 Subject: [PATCH 24/27] Added try block in load_from_path() --- src/jimgw/single_event/runManager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index ff4f571e..03f13faf 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -128,9 +128,12 @@ def load_from_path(self, path: str) -> SingleEventRun: and "local_sampler_arg" in data["jim_parameters"] and "step_size" in data["jim_parameters"]["local_sampler_arg"] ): - data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array( - data["jim_parameters"]["local_sampler_arg"]["step_size"] - ) + try: + data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array( + data["jim_parameters"]["local_sampler_arg"]["step_size"] + ) + except Exception as e: + print(f"Error in loading step_size: {e}") return SingleEventRun(**data) ### Initialization functions ### From 337c21ee83e8db3f1dedded758b507a39eba613e Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:13:14 +0800 Subject: [PATCH 25/27] Fixed try block in load_from_path() --- src/jimgw/single_event/runManager.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 03f13faf..404ea3cb 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -123,17 +123,13 @@ def save(self, path: str): def load_from_path(self, path: str) -> SingleEventRun: with open(path, "r") as f: data = yaml.safe_load(f) - if ( - "jim_parameters" in data - and "local_sampler_arg" in data["jim_parameters"] - and "step_size" in data["jim_parameters"]["local_sampler_arg"] - ): try: data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array( data["jim_parameters"]["local_sampler_arg"]["step_size"] ) - except Exception as e: - print(f"Error in loading step_size: {e}") + except KeyError as e: + print(f"Key {e} not found.") + return SingleEventRun(**data) ### Initialization functions ### From 1fb0b292841cdf9d08ddb58d0d994bbd6f21a6ff Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:14:04 +0800 Subject: [PATCH 26/27] reformatted --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 404ea3cb..035d571d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -129,7 +129,7 @@ def load_from_path(self, path: str) -> SingleEventRun: ) except KeyError as e: print(f"Key {e} not found.") - + return SingleEventRun(**data) ### Initialization functions ### From 3a02bdc869e1c88f6cd33f8811f4630bf134e3c5 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 2 Sep 2024 21:18:30 +0800 Subject: [PATCH 27/27] reformatted --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 035d571d..de5d84f7 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -128,7 +128,7 @@ def load_from_path(self, path: str) -> SingleEventRun: data["jim_parameters"]["local_sampler_arg"]["step_size"] ) except KeyError as e: - print(f"Key {e} not found.") + print("No local sampler argument provided in the configuration.") return SingleEventRun(**data)