From 7bb44eac46ac92a2931a64f6f9452514269d7f8f Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Fri, 16 Feb 2024 20:24:17 +0100 Subject: [PATCH 1/6] Printing injected SNRs --- src/jimgw/single_event/detector.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 4cea21af..ddaba2f5 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -406,6 +406,14 @@ def inject_signal( signal = self.fd_response(freqs, h_sky, params) * align_time self.data = signal + noise_real + 1j * noise_imag + # also calculate the optimal SNR and match filter SNR + optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real) + match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR + + print(f"For detector {self.name}:") + print(f"The injected optimal SNR is {optimal_SNR}") + print(f"The injected match filter SNR is {match_filter_SNR}") + @jaxtyped def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" From 121777c26773499cec32f342b5be42e381fd1286 Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Fri, 16 Feb 2024 20:42:28 +0100 Subject: [PATCH 2/6] fixing formatting in prior.py --- src/jimgw/prior.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 433b385a..76454075 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -267,7 +267,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: phi = x[self.naming[1]] mag = x[self.naming[2]] output = jnp.where( - (mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1), + (mag > 1) + | (mag < 0) + | (phi > 2 * jnp.pi) + | (phi < 0) + | (theta > 1) + | (theta < -1), jnp.zeros_like(0) - jnp.inf, jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), ) From 370223cebedaa34045104377a77380feedd2bc36 Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Fri, 16 Feb 2024 20:43:10 +0100 Subject: [PATCH 3/6] fixing formatting in runManager.py --- src/jimgw/single_event/runManager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index bed1ee3d..e2f5f8b2 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -249,9 +249,7 @@ def initialize_waveform(self) -> Waveform: ### Utility functions ### - def get_detector_waveform( - self, params: dict[str, float] - ) -> tuple[ + def get_detector_waveform(self, params: dict[str, float]) -> tuple[ Float[Array, " n_sample"], dict[str, Float[Array, " n_sample"]], dict[str, Float[Array, " n_sample"]], From d1fc1ff735984003fba7df3abb68aee184e6857b Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Fri, 16 Feb 2024 20:48:42 +0100 Subject: [PATCH 4/6] fixing formatting in prior.py 2 --- src/jimgw/prior.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 76454075..f8aab8e5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -220,7 +220,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: class Sphere(Prior): - """ A prior on a sphere represented by Cartesian coordinates. @@ -281,7 +280,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped class AlignedSpin(Prior): - """ Prior distribution for the aligned (z) component of the spin. @@ -395,7 +393,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped class PowerLaw(Prior): - """ A prior following the power-law with alpha in the range [xmin, xmax). p(x) ~ x^{\alpha} From c09aa1ae037015945191682470b3c4c6a1522d79 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Fri, 16 Feb 2024 21:10:51 +0100 Subject: [PATCH 5/6] Futher fixing formatting --- src/jimgw/single_event/runManager.py | 11 ++++++----- src/jimgw/single_event/utils.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index e2f5f8b2..61b82c5a 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -1,3 +1,4 @@ +from typing import Union from jimgw.base import RunManager from dataclasses import dataclass, field, asdict from jimgw.single_event.likelihood import likelihood_presets, SingleEventLiklihood @@ -65,18 +66,18 @@ class SingleEventRun: detectors: list[str] priors: dict[ - str, dict[str, str | float | int | bool] + str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. - jim_parameters: dict[str, str | float | int | bool | dict] + jim_parameters: dict[str, Union[str, float, int, bool, dict]] injection_parameters: dict[str, float] injection: bool = False - likelihood_parameters: dict[str, str | float | int | bool | PyTree] = field( + likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} ) - waveform_parameters: dict[str, str | float | int | bool] = field( + waveform_parameters: dict[str, Union[str, float, int, bool]] = field( default_factory=lambda: {"name": ""} ) - data_parameters: dict[str, float | int] = field( + data_parameters: dict[str, Union[float, int]] = field( default_factory=lambda: { "trigger_time": 0.0, "duration": 0, diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index aba78a02..84d96228 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +from jax.scipy.integrate import trapezoid from jax import jit from jaxtyping import Float, Array @@ -34,7 +35,7 @@ def inner_product( # psd_interp = jnp.interp(frequency, psd_frequency, psd) df = frequency[1] - frequency[0] integrand = jnp.conj(h1) * h2 / psd - return 4.0 * jnp.real(jnp.trapz(integrand, dx=df)) + return 4.0 * jnp.real(trapezoid(integrand, dx=df)) @jit From a5dc8c2c8b5c4177ea2f28202cc1af667978d045 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 20 Feb 2024 11:26:42 -0500 Subject: [PATCH 6/6] Update likelihood calculation in GW150914_PV2.py --- example/GW150914_PV2.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index a36875e5..b0319ea9 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -109,8 +109,8 @@ [-1.0, 1.0], ] ) -# likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) -likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) +# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(prior.n_dim) @@ -135,7 +135,10 @@ keep_quantile=0., train_thinning=1, output_thinning=30, + num_layers=6, + hidden_size=[64, 64], + num_bins=8, local_sampler_arg=local_sampler_arg, ) -jim.sample(jax.random.PRNGKey(42)) +jim.sample(jax.random.PRNGKey(42))1