Skip to content

Commit

Permalink
Merge branch 'main' into time-phase-marginalization
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong authored Feb 20, 2024
2 parents 145b41a + a5dc8c2 commit 7444f19
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
9 changes: 6 additions & 3 deletions example/GW150914_PV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@
[-1.0, 1.0],
]
)
# likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)
# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2)


mass_matrix = jnp.eye(prior.n_dim)
Expand All @@ -135,7 +135,10 @@
keep_quantile=0.,
train_thinning=1,
output_thinning=30,
num_layers=6,
hidden_size=[64, 64],
num_bins=8,
local_sampler_arg=local_sampler_arg,
)

jim.sample(jax.random.PRNGKey(42))
jim.sample(jax.random.PRNGKey(42))1
10 changes: 6 additions & 4 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:


class Sphere(Prior):

"""
A prior on a sphere represented by Cartesian coordinates.
Expand Down Expand Up @@ -267,7 +266,12 @@ def log_prob(self, x: dict[str, Float]) -> Float:
phi = x[self.naming[1]]
mag = x[self.naming[2]]
output = jnp.where(
(mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1),
(mag > 1)
| (mag < 0)
| (phi > 2 * jnp.pi)
| (phi < 0)
| (theta > 1)
| (theta < -1),
jnp.zeros_like(0) - jnp.inf,
jnp.log(mag**2 * jnp.sin(x[self.naming[0]])),
)
Expand All @@ -276,7 +280,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:

@jaxtyped
class AlignedSpin(Prior):

"""
Prior distribution for the aligned (z) component of the spin.
Expand Down Expand Up @@ -390,7 +393,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:

@jaxtyped
class PowerLaw(Prior):

"""
A prior following the power-law with alpha in the range [xmin, xmax).
p(x) ~ x^{\alpha}
Expand Down
8 changes: 8 additions & 0 deletions src/jimgw/single_event/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ def inject_signal(
signal = self.fd_response(freqs, h_sky, params) * align_time
self.data = signal + noise_real + 1j * noise_imag

# also calculate the optimal SNR and match filter SNR
optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real)
match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR

print(f"For detector {self.name}:")
print(f"The injected optimal SNR is {optimal_SNR}")
print(f"The injected match filter SNR is {match_filter_SNR}")

@jaxtyped
def load_psd(
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
Expand Down
15 changes: 7 additions & 8 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
from jimgw.base import RunManager
from dataclasses import dataclass, field, asdict
from jimgw.single_event.likelihood import likelihood_presets, SingleEventLiklihood
Expand Down Expand Up @@ -65,18 +66,18 @@ class SingleEventRun:

detectors: list[str]
priors: dict[
str, dict[str, str | float | int | bool]
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, str | float | int | bool | dict]
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
injection_parameters: dict[str, float]
injection: bool = False
likelihood_parameters: dict[str, str | float | int | bool | PyTree] = field(
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
)
waveform_parameters: dict[str, str | float | int | bool] = field(
waveform_parameters: dict[str, Union[str, float, int, bool]] = field(
default_factory=lambda: {"name": ""}
)
data_parameters: dict[str, float | int] = field(
data_parameters: dict[str, Union[float, int]] = field(
default_factory=lambda: {
"trigger_time": 0.0,
"duration": 0,
Expand Down Expand Up @@ -249,9 +250,7 @@ def initialize_waveform(self) -> Waveform:

### Utility functions ###

def get_detector_waveform(
self, params: dict[str, float]
) -> tuple[
def get_detector_waveform(self, params: dict[str, float]) -> tuple[
Float[Array, " n_sample"],
dict[str, Float[Array, " n_sample"]],
dict[str, Float[Array, " n_sample"]],
Expand Down

0 comments on commit 7444f19

Please sign in to comment.