From 42760db9cd16a7cd45738cddb22b42c72a08d930 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Fri, 16 Feb 2024 18:41:16 +0100 Subject: [PATCH 01/10] initial commit for marginalization --- src/jimgw/single_event/likelihood.py | 54 +++++---- src/jimgw/single_event/utils.py | 161 +++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 19 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 79907c90..c669583e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -9,6 +9,12 @@ from jimgw.single_event.detector import Detector from jimgw.prior import Prior +from jimgw.single_event.utils import ( + original_likelihood, + phase_marginalized_likelihood, + time_marginalized_likelihood, + phase_time_marginalized_likelihood +) from jimgw.single_event.waveform import Waveform from jimgw.base import LikelihoodBase @@ -51,6 +57,24 @@ def __init__( self.trigger_time = trigger_time self.duration = duration self.post_trigger_duration = post_trigger_duration + if 'marginalization' in kwargs: + marginalization = kwargs['marginalization'] + assert marginalization in ['phase', 'phase-time', 'time'], \ + "Only support time, phase and phase+time marginalzation" + self.marginalization = marginalization + if self.marginalization == 'phase-time': + self.param_func = lambda x: {**x, 'phi_c': 0., 't_c': 0.} + self.likelihood_function = phase_time_marginalized_likelihood + elif self.marginalization == 'time': + self.param_func = lambda x: {**x, 't_c': 0.} + self.likelihood_function = time_marginalized_likelihood + elif self.marginalization == 'phase': + self.param_func = lambda x: {**x, 'phi_c': 0.} + self.likelihood_function = phase_marginalized_likelihood + else: + self.param_func = lambda x: x + self.likelihood_function = original_likelihood + self.kwargs = kwargs @property def epoch(self): @@ -71,31 +95,23 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: """ Evaluate the likelihood for a given set of parameters. """ - log_likelihood = 0 frequencies = self.frequencies - df = frequencies[1] - frequencies[0] params["gmst"] = self.gmst + # adjust the params due to different marginalzation scheme + params = self.param_func(params) + # evaluate the waveform as usual waveform_sky = self.waveform(frequencies, params) align_time = jnp.exp( -1j * 2 * jnp.pi * frequencies * (self.epoch + params["t_c"]) ) - for detector in self.detectors: - waveform_dec = ( - detector.fd_response(frequencies, waveform_sky, params) * align_time - ) - match_filter_SNR = ( - 4 - * jnp.sum( - (jnp.conj(waveform_dec) * detector.data) / detector.psd * df - ).real - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(waveform_dec) * waveform_dec / detector.psd * df - ).real - ) - log_likelihood += match_filter_SNR - optimal_SNR / 2 + log_likelihood = self.likelihood_function( + params, + waveform_sky, + self.detectors, + frequencies, + align_time, + **self.kwargs, + ) return log_likelihood diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index aba78a02..dbdb8904 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +from jax.scipy.special import i0e, logsumexp from jax import jit from jaxtyping import Float, Array @@ -140,3 +141,163 @@ def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Floa phi = ra - gmst theta = jnp.pi / 2 - dec return theta, phi + + +def log_i0(x): + """ + A numerically stable method to evaluate log of + a modified Bessel function of order 0. + It is used in the phase-marginalized likelihood. + + Parameters + ========== + x: array-like + Value(s) at which to evaluate the function + + Returns + ======= + array-like: + The natural logarithm of the bessel function + """ + return jnp.log(i0e(x)) + x + + +def original_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): + log_likelihood = 0. + df = freqs[1] - freqs[0] + for detector in detectors: + h_dec = ( + detector.fd_response(freqs, h_sky, params) * align_time + ) + match_filter_SNR = ( + 4 + * jnp.sum( + (jnp.conj(h_dec) * detector.data) / detector.psd * df + ).real + ) + optimal_SNR = ( + 4 + * jnp.sum( + jnp.conj(h_dec) * h_dec / detector.psd * df + ).real + ) + log_likelihood += match_filter_SNR - optimal_SNR / 2 + + return log_likelihood + + +def phase_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): + log_likelihood = 0. + complex_d_inner_h = 0. + df = freqs[1] - freqs[0] + for detector in detectors: + h_dec = ( + detector.fd_response(freqs, h_sky, params) * align_time + ) + complex_d_inner_h += ( + 4 + * jnp.sum( + (jnp.conj(h_dec) * detector.data) / detector.psd * df + ) + ) + optimal_SNR = ( + 4 + * jnp.sum( + jnp.conj(h_dec) * h_dec / detector.psd * df + ).real + ) + log_likelihood += -optimal_SNR / 2 + + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + + return log_likelihood + + +def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): + log_likelihood = 0. + df = freqs[1] - freqs[0] + # using instead of + complex_h_inner_d = jnp.zeros_like(freqs) + for detector in detectors: + h_dec = ( + detector.fd_response(freqs, h_sky, params) * align_time + ) + complex_h_inner_d += ( + 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + ) + optimal_SNR = ( + 4 + * jnp.sum( + jnp.conj(h_dec) * h_dec / detector.psd * df + ).real + ) + log_likelihood += -optimal_SNR / 2 + + # padding the complex_d_inner_h before feeding to the fft + # lower and higher frequency padding + pad_low = jnp.arange(0, freqs[0], df) + pad_high = jnp.arange(freqs[1], kwargs['sampling_rate'], df) + complex_h_inner_d = jnp.concatenate((pad_low, complex_h_inner_d, pad_high)) + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d) + # abusing the fftfreq to get the corresponding tc array + tc_array = jnp.fft.fftfreq(n=len(fft_h_inner_d), d=df) + + # fetch the range of valid tc + tc_range = kwargs['tc_range'] + # set the values to -inf when it is outside the tc range + # so that they will disappear after the logsumexp + fft_h_inner_d = jnp.where( + tc_array > tc_range[0] and tc_array < tc_range[1], + fft_h_inner_d, + jnp.zeros_like(fft_h_inner_d) - jnp.inf + ) + + # using the logsumexp to marginalize over the tc prior range + log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(fft_h_inner_d)) + + return log_likelihood + + +def phase_time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): + log_likelihood = 0. + df = freqs[1] - freqs[0] + # using instead of + complex_h_inner_d = jnp.zeros_like(freqs) + for detector in detectors: + h_dec = ( + detector.fd_response(freqs, h_sky, params) * align_time + ) + complex_h_inner_d += ( + 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + ) + optimal_SNR = ( + 4 + * jnp.sum( + jnp.conj(h_dec) * h_dec / detector.psd * df + ).real + ) + log_likelihood += -optimal_SNR / 2 + + # padding the complex_d_inner_h before feeding to the fft + # lower and higher frequency padding + pad_low = jnp.arange(0, freqs[0], df) + pad_high = jnp.arange(freqs[1], kwargs['sampling_rate'], df) + complex_h_inner_d = jnp.concatenate((pad_low, complex_h_inner_d, pad_high)) + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d) + # abusing the fftfreq to get the corresponding tc array + tc_array = jnp.fft.fftfreq(n=len(fft_h_inner_d), d=df) + + # fetch the range of valid tc + tc_range = kwargs['tc_range'] + # set the values to -inf when it is outside the tc range + # so that they will disappear after the logsumexp + log_i0_abs_fft = jnp.where( + tc_array > tc_range[0] and tc_array < tc_range[1], + log_i0(jnp.absolute(fft_h_inner_d)), + jnp.zeros_like(fft_h_inner_d) - jnp.inf + ) + + # using the logsumexp to marginalize over the tc prior range + log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(fft_h_inner_d)) + + return log_likelihood From c7f15e2cb9b8ef8cd8bff2eb9b730a72595986e3 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Mon, 19 Feb 2024 11:47:50 +0100 Subject: [PATCH 02/10] Further commit for the marginalization --- src/jimgw/single_event/likelihood.py | 50 ++++--- src/jimgw/single_event/utils.py | 188 +++++++++++++++------------ 2 files changed, 138 insertions(+), 100 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index c669583e..c50e9f04 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -13,7 +13,7 @@ original_likelihood, phase_marginalized_likelihood, time_marginalized_likelihood, - phase_time_marginalized_likelihood + phase_time_marginalized_likelihood, ) from jimgw.single_event.waveform import Waveform from jimgw.base import LikelihoodBase @@ -57,24 +57,44 @@ def __init__( self.trigger_time = trigger_time self.duration = duration self.post_trigger_duration = post_trigger_duration - if 'marginalization' in kwargs: - marginalization = kwargs['marginalization'] - assert marginalization in ['phase', 'phase-time', 'time'], \ - "Only support time, phase and phase+time marginalzation" + self.kwargs = kwargs + if "marginalization" in self.kwargs: + marginalization = self.kwargs["marginalization"] + assert marginalization in [ + "phase", + "phase-time", + "time", + ], "Only support time, phase and phase+time marginalzation" self.marginalization = marginalization - if self.marginalization == 'phase-time': - self.param_func = lambda x: {**x, 'phi_c': 0., 't_c': 0.} + if self.marginalization == "phase-time": + self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0} self.likelihood_function = phase_time_marginalized_likelihood - elif self.marginalization == 'time': - self.param_func = lambda x: {**x, 't_c': 0.} + print("Marginalizing over phase and time") + elif self.marginalization == "time": + self.param_func = lambda x: {**x, "t_c": 0.0} self.likelihood_function = time_marginalized_likelihood - elif self.marginalization == 'phase': - self.param_func = lambda x: {**x, 'phi_c': 0.} + print("Marginalizing over time") + elif self.marginalization == "phase": + self.param_func = lambda x: {**x, "phase_c": 0.0} self.likelihood_function = phase_marginalized_likelihood - else: - self.param_func = lambda x: x - self.likelihood_function = original_likelihood - self.kwargs = kwargs + print("Marginalizing over phase") + + if 'time' in self.marginalization: + fs = kwargs['sampling_rate'] + self.kwargs['tc_array'] = jnp.fft.fftfreq( + int(duration * fs), + 1. / duration + ) + self.kwargs['pad_low'] = jnp.zeros(int(self.frequencies[0] * duration)) + if jnp.isclose(self.frequencies[-1], fs / 2. - 1. / duration): + self.kwargs['pad_high'] = jnp.array([]) + else: + self.kwargs['pad_high'] = jnp.zeros( + int((fs / 2. - 1. / duration - self.frequencies[-1]) * duration) + ) + else: + self.param_func = lambda x: x + self.likelihood_function = original_likelihood @property def epoch(self): diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index dbdb8904..2075a87f 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -163,49 +163,31 @@ def log_i0(x): def original_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0. + log_likelihood = 0.0 df = freqs[1] - freqs[0] for detector in detectors: - h_dec = ( - detector.fd_response(freqs, h_sky, params) * align_time - ) + h_dec = detector.fd_response(freqs, h_sky, params) * align_time match_filter_SNR = ( - 4 - * jnp.sum( - (jnp.conj(h_dec) * detector.data) / detector.psd * df - ).real - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(h_dec) * h_dec / detector.psd * df - ).real + 4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real ) + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real log_likelihood += match_filter_SNR - optimal_SNR / 2 return log_likelihood -def phase_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0. - complex_d_inner_h = 0. +def phase_marginalized_likelihood( + params, h_sky, detectors, freqs, align_time, **kwargs +): + log_likelihood = 0.0 + complex_d_inner_h = 0.0 df = freqs[1] - freqs[0] for detector in detectors: - h_dec = ( - detector.fd_response(freqs, h_sky, params) * align_time - ) - complex_d_inner_h += ( - 4 - * jnp.sum( - (jnp.conj(h_dec) * detector.data) / detector.psd * df - ) - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(h_dec) * h_dec / detector.psd * df - ).real + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_d_inner_h += 4 * jnp.sum( + (jnp.conj(h_dec) * detector.data) / detector.psd * df ) + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real log_likelihood += -optimal_SNR / 2 log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) @@ -214,90 +196,126 @@ def phase_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, * def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0. + log_likelihood = 0.0 df = freqs[1] - freqs[0] # using instead of complex_h_inner_d = jnp.zeros_like(freqs) for detector in detectors: - h_dec = ( - detector.fd_response(freqs, h_sky, params) * align_time - ) - complex_h_inner_d += ( - 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(h_dec) * h_dec / detector.psd * df - ).real - ) + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real log_likelihood += -optimal_SNR / 2 - # padding the complex_d_inner_h before feeding to the fft - # lower and higher frequency padding - pad_low = jnp.arange(0, freqs[0], df) - pad_high = jnp.arange(freqs[1], kwargs['sampling_rate'], df) - complex_h_inner_d = jnp.concatenate((pad_low, complex_h_inner_d, pad_high)) - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d) - # abusing the fftfreq to get the corresponding tc array - tc_array = jnp.fft.fftfreq(n=len(fft_h_inner_d), d=df) - - # fetch the range of valid tc + # fetch the tc range tc_array, lower padding and higher padding tc_range = kwargs['tc_range'] + tc_array = kwargs['tc_array'] + pad_low = kwargs['pad_low'] + pad_high = kwargs['pad_high'] + fs = kwargs['sampling_rate'] + + # padding the complex_h_inner_d + # this array is the hd*/S for f in [-fs / 2, -df] + complex_h_inner_d_negative_f = jnp.concatenate( + (jnp.zeros(len(pad_high) + 1), + jnp.flip(complex_h_inner_d).conj(), jnp.zeros(len(pad_low) - 1)) + ) + # this array is the hd*/S for f in [0, fs / 2 - df] + complex_h_inner_d_positive_f = jnp.concatenate( + (pad_low, complex_h_inner_d, pad_high) + ) + + # combing to get the complete f + # using the convention of fftfreq in numpy + # i.e. f in [-fs / 2, -fs / 2 + df... -df, 0, df, ... fs / 2 - df] + complex_h_inner_d_full_f = jnp.concatenate( + (complex_h_inner_d_negative_f, complex_h_inner_d_positive_f) + ) + # since we go from one-sided to two-sided frequency range + # we need to introduce a factor of 2 correction + complex_h_inner_d_full_f /= 2. + + # make use of the fft + # which then return the exp(-i2pift_c) + # w.r.t. the tc_array + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_full_f, norm='backward') + # this extra factor is due to f = -fs / 2 + j * df, where j is the array index + # since the f2 is usally a power of 2, it usally has no effect + # but it is here to impreove the code readibility + fft_h_inner_d *= jnp.exp(-1j * jnp.pi * fs) + # set the values to -inf when it is outside the tc range # so that they will disappear after the logsumexp fft_h_inner_d = jnp.where( - tc_array > tc_range[0] and tc_array < tc_range[1], - fft_h_inner_d, - jnp.zeros_like(fft_h_inner_d) - jnp.inf + (tc_array > tc_range[0]) & (tc_array < tc_range[1]), + fft_h_inner_d.real, + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(fft_h_inner_d)) + log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array)) return log_likelihood -def phase_time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0. +def phase_time_marginalized_likelihood( + params, h_sky, detectors, freqs, align_time, **kwargs +): + log_likelihood = 0.0 df = freqs[1] - freqs[0] # using instead of complex_h_inner_d = jnp.zeros_like(freqs) for detector in detectors: - h_dec = ( - detector.fd_response(freqs, h_sky, params) * align_time - ) - complex_h_inner_d += ( - 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(h_dec) * h_dec / detector.psd * df - ).real - ) + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real log_likelihood += -optimal_SNR / 2 - # padding the complex_d_inner_h before feeding to the fft - # lower and higher frequency padding - pad_low = jnp.arange(0, freqs[0], df) - pad_high = jnp.arange(freqs[1], kwargs['sampling_rate'], df) - complex_h_inner_d = jnp.concatenate((pad_low, complex_h_inner_d, pad_high)) - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d) - # abusing the fftfreq to get the corresponding tc array - tc_array = jnp.fft.fftfreq(n=len(fft_h_inner_d), d=df) - - # fetch the range of valid tc + # fetch the tc range tc_array, lower padding and higher padding tc_range = kwargs['tc_range'] + tc_array = kwargs['tc_array'] + pad_low = kwargs['pad_low'] + pad_high = kwargs['pad_high'] + fs = kwargs['sampling_rate'] + + # padding the complex_h_inner_d + # this array is the hd*/S for f in [-fs / 2, -df] + complex_h_inner_d_negative_f = jnp.concatenate( + (jnp.zeros(len(pad_high) + 1), + jnp.flip(complex_h_inner_d).conj(), jnp.zeros(len(pad_low) - 1)) + ) + # this array is the hd*/S for f in [0, fs / 2 - df] + complex_h_inner_d_positive_f = jnp.concatenate( + (pad_low, complex_h_inner_d, pad_high) + ) + + # combing to get the complete f + # using the convention of fftfreq in numpy + # i.e. f in [-fs / 2, -fs / 2 + df... -df, 0, df, ... fs / 2 - df] + complex_h_inner_d_full_f = jnp.concatenate( + (complex_h_inner_d_negative_f, complex_h_inner_d_positive_f) + ) + # since we go from one-sided to two-sided frequency range + # we need to introduce a factor of 2 correction + complex_h_inner_d_full_f /= 2. + + # make use of the fft + # which then return the exp(-i2pift_c) + # w.r.t. the tc_array + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_full_f, norm='backward') + # this extra factor is due to f = -fs / 2 + j * df, where j is the array index + # since the f2 is usally a power of 2, it usally has no effect + # but it is here to impreove the code readibility + fft_h_inner_d *= jnp.exp(-1j * jnp.pi * fs) + # set the values to -inf when it is outside the tc range # so that they will disappear after the logsumexp log_i0_abs_fft = jnp.where( - tc_array > tc_range[0] and tc_array < tc_range[1], + (tc_array > tc_range[0]) & (tc_array < tc_range[1]), log_i0(jnp.absolute(fft_h_inner_d)), - jnp.zeros_like(fft_h_inner_d) - jnp.inf + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, ) # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(fft_h_inner_d)) + log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) return log_likelihood From d476e68e0bb3051a1558762d417fd96a24dd017f Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Feb 2024 12:09:11 +0100 Subject: [PATCH 03/10] Switching to positive frequency only --- src/jimgw/single_event/likelihood.py | 22 +++++----- src/jimgw/single_event/utils.py | 60 +++++----------------------- 2 files changed, 22 insertions(+), 60 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index c50e9f04..594317cd 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -79,18 +79,20 @@ def __init__( self.likelihood_function = phase_marginalized_likelihood print("Marginalizing over phase") - if 'time' in self.marginalization: - fs = kwargs['sampling_rate'] - self.kwargs['tc_array'] = jnp.fft.fftfreq( - int(duration * fs), - 1. / duration + if "time" in self.marginalization: + fs = kwargs["sampling_rate"] + self.kwargs["tc_array"] = jnp.fft.fftfreq( + int(duration * fs / 2), 1.0 / duration ) - self.kwargs['pad_low'] = jnp.zeros(int(self.frequencies[0] * duration)) - if jnp.isclose(self.frequencies[-1], fs / 2. - 1. / duration): - self.kwargs['pad_high'] = jnp.array([]) + self.kwargs["pad_low"] = jnp.zeros(int(self.frequencies[0] * duration)) + if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): + self.kwargs["pad_high"] = jnp.array([]) else: - self.kwargs['pad_high'] = jnp.zeros( - int((fs / 2. - 1. / duration - self.frequencies[-1]) * duration) + self.kwargs["pad_high"] = jnp.zeros( + int( + (fs / 2.0 - 1.0 / duration - self.frequencies[-1]) + * duration + ) ) else: self.param_func = lambda x: x diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 2075a87f..480f98f3 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -207,41 +207,21 @@ def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, ** log_likelihood += -optimal_SNR / 2 # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs['tc_range'] - tc_array = kwargs['tc_array'] - pad_low = kwargs['pad_low'] - pad_high = kwargs['pad_high'] - fs = kwargs['sampling_rate'] + tc_range = kwargs["tc_range"] + tc_array = kwargs["tc_array"] + pad_low = kwargs["pad_low"] + pad_high = kwargs["pad_high"] # padding the complex_h_inner_d - # this array is the hd*/S for f in [-fs / 2, -df] - complex_h_inner_d_negative_f = jnp.concatenate( - (jnp.zeros(len(pad_high) + 1), - jnp.flip(complex_h_inner_d).conj(), jnp.zeros(len(pad_low) - 1)) - ) # this array is the hd*/S for f in [0, fs / 2 - df] complex_h_inner_d_positive_f = jnp.concatenate( (pad_low, complex_h_inner_d, pad_high) ) - # combing to get the complete f - # using the convention of fftfreq in numpy - # i.e. f in [-fs / 2, -fs / 2 + df... -df, 0, df, ... fs / 2 - df] - complex_h_inner_d_full_f = jnp.concatenate( - (complex_h_inner_d_negative_f, complex_h_inner_d_positive_f) - ) - # since we go from one-sided to two-sided frequency range - # we need to introduce a factor of 2 correction - complex_h_inner_d_full_f /= 2. - # make use of the fft # which then return the exp(-i2pift_c) # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_full_f, norm='backward') - # this extra factor is due to f = -fs / 2 + j * df, where j is the array index - # since the f2 is usally a power of 2, it usally has no effect - # but it is here to impreove the code readibility - fft_h_inner_d *= jnp.exp(-1j * jnp.pi * fs) + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") # set the values to -inf when it is outside the tc range # so that they will disappear after the logsumexp @@ -271,41 +251,21 @@ def phase_time_marginalized_likelihood( log_likelihood += -optimal_SNR / 2 # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs['tc_range'] - tc_array = kwargs['tc_array'] - pad_low = kwargs['pad_low'] - pad_high = kwargs['pad_high'] - fs = kwargs['sampling_rate'] + tc_range = kwargs["tc_range"] + tc_array = kwargs["tc_array"] + pad_low = kwargs["pad_low"] + pad_high = kwargs["pad_high"] # padding the complex_h_inner_d - # this array is the hd*/S for f in [-fs / 2, -df] - complex_h_inner_d_negative_f = jnp.concatenate( - (jnp.zeros(len(pad_high) + 1), - jnp.flip(complex_h_inner_d).conj(), jnp.zeros(len(pad_low) - 1)) - ) # this array is the hd*/S for f in [0, fs / 2 - df] complex_h_inner_d_positive_f = jnp.concatenate( (pad_low, complex_h_inner_d, pad_high) ) - # combing to get the complete f - # using the convention of fftfreq in numpy - # i.e. f in [-fs / 2, -fs / 2 + df... -df, 0, df, ... fs / 2 - df] - complex_h_inner_d_full_f = jnp.concatenate( - (complex_h_inner_d_negative_f, complex_h_inner_d_positive_f) - ) - # since we go from one-sided to two-sided frequency range - # we need to introduce a factor of 2 correction - complex_h_inner_d_full_f /= 2. - # make use of the fft # which then return the exp(-i2pift_c) # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_full_f, norm='backward') - # this extra factor is due to f = -fs / 2 + j * df, where j is the array index - # since the f2 is usally a power of 2, it usally has no effect - # but it is here to impreove the code readibility - fft_h_inner_d *= jnp.exp(-1j * jnp.pi * fs) + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") # set the values to -inf when it is outside the tc range # so that they will disappear after the logsumexp From 145b41a8050a6821f9b06d6431f12263663cf837 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Feb 2024 12:13:58 +0100 Subject: [PATCH 04/10] Fixing conflict --- src/jimgw/single_event/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 480f98f3..df9df401 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,5 +1,6 @@ import jax.numpy as jnp from jax.scipy.special import i0e, logsumexp +from jax.scipy.integrate import trapezoid from jax import jit from jaxtyping import Float, Array @@ -35,7 +36,7 @@ def inner_product( # psd_interp = jnp.interp(frequency, psd_frequency, psd) df = frequency[1] - frequency[0] integrand = jnp.conj(h1) * h2 / psd - return 4.0 * jnp.real(jnp.trapz(integrand, dx=df)) + return 4.0 * jnp.real(trapezoid(integrand, dx=df)) @jit From 0cbf45c4b8c7601bbf45604a3d38cdb3d5c253ef Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 20 Feb 2024 11:44:13 -0500 Subject: [PATCH 05/10] Move internel likelihood function from utils.py to likelihood.py. --- src/jimgw/single_event/likelihood.py | 151 +++++++++++++++++++++++++-- src/jimgw/single_event/utils.py | 121 +-------------------- 2 files changed, 146 insertions(+), 126 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 594317cd..57069c73 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +from jax.scipy.special import logsumexp import numpy as np import numpy.typing as npt from astropy.time import Time @@ -9,12 +10,7 @@ from jimgw.single_event.detector import Detector from jimgw.prior import Prior -from jimgw.single_event.utils import ( - original_likelihood, - phase_marginalized_likelihood, - time_marginalized_likelihood, - phase_time_marginalized_likelihood, -) +from jimgw.single_event.utils import log_i0 from jimgw.single_event.waveform import Waveform from jimgw.base import LikelihoodBase @@ -491,3 +487,146 @@ def y(x): "TransientLikelihoodFD": TransientLikelihoodFD, "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, } + + +def original_likelihood( + params: dict[str, Float], + h_sky: dict[str, Float[Array, " n_dim"]], + detectors: list[Detector], + freqs: Float[Array, " n_dim"], + align_time: Float, + **kwargs, +) -> Float: + log_likelihood = 0.0 + df = freqs[1] - freqs[0] + for detector in detectors: + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + match_filter_SNR = ( + 4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real + ) + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + log_likelihood += match_filter_SNR - optimal_SNR / 2 + + return log_likelihood + + +def phase_marginalized_likelihood( + params: dict[str, Float], + h_sky: dict[str, Float[Array, " n_dim"]], + detectors: list[Detector], + freqs: Float[Array, " n_dim"], + align_time: Float, + **kwargs, +) -> Float: + log_likelihood = 0.0 + complex_d_inner_h = 0.0 + df = freqs[1] - freqs[0] + for detector in detectors: + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_d_inner_h += 4 * jnp.sum( + (jnp.conj(h_dec) * detector.data) / detector.psd * df + ) + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + log_likelihood += -optimal_SNR / 2 + + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + + return log_likelihood + + +def time_marginalized_likelihood( + params: dict[str, Float], + h_sky: dict[str, Float[Array, " n_dim"]], + detectors: list[Detector], + freqs: Float[Array, " n_dim"], + align_time: Float, + **kwargs, +) -> Float: + log_likelihood = 0.0 + df = freqs[1] - freqs[0] + # using instead of + complex_h_inner_d = jnp.zeros_like(freqs) + for detector in detectors: + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + log_likelihood += -optimal_SNR / 2 + + # fetch the tc range tc_array, lower padding and higher padding + tc_range = kwargs["tc_range"] + tc_array = kwargs["tc_array"] + pad_low = kwargs["pad_low"] + pad_high = kwargs["pad_high"] + + # padding the complex_h_inner_d + # this array is the hd*/S for f in [0, fs / 2 - df] + complex_h_inner_d_positive_f = jnp.concatenate( + (pad_low, complex_h_inner_d, pad_high) + ) + + # make use of the fft + # which then return the exp(-i2pift_c) + # w.r.t. the tc_array + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") + + # set the values to -inf when it is outside the tc range + # so that they will disappear after the logsumexp + fft_h_inner_d = jnp.where( + (tc_array > tc_range[0]) & (tc_array < tc_range[1]), + fft_h_inner_d.real, + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, + ) + + # using the logsumexp to marginalize over the tc prior range + log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array)) + + return log_likelihood + + +def phase_time_marginalized_likelihood( + params: dict[str, Float], + h_sky: dict[str, Float[Array, " n_dim"]], + detectors: list[Detector], + freqs: Float[Array, " n_dim"], + align_time: Float, + **kwargs, +) -> Float: + log_likelihood = 0.0 + df = freqs[1] - freqs[0] + # using instead of + complex_h_inner_d = jnp.zeros_like(freqs) + for detector in detectors: + h_dec = detector.fd_response(freqs, h_sky, params) * align_time + complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + log_likelihood += -optimal_SNR / 2 + + # fetch the tc range tc_array, lower padding and higher padding + tc_range = kwargs["tc_range"] + tc_array = kwargs["tc_array"] + pad_low = kwargs["pad_low"] + pad_high = kwargs["pad_high"] + + # padding the complex_h_inner_d + # this array is the hd*/S for f in [0, fs / 2 - df] + complex_h_inner_d_positive_f = jnp.concatenate( + (pad_low, complex_h_inner_d, pad_high) + ) + + # make use of the fft + # which then return the exp(-i2pift_c) + # w.r.t. the tc_array + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") + + # set the values to -inf when it is outside the tc range + # so that they will disappear after the logsumexp + log_i0_abs_fft = jnp.where( + (tc_array > tc_range[0]) & (tc_array < tc_range[1]), + log_i0(jnp.absolute(fft_h_inner_d)), + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, + ) + + # using the logsumexp to marginalize over the tc prior range + log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) + + return log_likelihood diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index df9df401..d8e5fa5a 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -1,5 +1,5 @@ import jax.numpy as jnp -from jax.scipy.special import i0e, logsumexp +from jax.scipy.special import i0e from jax.scipy.integrate import trapezoid from jax import jit from jaxtyping import Float, Array @@ -161,122 +161,3 @@ def log_i0(x): The natural logarithm of the bessel function """ return jnp.log(i0e(x)) + x - - -def original_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0.0 - df = freqs[1] - freqs[0] - for detector in detectors: - h_dec = detector.fd_response(freqs, h_sky, params) * align_time - match_filter_SNR = ( - 4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real - ) - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real - log_likelihood += match_filter_SNR - optimal_SNR / 2 - - return log_likelihood - - -def phase_marginalized_likelihood( - params, h_sky, detectors, freqs, align_time, **kwargs -): - log_likelihood = 0.0 - complex_d_inner_h = 0.0 - df = freqs[1] - freqs[0] - for detector in detectors: - h_dec = detector.fd_response(freqs, h_sky, params) * align_time - complex_d_inner_h += 4 * jnp.sum( - (jnp.conj(h_dec) * detector.data) / detector.psd * df - ) - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real - log_likelihood += -optimal_SNR / 2 - - log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - - return log_likelihood - - -def time_marginalized_likelihood(params, h_sky, detectors, freqs, align_time, **kwargs): - log_likelihood = 0.0 - df = freqs[1] - freqs[0] - # using instead of - complex_h_inner_d = jnp.zeros_like(freqs) - for detector in detectors: - h_dec = detector.fd_response(freqs, h_sky, params) * align_time - complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real - log_likelihood += -optimal_SNR / 2 - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs["tc_range"] - tc_array = kwargs["tc_array"] - pad_low = kwargs["pad_low"] - pad_high = kwargs["pad_high"] - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - fft_h_inner_d = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - fft_h_inner_d.real, - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array)) - - return log_likelihood - - -def phase_time_marginalized_likelihood( - params, h_sky, detectors, freqs, align_time, **kwargs -): - log_likelihood = 0.0 - df = freqs[1] - freqs[0] - # using instead of - complex_h_inner_d = jnp.zeros_like(freqs) - for detector in detectors: - h_dec = detector.fd_response(freqs, h_sky, params) * align_time - complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real - log_likelihood += -optimal_SNR / 2 - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs["tc_range"] - tc_array = kwargs["tc_array"] - pad_low = kwargs["pad_low"] - pad_high = kwargs["pad_high"] - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - log_i0_abs_fft = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - log_i0(jnp.absolute(fft_h_inner_d)), - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) - - return log_likelihood From 20c52ff848c912c994862ca96c67d3601be50500 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 20 Feb 2024 12:06:48 -0500 Subject: [PATCH 06/10] bypassing yaml type check --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 61b82c5a..3ca31369 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -19,7 +19,7 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): return dumper.represent_list(data.tolist()) -yaml.add_representer(ArrayImpl, jaxarray_representer) +yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore prior_presets = { "Unconstrained_Uniform": prior.Unconstrained_Uniform, From 65fd20822ded9a197dae07a381f3c22b5687a9f3 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Feb 2024 18:17:53 +0100 Subject: [PATCH 07/10] Allowing the reference parameters for Heterodyned likelihood to be provided --- src/jimgw/single_event/likelihood.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 57069c73..d4d90b52 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -169,6 +169,7 @@ def __init__( post_trigger_duration: float = 2, popsize: int = 100, n_loops: int = 2000, + ref_params: dict = {}, **kwargs, ) -> None: super().__init__( @@ -187,11 +188,15 @@ def __init__( ) self.freq_grid_low = freq_grid[:-1] - print("Finding reference parameters..") - - self.ref_params = self.maximize_likelihood( - bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops - ) + if not ref_params: + print("No reference parameters are provided, finding it...") + self.ref_params = self.maximize_likelihood( + bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops + ) + print(f"The reference parameters are {self.ref_params}") + else: + self.ref_params = ref_params + print(f"Reference parameters provided, which are {self.ref_params}") print("Constructing reference waveforms..") From d5496bae15e31ea7e3eb6f94594cd8b13f819564 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Feb 2024 18:52:20 +0100 Subject: [PATCH 08/10] Initial commit of adding phase marginalization in relative binning likelihood --- src/jimgw/single_event/likelihood.py | 191 +++++++++++++++++++++------ 1 file changed, 147 insertions(+), 44 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index d4d90b52..a21c6c28 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -178,8 +178,26 @@ def __init__( print("Initializing heterodyned likelihood..") - # Get the original frequency grid + self.kwargs = kwargs + if "marginalization" in self.kwargs: + marginalization = self.kwargs["marginalization"] + assert marginalization in [ + "phase", + ], "Heterodyned likelihood only support phase marginalzation" + self.marginalization = marginalization + if self.marginalization == "phase": + self.param_func = lambda x: {**x, "phase_c": 0.0} + self.likelihood_function = phase_marginalized_likelihood + self.rb_likelihood_function = ( + phase_marginalized_relative_binning_likelihood + ) + print("Marginalizing over phase") + else: + self.param_func = lambda x: x + self.likelihood_function = original_likelihood + self.rb_likelihood_function = original_relative_binning_likelihood + # Get the original frequency grid frequency_original = self.frequencies # Get the grid of the relative binning scheme (contains the final endpoint) # and the center points @@ -201,6 +219,8 @@ def __init__( print("Constructing reference waveforms..") self.ref_params["gmst"] = self.gmst + # adjust the params due to different marginalzation scheme + self.ref_params = self.param_func(self.ref_params) self.waveform_low_ref = {} self.waveform_center_ref = {} @@ -292,10 +312,12 @@ def __init__( self.B1_array[detector.name] = B1[mask_heterodyne_center] def evaluate(self, params: dict[str, Float], data: dict) -> Float: - log_likelihood = 0 frequencies_low = self.freq_grid_low frequencies_center = self.freq_grid_center params["gmst"] = self.gmst + # adjust the params due to different marginalzation scheme + params = self.param_func(params) + # evaluate the waveforms as usual waveform_sky_low = self.waveform(frequencies_low, params) waveform_sky_center = self.waveform(frequencies_center, params) align_time_low = jnp.exp( @@ -304,30 +326,23 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: align_time_center = jnp.exp( -1j * 2 * jnp.pi * frequencies_center * (self.epoch + params["t_c"]) ) - for detector in self.detectors: - waveform_low = ( - detector.fd_response(frequencies_low, waveform_sky_low, params) - * align_time_low - ) - waveform_center = ( - detector.fd_response(frequencies_center, waveform_sky_center, params) - * align_time_center - ) - - r0 = waveform_center / self.waveform_center_ref[detector.name] - r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( - frequencies_low - frequencies_center - ) - match_filter_SNR = jnp.sum( - self.A0_array[detector.name] * r0.conj() - + self.A1_array[detector.name] * r1.conj() - ) - optimal_SNR = jnp.sum( - self.B0_array[detector.name] * jnp.abs(r0) ** 2 - + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real - ) - log_likelihood += (match_filter_SNR - optimal_SNR / 2).real - + log_likelihood = self.rb_likelihood_function( + params, + waveform_sky_low, + self.A0_array, + self.A1_array, + self.B0_array, + self.B1_array, + waveform_sky_center, + self.waveform_low_ref, + self.waveform_center_ref, + self.detectors, + frequencies_low, + frequencies_center, + align_time_low, + align_time_center, + **self.kwargs, + ) return log_likelihood def evaluate_original( @@ -340,29 +355,22 @@ def evaluate_original( """ log_likelihood = 0 frequencies = self.frequencies - df = frequencies[1] - frequencies[0] params["gmst"] = self.gmst + # adjust the params due to different marginalzation scheme + params = self.param_func(params) + # evaluate the waveform as usual waveform_sky = self.waveform(frequencies, params) align_time = jnp.exp( -1j * 2 * jnp.pi * frequencies * (self.epoch + params["t_c"]) ) - for detector in self.detectors: - waveform_dec = ( - detector.fd_response(frequencies, waveform_sky, params) * align_time - ) - match_filter_SNR = ( - 4 - * jnp.sum( - (jnp.conj(waveform_dec) * detector.data) / detector.psd * df - ).real - ) - optimal_SNR = ( - 4 - * jnp.sum( - jnp.conj(waveform_dec) * waveform_dec / detector.psd * df - ).real - ) - log_likelihood += match_filter_SNR - optimal_SNR / 2 + log_likelihood = self.likelihood_function( + params, + waveform_sky, + self.detectors, + frequencies, + align_time, + **self.kwargs, + ) return log_likelihood @staticmethod @@ -635,3 +643,98 @@ def phase_time_marginalized_likelihood( log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) return log_likelihood + + +def original_relative_binning_likelihood( + params, + A0_array, + A1_array, + B0_array, + B1_array, + waveform_sky_low, + waveform_sky_center, + waveform_low_ref, + waveform_center_ref, + detectors, + frequencies_low, + frequencies_center, + align_time_low, + align_time_center, + **kwargs, +): + + log_likelihood = 0.0 + + for detector in detectors: + waveform_low = ( + detector.fd_response(frequencies_low, waveform_sky_low, params) + * align_time_low + ) + waveform_center = ( + detector.fd_response(frequencies_center, waveform_sky_center, params) + * align_time_center + ) + + r0 = waveform_center / waveform_center_ref[detector.name] + r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + match_filter_SNR = jnp.sum( + A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += (match_filter_SNR - optimal_SNR / 2).real + + return log_likelihood + + +def phase_marginalized_relative_binning_likelihood( + params, + A0_array, + A1_array, + B0_array, + B1_array, + waveform_sky_low, + waveform_sky_center, + waveform_low_ref, + waveform_center_ref, + detectors, + frequencies_low, + frequencies_center, + align_time_low, + align_time_center, + **kwargs, +): + + log_likelihood = 0.0 + + complex_d_inner_h = jnp.zeros_like(A0_array) + + for detector in detectors: + waveform_low = ( + detector.fd_response(frequencies_low, waveform_sky_low, params) + * align_time_low + ) + waveform_center = ( + detector.fd_response(frequencies_center, waveform_sky_center, params) + * align_time_center + ) + r0 = waveform_center / waveform_center_ref[detector.name] + r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + complex_d_inner_h += ( + A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += -optimal_SNR.real / 2 + + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + + return log_likelihood From c8ae3f20da38fe9970daaa2afa642942054803f6 Mon Sep 17 00:00:00 2001 From: Tsun Ho Pang Date: Tue, 20 Feb 2024 11:04:59 -0800 Subject: [PATCH 09/10] Minor fix --- src/jimgw/single_event/likelihood.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index a21c6c28..b0bee81b 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -328,11 +328,11 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: ) log_likelihood = self.rb_likelihood_function( params, - waveform_sky_low, self.A0_array, self.A1_array, self.B0_array, self.B1_array, + waveform_sky_low, waveform_sky_center, self.waveform_low_ref, self.waveform_center_ref, @@ -708,10 +708,8 @@ def phase_marginalized_relative_binning_likelihood( align_time_center, **kwargs, ): - log_likelihood = 0.0 - - complex_d_inner_h = jnp.zeros_like(A0_array) + complex_d_inner_h = 0.0 for detector in detectors: waveform_low = ( @@ -726,7 +724,7 @@ def phase_marginalized_relative_binning_likelihood( r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center ) - complex_d_inner_h += ( + complex_d_inner_h += jnp.sum( A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() ) optimal_SNR = jnp.sum( From ee7489ae48e40cd9ec99dc59d23ca40a3387ab15 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Tue, 20 Feb 2024 11:27:22 -0800 Subject: [PATCH 10/10] Adjusting the reference parameter's eta in case it is too close to 0.25 --- src/jimgw/single_event/likelihood.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index b0bee81b..08c9e315 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -208,14 +208,22 @@ def __init__( if not ref_params: print("No reference parameters are provided, finding it...") - self.ref_params = self.maximize_likelihood( + ref_params = self.maximize_likelihood( bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops ) + self.ref_params = {key: float(value) for key, value in ref_params.items()} print(f"The reference parameters are {self.ref_params}") else: self.ref_params = ref_params print(f"Reference parameters provided, which are {self.ref_params}") + # safe guard for the reference parameters + # since ripple cannot handle eta=0.25 + if jnp.isclose(self.ref_params["eta"], 0.25): + self.ref_params["eta"] = 0.249995 + print("The eta of the reference parameter is close to 0.25") + print(f"The eta is adjusted to {self.ref_params['eta']}") + print("Constructing reference waveforms..") self.ref_params["gmst"] = self.gmst