diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index a36875e5..b0319ea9 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -109,8 +109,8 @@ [-1.0, 1.0], ] ) -# likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) -likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) +# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(prior.n_dim) @@ -135,7 +135,10 @@ keep_quantile=0., train_thinning=1, output_thinning=30, + num_layers=6, + hidden_size=[64, 64], + num_bins=8, local_sampler_arg=local_sampler_arg, ) -jim.sample(jax.random.PRNGKey(42)) +jim.sample(jax.random.PRNGKey(42))1 diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 433b385a..f8aab8e5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -220,7 +220,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: class Sphere(Prior): - """ A prior on a sphere represented by Cartesian coordinates. @@ -267,7 +266,12 @@ def log_prob(self, x: dict[str, Float]) -> Float: phi = x[self.naming[1]] mag = x[self.naming[2]] output = jnp.where( - (mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1), + (mag > 1) + | (mag < 0) + | (phi > 2 * jnp.pi) + | (phi < 0) + | (theta > 1) + | (theta < -1), jnp.zeros_like(0) - jnp.inf, jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), ) @@ -276,7 +280,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped class AlignedSpin(Prior): - """ Prior distribution for the aligned (z) component of the spin. @@ -390,7 +393,6 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped class PowerLaw(Prior): - """ A prior following the power-law with alpha in the range [xmin, xmax). p(x) ~ x^{\alpha} diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 4cea21af..ddaba2f5 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -406,6 +406,14 @@ def inject_signal( signal = self.fd_response(freqs, h_sky, params) * align_time self.data = signal + noise_real + 1j * noise_imag + # also calculate the optimal SNR and match filter SNR + optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real) + match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR + + print(f"For detector {self.name}:") + print(f"The injected optimal SNR is {optimal_SNR}") + print(f"The injected match filter SNR is {match_filter_SNR}") + @jaxtyped def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index bed1ee3d..61b82c5a 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -1,3 +1,4 @@ +from typing import Union from jimgw.base import RunManager from dataclasses import dataclass, field, asdict from jimgw.single_event.likelihood import likelihood_presets, SingleEventLiklihood @@ -65,18 +66,18 @@ class SingleEventRun: detectors: list[str] priors: dict[ - str, dict[str, str | float | int | bool] + str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. - jim_parameters: dict[str, str | float | int | bool | dict] + jim_parameters: dict[str, Union[str, float, int, bool, dict]] injection_parameters: dict[str, float] injection: bool = False - likelihood_parameters: dict[str, str | float | int | bool | PyTree] = field( + likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} ) - waveform_parameters: dict[str, str | float | int | bool] = field( + waveform_parameters: dict[str, Union[str, float, int, bool]] = field( default_factory=lambda: {"name": ""} ) - data_parameters: dict[str, float | int] = field( + data_parameters: dict[str, Union[float, int]] = field( default_factory=lambda: { "trigger_time": 0.0, "duration": 0, @@ -249,9 +250,7 @@ def initialize_waveform(self) -> Waveform: ### Utility functions ### - def get_detector_waveform( - self, params: dict[str, float] - ) -> tuple[ + def get_detector_waveform(self, params: dict[str, float]) -> tuple[ Float[Array, " n_sample"], dict[str, Float[Array, " n_sample"]], dict[str, Float[Array, " n_sample"]],