Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time and phase marginalization #71

Merged
merged 11 commits into from
Feb 22, 2024
76 changes: 57 additions & 19 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

from jimgw.single_event.detector import Detector
from jimgw.prior import Prior
from jimgw.single_event.utils import (
original_likelihood,
phase_marginalized_likelihood,
time_marginalized_likelihood,
phase_time_marginalized_likelihood,
)
from jimgw.single_event.waveform import Waveform
from jimgw.base import LikelihoodBase

Expand Down Expand Up @@ -51,6 +57,46 @@ def __init__(
self.trigger_time = trigger_time
self.duration = duration
self.post_trigger_duration = post_trigger_duration
self.kwargs = kwargs
if "marginalization" in self.kwargs:
marginalization = self.kwargs["marginalization"]
assert marginalization in [
"phase",
"phase-time",
"time",
], "Only support time, phase and phase+time marginalzation"
self.marginalization = marginalization
if self.marginalization == "phase-time":
self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0}
self.likelihood_function = phase_time_marginalized_likelihood
print("Marginalizing over phase and time")
elif self.marginalization == "time":
self.param_func = lambda x: {**x, "t_c": 0.0}
self.likelihood_function = time_marginalized_likelihood
print("Marginalizing over time")
elif self.marginalization == "phase":
self.param_func = lambda x: {**x, "phase_c": 0.0}
self.likelihood_function = phase_marginalized_likelihood
print("Marginalizing over phase")

if "time" in self.marginalization:
fs = kwargs["sampling_rate"]
self.kwargs["tc_array"] = jnp.fft.fftfreq(
int(duration * fs / 2), 1.0 / duration
)
self.kwargs["pad_low"] = jnp.zeros(int(self.frequencies[0] * duration))
if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration):
self.kwargs["pad_high"] = jnp.array([])
else:
self.kwargs["pad_high"] = jnp.zeros(
int(
(fs / 2.0 - 1.0 / duration - self.frequencies[-1])
* duration
)
)
else:
self.param_func = lambda x: x
self.likelihood_function = original_likelihood

@property
def epoch(self):
Expand All @@ -71,31 +117,23 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float:
"""
Evaluate the likelihood for a given set of parameters.
"""
log_likelihood = 0
frequencies = self.frequencies
df = frequencies[1] - frequencies[0]
params["gmst"] = self.gmst
# adjust the params due to different marginalzation scheme
params = self.param_func(params)
# evaluate the waveform as usual
waveform_sky = self.waveform(frequencies, params)
align_time = jnp.exp(
-1j * 2 * jnp.pi * frequencies * (self.epoch + params["t_c"])
)
for detector in self.detectors:
waveform_dec = (
detector.fd_response(frequencies, waveform_sky, params) * align_time
)
match_filter_SNR = (
4
* jnp.sum(
(jnp.conj(waveform_dec) * detector.data) / detector.psd * df
).real
)
optimal_SNR = (
4
* jnp.sum(
jnp.conj(waveform_dec) * waveform_dec / detector.psd * df
).real
)
log_likelihood += match_filter_SNR - optimal_SNR / 2
log_likelihood = self.likelihood_function(
params,
waveform_sky,
self.detectors,
frequencies,
align_time,
**self.kwargs,
)
return log_likelihood


Expand Down
142 changes: 141 additions & 1 deletion src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import jax.numpy as jnp
from jax.scipy.special import i0e, logsumexp
from jax.scipy.integrate import trapezoid
from jax import jit
from jaxtyping import Float, Array

Expand Down Expand Up @@ -34,7 +36,7 @@ def inner_product(
# psd_interp = jnp.interp(frequency, psd_frequency, psd)
df = frequency[1] - frequency[0]
integrand = jnp.conj(h1) * h2 / psd
return 4.0 * jnp.real(jnp.trapz(integrand, dx=df))
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


@jit
Expand Down Expand Up @@ -140,3 +142,141 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa
phi = ra - gmst
theta = jnp.pi / 2 - dec
return theta, phi


def log_i0(x):
"""
A numerically stable method to evaluate log of
a modified Bessel function of order 0.
It is used in the phase-marginalized likelihood.

Parameters
==========
x: array-like
Value(s) at which to evaluate the function

Returns
=======
array-like:
The natural logarithm of the bessel function
"""
return jnp.log(i0e(x)) + x


def original_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to put the likelihood function in the likelihood.py file instead of utils.py. You can leave the logi0 in.

log_likelihood = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
match_filter_SNR = (
4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += match_filter_SNR - optimal_SNR / 2

return log_likelihood


def phase_marginalized_likelihood(
params, h_sky, detectors, freqs, align_time, **kwargs
):
log_likelihood = 0.0
complex_d_inner_h = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_d_inner_h += 4 * jnp.sum(
(jnp.conj(h_dec) * detector.data) / detector.psd * df
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

log_likelihood += log_i0(jnp.absolute(complex_d_inner_h))

return log_likelihood


def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs):
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
fft_h_inner_d = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
fft_h_inner_d.real,
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array))

return log_likelihood


def phase_time_marginalized_likelihood(
params, h_sky, detectors, freqs, align_time, **kwargs
):
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
log_i0_abs_fft = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
log_i0(jnp.absolute(fft_h_inner_d)),
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array))

return log_likelihood