From c114ba01652d4d4ec8d35e4c80e344c4dd17fee5 Mon Sep 17 00:00:00 2001 From: Max Isi Date: Fri, 18 Oct 2024 11:55:51 -0400 Subject: [PATCH] propagating detector changes to likelihood --- src/jimgw/single_event/data.py | 119 ++++++++++---- src/jimgw/single_event/detector.py | 232 ++++++++++----------------- src/jimgw/single_event/likelihood.py | 41 +++-- 3 files changed, 194 insertions(+), 198 deletions(-) diff --git a/src/jimgw/single_event/data.py b/src/jimgw/single_event/data.py index 2acd9531..af9deda8 100644 --- a/src/jimgw/single_event/data.py +++ b/src/jimgw/single_event/data.py @@ -1,41 +1,36 @@ __include__ = ["Data", "PowerSpectrum"] -from abc import ABC, abstractmethod +from abc import ABC -import jax import jax.numpy as jnp import numpy as np -import requests from gwpy.timeseries import TimeSeries -from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped, Complex -from typing import Optional, Any -from beartype import beartype as typechecker +from jaxtyping import Array, Float, Complex, PRNGKeyArray +from typing import Optional +# from beartype import beartype as typechecker from scipy.interpolate import interp1d import scipy.signal as sig from scipy.signal.windows import tukey - -from jimgw.constants import C_SI, EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS -from jimgw.single_event.wave import Polarization import logging +import jax + DEG_TO_RAD = jnp.pi / 180 # TODO: Need to expand this list. Currently it is only O3. asd_file_dict = { - "H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt", - "L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt", - "V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt", + "H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt", # noqa: E501 + "L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt", # noqa: E501 + "V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt", # noqa: E501 } class Data(ABC): - """ - Base class for all data. The time domain data are considered the primary - entitiy; the Fourier domain data are derived from an FFT after applying a - window. The structure is set up so that :attr:`td` and :attr:`fd` are always - Fourier conjugates of each other: the one-sided Fourier series is complete - up to the Nyquist frequency - + """Base class for all data. The time domain data are considered the primary + entity; the Fourier domain data are derived from an FFT after applying a + window. The structure is set up so that :attr:`td` and :attr:`fd` are + always Fourier conjugates of each other: the one-sided Fourier series is + complete up to the Nyquist frequency. """ name: str @@ -108,13 +103,14 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]), self.delta_t = delta_t self.epoch = epoch if window is None: - self.window = jnp.ones_like(self.td) + self.set_tukey_window() else: self.window = window self.name = name or '' def __repr__(self): - return f"{self.__class__.__name__}(name='{self.name}', delta_t={self.delta_t}, epoch={self.epoch})" + return f"{self.__class__.__name__}(name='{self.name}', " + \ + f"delta_t={self.delta_t}, epoch={self.epoch})" def __bool__(self) -> bool: """Check if the data is empty.""" @@ -215,7 +211,8 @@ def from_gwosc(cls, data_td = TimeSeries.fetch_open_data(ifo, gps_start_time, gps_end_time, cache=cache, **kws) - return cls(data_td.value, data_td.dt.value, data_td.epoch.value, ifo) + return cls(data_td.value, data_td.dt.value, data_td.epoch.value, ifo) # type: ignore # noqa: E501 + class PowerSpectrum(ABC): name: str @@ -240,7 +237,7 @@ def duration(self) -> Float: @property def sampling_frequency(self) -> Float: """Sampling frequency of the data in Hz.""" - return self.frequencies[-1] * 2 + return self.frequencies[-1] * 2 def __init__(self, values: Float[Array, " n_freq"] = jnp.array([]), frequencies: Float[Array, " n_freq"] = jnp.array([]), @@ -252,10 +249,15 @@ def __init__(self, values: Float[Array, " n_freq"] = jnp.array([]), self.name = name or '' def __repr__(self) -> str: - return f"{self.__class__.__name__}(name='{self.name}', frequencies={self.frequencies})" + return f"{self.__class__.__name__}(name='{self.name}', " + \ + f"frequencies={self.frequencies})" + + def __bool__(self) -> bool: + """Check if the power spectrum is empty.""" + return len(self.values) > 0 - def slice(self, f_min: float, f_max: float) -> \ - tuple[Float[Array, " n_sample"], Float[Array, " n_sample"]]: + def frequency_slice(self, f_min: float, f_max: float) -> \ + tuple[Float[Array, " n_sample"], Float[Array, " n_sample"]]: """Slice the power spectrum. Arguments @@ -270,24 +272,75 @@ def slice(self, f_min: float, f_max: float) -> \ psd_slice: PowerSpectrum Sliced power spectrum. """ - values = self.values[(self.frequencies >= f_min) & - (self.frequencies <= f_max)] - frequencies = self.frequencies[(self.frequencies >= f_min) & - (self.frequencies <= f_max)] - return values, frequencies + mask = (self.frequencies >= f_min) & (self.frequencies <= f_max) + return self.values[mask], self.frequencies[mask] - def interpolate(self, f: Float[Array, " n_sample"]) -> "PowerSpectrum": + def interpolate(self, f: Float[Array, " n_sample"], + kind: str = 'cubic', **kws) -> "PowerSpectrum": """Interpolate the power spectrum to a new set of frequencies. Arguments --------- f: array Frequencies to interpolate the power spectrum to. + kind: str, optional + Interpolation method (default: 'cubic') + **kws: dict, optional + Keyword arguments for `scipy.interpolate.interp1d` Returns ------- psd_interp: array Interpolated power spectrum. """ - interp = interp1d(self.frequencies, self.values, kind='cubic') + interp = interp1d(self.frequencies, self.values, kind=kind, **kws) return PowerSpectrum(interp(f), f, self.name) + + def simulate_data( + self, + key: PRNGKeyArray, + # freqs: Float[Array, " n_sample"], + # h_sky: dict[str, Float[Array, " n_sample"]], + # params: dict[str, Float], + # psd_file: str = "", + ) -> Complex[Array, " n_sample"]: + """ + Inject a signal into the detector data. + + Parameters + ---------- + key : PRNGKeyArray + JAX PRNG key. + h_sky : dict[str, Float[Array, " n_sample"]] + Array of waveforms in the sky frame. The key is the polarization + mode. + params : dict[str, Float] + Dictionary of parameters. + psd_file : str + Path to the PSD file. + + Returns + ------- + None + """ + key, subkey = jax.random.split(key, 2) + var = self.values / (4 * self.delta_f) + noise_real = jax.random.normal(key, shape=var.shape) * jnp.sqrt(var) + noise_imag = jax.random.normal(subkey, shape=var.shape) * jnp.sqrt(var) + return noise_real + 1j * noise_imag + + # WIP: this should be moved to Detector class + + # align_time = jnp.exp( + # -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"]) + # ) + # signal = self.fd_response(freqs, h_sky, params) * align_time + # self.data = signal + noise_real + 1j * noise_imag + + # # also calculate the optimal SNR and match filter SNR + # optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real) + # match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR + + # print(f"For detector {self.name}:") + # print(f"The injected optimal SNR is {optimal_SNR}") + # print(f"The injected match filter SNR is {match_filter_SNR}") diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index fce30d72..7a26f876 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -10,6 +10,7 @@ from scipy.interpolate import interp1d from scipy.signal.windows import tukey from . import data as jd +from typing import Optional from jimgw.constants import C_SI, EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS from jimgw.single_event.wave import Polarization @@ -28,15 +29,17 @@ class Detector(ABC): - """ - Base class for all detectors. - + """Base class for all detectors. """ name: str - data: Float[Array, " n_sample"] - psd: Float[Array, " n_sample"] + # NOTE: for some detectors (e.g. LISA, ET) data could be a list of Data + # objects so this might be worth revisiting + data: jd.Data + psd: jd.PowerSpectrum + + frequency_bounds: tuple[float, float] = (0., float("inf")) @abstractmethod def fd_response( @@ -46,9 +49,9 @@ def fd_response( params: dict, **kwargs, ) -> Float[Array, " n_sample"]: + """Modulate the waveform in the sky frame by the detector response + in the frequency domain. """ - Modulate the waveform in the sky frame by the detector response - in the frequency domain.""" pass @abstractmethod @@ -59,11 +62,37 @@ def td_response( params: dict, **kwargs, ) -> Float[Array, " n_sample"]: + """Modulate the waveform in the sky frame by the detector response + in the time domain. """ - Modulate the waveform in the sky frame by the detector response - in the time domain.""" pass + def set_frequency_bounds(self, f_min: Optional[float] = None, + f_max: Optional[float] = None) -> None: + """Set the frequency bounds for the detector. + + Parameters + ---------- + f_min : float + Minimum frequency. + f_max : float + Maximum frequency. + """ + bounds = list(self.frequency_bounds) + if f_min is not None: + bounds[0] = f_min + if f_max is not None: + bounds[1] = f_max + self.frequency_bounds = tuple(bounds) # type: ignore + + @property + def fd_data_slice(self): + return self.data.frequency_slice(*self.frequency_bounds) + + @property + def psd_slice(self): + return self.psd.frequency_slice(*self.frequency_bounds) + class GroundBased2G(Detector): """Object representing a ground-based detector. Contains information @@ -115,8 +144,8 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name})" def __init__(self, name: str, latitude: float = 0, longitude: float = 0, - elevation: float = 0, xarm_azimuth: float = 0, - yarm_azimuth: float = 0, xarm_tilt: float = 0, + elevation: float = 0, xarm_azimuth: float = 0, + yarm_azimuth: float = 0, xarm_tilt: float = 0, yarm_tilt: float = 0, modes: str = "pc"): self.name = name @@ -130,10 +159,7 @@ def __init__(self, name: str, latitude: float = 0, longitude: float = 0, self.polarization_mode = [Polarization(m) for m in modes] self.data = jd.Data() - - # self.frequencies = jnp.array([]) - # self.data = jnp.array([]) - # self.psd = jnp.array([]) + self.psd = jd.PowerSpectrum() @staticmethod def _get_arm( @@ -159,10 +185,12 @@ def _get_arm( """ e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0]) e_lat = jnp.array( - [-jnp.sin(lat) * jnp.cos(lon), -jnp.sin(lat) * jnp.sin(lon), jnp.cos(lat)] + [-jnp.sin(lat) * jnp.cos(lon), -jnp.sin(lat) + * jnp.sin(lon), jnp.cos(lat)] ) e_h = jnp.array( - [jnp.cos(lat) * jnp.cos(lon), jnp.cos(lat) * jnp.sin(lon), jnp.sin(lat)] + [jnp.cos(lat) * jnp.cos(lon), jnp.cos(lat) + * jnp.sin(lon), jnp.sin(lat)] ) return ( @@ -209,9 +237,8 @@ def tensor(self) -> Float[Array, " 3 3"]: """ # TODO: this could easily be generalized for other detector geometries arm1, arm2 = self.arms - return 0.5 * ( - jnp.einsum("i,j->ij", arm1, arm1) - jnp.einsum("i,j->ij", arm2, arm2) - ) + return 0.5 * jnp.einsum("i,j->ij", arm1, arm1) - \ + jnp.einsum("i,j->ij", arm2, arm2) @property def vertex(self) -> Float[Array, " 3"]: @@ -238,85 +265,6 @@ def vertex(self) -> Float[Array, " 3"]: z = ((minor / major) ** 2 * r + h) * jnp.sin(lat) return jnp.array([x, y, z]) - def load_data( - self, - trigger_time: Float, - gps_start_pad: int, - gps_end_pad: int, - f_min: Float, - f_max: Float, - psd_pad: int = 16, - tukey_alpha: Float = 0.2, - gwpy_kwargs: dict | None = None, - ) -> None: - """Load open GW detector data from GWOSC using GWpy. Essentially, this - is a wrapper around the GWpy :meth:`TimeSeries.fetch_open_data` - method. - - Parameters - ---------- - trigger_time : Float - The GPS time of the trigger. - gps_start_pad : int - The amount of time before the trigger to fetch data. - gps_end_pad : int - The amount of time after the trigger to fetch data. - f_min : Float - The minimum frequency to fetch data. - f_max : Float - The maximum frequency to fetch data. - tukey_alpha : Float - The ``alpha`` parameter for the Tukey window; this represents - the fraction of the segment duration that is tapered on each end - (defaults to 0.2). - gwpy_kwargs : dict, optional - Additional keyword arguments to pass to the GWpy - :meth:`TimeSeries.fetch_open_data` method, defaults to - {}. - """ - if gwpy_kwargs is None: - gwpy_kwargs = _DEF_GWPY_KWARGS - - duration = gps_end_pad + gps_start_pad - logging.info(f"Fetching {duration} s of {self.name} data around " - f"{trigger_time} from GWOSC.") - - data_td = TimeSeries.fetch_open_data( - self.name, - trigger_time - gps_start_pad, - trigger_time + gps_end_pad, - **gwpy_kwargs, - ) - assert isinstance(data_td, TimeSeries), "Data is not a TimeSeries object." - segment_length = data_td.duration.value - n = len(data_td) - delta_t = data_td.dt.value # type: ignore - data = jnp.fft.rfft(jnp.array(data_td.value) * tukey(n, tukey_alpha)) * delta_t - freq = jnp.fft.rfftfreq(n, delta_t) - # TODO: Check if this is the right way to fetch PSD - start_psd = int(trigger_time) - gps_start_pad - 2 * psd_pad - end_psd = int(trigger_time) - gps_start_pad - psd_pad - - print("Fetching PSD data...") - psd_data_td = TimeSeries.fetch_open_data( - self.name, start_psd, end_psd, **gwpy_kwargs - ) - assert isinstance( - psd_data_td, TimeSeries - ), "PSD data is not a TimeSeries object." - psd = psd_data_td.psd( - fftlength=segment_length - ).value # TODO: Check whether this is sright. - - print("Finished loading data.") - - self.frequencies = freq[(freq > f_min) & (freq < f_max)] - self.data = data[(freq > f_min) & (freq < f_max)] - self.psd = psd[(freq > f_min) & (freq < f_max)] - load_data.__doc__ = load_data.__doc__.format(_DEF_GWPY_KWARGS) - - # def load_data(self, data: ) - def fd_response( self, frequency: Float[Array, " n_sample"], @@ -422,55 +370,6 @@ def antenna_pattern(self, ra: Float, dec: Float, psi: Float, gmst: Float) -> dic return antenna_patterns - def inject_signal( - self, - key: PRNGKeyArray, - freqs: Float[Array, " n_sample"], - h_sky: dict[str, Float[Array, " n_sample"]], - params: dict[str, Float], - psd_file: str = "", - ) -> None: - """ - Inject a signal into the detector data. - - Parameters - ---------- - key : PRNGKeyArray - JAX PRNG key. - freqs : Float[Array, " n_sample"] - Array of frequencies. - h_sky : dict[str, Float[Array, " n_sample"]] - Array of waveforms in the sky frame. The key is the polarization mode. - params : dict[str, Float] - Dictionary of parameters. - psd_file : str - Path to the PSD file. - - Returns - ------- - None - """ - self.frequencies = freqs - self.psd = self.load_psd(freqs, psd_file) - key, subkey = jax.random.split(key, 2) - var = self.psd / (4 * (freqs[1] - freqs[0])) - noise_real = jax.random.normal(key, shape=freqs.shape) * jnp.sqrt(var / 2.0) - noise_imag = jax.random.normal(subkey, shape=freqs.shape) * jnp.sqrt(var / 2.0) - align_time = jnp.exp( - -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"]) - ) - - signal = self.fd_response(freqs, h_sky, params) * align_time - self.data = signal + noise_real + 1j * noise_imag - - # also calculate the optimal SNR and match filter SNR - optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real) - match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR - - print(f"For detector {self.name}:") - print(f"The injected optimal SNR is {optimal_SNR}") - print(f"The injected match filter SNR is {match_filter_SNR}") - @jaxtyped(typechecker=typechecker) def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" @@ -485,10 +384,45 @@ def load_psd( else: f, psd_vals = np.loadtxt(psd_file, unpack=True) - psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore + psd = interp1d(f, psd_vals, fill_value=( + psd_vals[0], psd_vals[-1]))(freqs) # type: ignore psd = jnp.array(psd) return psd + def set_data(self, data: jd.Data | Array, **kws) -> None: + """Add data to detector. + + Arguments + --------- + data : jd.Data | Array + Data to be added to the detector, either as a `jd.Data` object + or as a timeseries array. + kws : dict + Additional keyword arguments to pass to `jd.Data` constructor. + """ + if isinstance(data, jd.Data): + self.data = data + else: + self.data = jd.Data(data, **kws) + + def set_psd(self, psd: jd.PowerSpectrum | Array, **kws) -> None: + """Add PSD to detector. + + Arguments + --------- + psd : jd.PowerSpectrum | Array + PSD to be added to the detector, either as a `jd.PowerSpectrum` + object or as a timeseries array. + kws : dict + Additional keyword arguments to pass to `jd.PowerSpectrum` + constructor. + """ + if isinstance(psd, jd.PowerSpectrum): + self.psd = psd + else: + # not clear if we want to support this + self.psd = jd.PowerSpectrum(psd, **kws) + H1 = GroundBased2G( "H1", diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index f8f69a7a..fe280f1b 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -11,10 +11,11 @@ from jimgw.base import LikelihoodBase from jimgw.prior import Prior -from jimgw.single_event.detector import Detector +from jimgw.single_event.detector import Detector, GroundBased2G from jimgw.utils import log_i0 from jimgw.single_event.waveform import Waveform from jimgw.transforms import BijectiveTransform, NtoMTransform +import logging class SingleEventLikelihood(LikelihoodBase): @@ -40,31 +41,37 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, + f_min: float = 0, + f_max: float = float("inf"), trigger_time: float = 0, - duration: float = 4, post_trigger_duration: float = 2, - # TODO: apply f_min and f_max and get frequency domain data - # here **kwargs, ) -> None: self.detectors = detectors - assert jnp.all( - jnp.array( - [ - (self.detectors[0].frequencies == detector.frequencies).all() # type: ignore - for detector in self.detectors - ] - ) + + # TODO: we can probably make this a bit more elegant + for det in detectors: + if not det.data.has_fd: + logging.info("Computing FFT with default window") + det.data.fft() + det.set_frequency_bounds(f_min, f_max) + + freqs = [d.data.frequency_slice(f_min, f_max)[1] for d in detectors] + assert all([ + (freqs[0] + == freq).all() # noqa: W503 + for freq in freqs] ), "The detectors must have the same frequency grid" - self.frequencies = self.detectors[0].frequencies # type: ignore + self.frequencies = freqs[0] # type: ignore self.waveform = waveform self.trigger_time = trigger_time self.gmst = ( - Time(trigger_time, format="gps").sidereal_time("apparent", "greenwich").rad + Time(trigger_time, format="gps").sidereal_time("apparent", + "greenwich").rad ) self.trigger_time = trigger_time - self.duration = duration + self.duration = duration = self.detectors[0].data.duration self.post_trigger_duration = post_trigger_duration self.kwargs = kwargs if "marginalization" in self.kwargs: @@ -647,10 +654,12 @@ def original_likelihood( df = freqs[1] - freqs[0] for detector in detectors: h_dec = detector.fd_response(freqs, h_sky, params) * align_time + data = detector.fd_data_slice + psd = detector.psd_slice match_filter_SNR = ( - 4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real + 4 * jnp.sum((jnp.conj(h_dec) * data) / psd * df).real ) - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real log_likelihood += match_filter_SNR - optimal_SNR / 2 return log_likelihood