Skip to content

Commit

Permalink
Move internel likelihood function from utils.py to likelihood.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Feb 20, 2024
1 parent 7444f19 commit 0cbf45c
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 126 deletions.
151 changes: 145 additions & 6 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import numpy as np
import numpy.typing as npt
from astropy.time import Time
Expand All @@ -9,12 +10,7 @@

from jimgw.single_event.detector import Detector
from jimgw.prior import Prior
from jimgw.single_event.utils import (
original_likelihood,
phase_marginalized_likelihood,
time_marginalized_likelihood,
phase_time_marginalized_likelihood,
)
from jimgw.single_event.utils import log_i0
from jimgw.single_event.waveform import Waveform
from jimgw.base import LikelihoodBase

Expand Down Expand Up @@ -491,3 +487,146 @@ def y(x):
"TransientLikelihoodFD": TransientLikelihoodFD,
"HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD,
}


def original_likelihood(
params: dict[str, Float],
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
match_filter_SNR = (
4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += match_filter_SNR - optimal_SNR / 2

return log_likelihood


def phase_marginalized_likelihood(
params: dict[str, Float],
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
complex_d_inner_h = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_d_inner_h += 4 * jnp.sum(
(jnp.conj(h_dec) * detector.data) / detector.psd * df
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

log_likelihood += log_i0(jnp.absolute(complex_d_inner_h))

return log_likelihood


def time_marginalized_likelihood(
params: dict[str, Float],
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
fft_h_inner_d = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
fft_h_inner_d.real,
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array))

return log_likelihood


def phase_time_marginalized_likelihood(
params: dict[str, Float],
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
log_i0_abs_fft = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
log_i0(jnp.absolute(fft_h_inner_d)),
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array))

return log_likelihood
121 changes: 1 addition & 120 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax.numpy as jnp
from jax.scipy.special import i0e, logsumexp
from jax.scipy.special import i0e
from jax.scipy.integrate import trapezoid
from jax import jit
from jaxtyping import Float, Array
Expand Down Expand Up @@ -161,122 +161,3 @@ def log_i0(x):
The natural logarithm of the bessel function
"""
return jnp.log(i0e(x)) + x


def original_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs):
log_likelihood = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
match_filter_SNR = (
4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += match_filter_SNR - optimal_SNR / 2

return log_likelihood


def phase_marginalized_likelihood(
params, h_sky, detectors, freqs, align_time, **kwargs
):
log_likelihood = 0.0
complex_d_inner_h = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_d_inner_h += 4 * jnp.sum(
(jnp.conj(h_dec) * detector.data) / detector.psd * df
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

log_likelihood += log_i0(jnp.absolute(complex_d_inner_h))

return log_likelihood


def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs):
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
fft_h_inner_d = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
fft_h_inner_d.real,
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array))

return log_likelihood


def phase_time_marginalized_likelihood(
params, h_sky, detectors, freqs, align_time, **kwargs
):
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
tc_range = kwargs["tc_range"]
tc_array = kwargs["tc_array"]
pad_low = kwargs["pad_low"]
pad_high = kwargs["pad_high"]

# padding the complex_h_inner_d
# this array is the hd*/S for f in [0, fs / 2 - df]
complex_h_inner_d_positive_f = jnp.concatenate(
(pad_low, complex_h_inner_d, pad_high)
)

# make use of the fft
# which then return the <h|d>exp(-i2pift_c)
# w.r.t. the tc_array
fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward")

# set the values to -inf when it is outside the tc range
# so that they will disappear after the logsumexp
log_i0_abs_fft = jnp.where(
(tc_array > tc_range[0]) & (tc_array < tc_range[1]),
log_i0(jnp.absolute(fft_h_inner_d)),
jnp.zeros_like(fft_h_inner_d.real) - jnp.inf,
)

# using the logsumexp to marginalize over the tc prior range
log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array))

return log_likelihood

0 comments on commit 0cbf45c

Please sign in to comment.