diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 213d97c8..b5595520 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,17 +4,17 @@ repos: hooks: - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.290' + rev: 'v0.1.6' hooks: - id: ruff args: ["--fix"] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.327 + rev: v1.1.338 hooks: - id: pyright additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions] - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 + rev: 1.7.1 hooks: - id: nbqa-black additional_dependencies: [ipython==8.12, black] diff --git a/example/GW150914.py b/example/GW150914.py index a9c1c9c8..8ba26ead 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -17,15 +17,17 @@ # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 -start = gps - 2 -end = gps + 2 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration fmin = 20.0 fmax = 1024.0 ifos = ["H1", "L1"] -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) q_prior = Unconstrained_Uniform( @@ -91,6 +93,7 @@ post_trigger_duration=2, ) +likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) @@ -100,7 +103,7 @@ jim = Jim( likelihood, prior, - n_loop_training=200, + n_loop_training=100, n_loop_production=10, n_local_steps=150, n_global_steps=150, @@ -117,5 +120,4 @@ local_sampler_arg=local_sampler_arg, ) -# jim.maximize_likelihood([prior.xmin, prior.xmax]) jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py new file mode 100644 index 00000000..08e091b6 --- /dev/null +++ b/example/GW150914_heterodyne.py @@ -0,0 +1,89 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD +from jimgw.prior import Uniform +import jax.numpy as jnp +import jax + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +prior = Uniform( + xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], + xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], + naming=[ + "M_c", + "q", + "s1_z", + "s2_z", + "d_L", + "t_c", + "phase_c", + "cos_iota", + "psi", + "ra", + "sin_dec", + ], + transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), + "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), + "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +) + +likelihood = HeterodynedTransientLikelihoodFD( + [H1, L1], + prior=prior, + bounds=[prior.xmin, prior.xmax], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=duration, + post_trigger_duration=post_trigger_duration, + n_loops=300 +) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +jim = Jim( + likelihood, + prior, + n_loop_training=100, + n_loop_production=10, + n_local_steps=150, + n_global_steps=150, + n_chains=500, + n_epochs=50, + learning_rate=0.001, + max_samples=45000, + momentum=0.9, + batch_size=50000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, +) + +jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW170817.py b/example/GW170817.py new file mode 100644 index 00000000..8f01cc23 --- /dev/null +++ b/example/GW170817.py @@ -0,0 +1,102 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1, V1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD +from jimgw.prior import Uniform +from gwosc.datasets import event_gps +import jax.numpy as jnp +import jax + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +gps = event_gps("GW170817") +duration = 128 +post_trigger_duration = 32 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"]#, "V1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=0.05, gwpy_kwargs={"version": 2, "cache": False}) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=0.05, gwpy_kwargs={"version": 2, "cache": False}) +# V1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.05) + +prior = Uniform( + xmin=[1.18, 0.125, -0.3, -0.3, 1., -0.1, 0.0, -1, 0.0, 0.0, -1.0], + xmax=[1.21, 1.0, 0.3, 0.3, 75., 0.1, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], + naming=[ + "M_c", + "q", + "s1_z", + "s2_z", + "d_L", + "t_c", + "phase_c", + "cos_iota", + "psi", + "ra", + "sin_dec", + ], + transforms={ + "q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2), + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ), + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ), + }, # sin and arcsin are periodize cos_iota and sin_dec +) + +likelihood = HeterodynedTransientLikelihoodFD( + [H1], + prior=prior, + bounds=[prior.xmin, prior.xmax], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=duration, + post_trigger_duration=post_trigger_duration, + n_loops=1000 +) + +# mass_matrix = jnp.eye(11) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[5, 5].set(1e-3) +# local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +# jim = Jim( +# likelihood, +# prior, +# n_loop_training=100, +# n_loop_production=10, +# n_local_steps=150, +# n_global_steps=150, +# n_chains=500, +# n_epochs=50, +# learning_rate=0.001, +# max_samples=45000, +# momentum=0.9, +# batch_size=50000, +# use_global=True, +# keep_quantile=0.0, +# train_thinning=1, +# output_thinning=10, +# local_sampler_arg=local_sampler_arg, +# ) + +# jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index 0ff7478a..d7580335 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -1,26 +1,27 @@ -import jax.numpy as jnp -from jimgw.constants import * -from jimgw.wave import Polarization -from scipy.signal.windows import tukey from abc import ABC, abstractmethod -import equinox as eqx -from jaxtyping import Array, PRNGKeyArray + import jax -from gwpy.timeseries import TimeSeries -from typing import Callable -import requests +import jax.numpy as jnp import numpy as np +import requests +from gwpy.timeseries import TimeSeries +from jaxtyping import Array, PRNGKeyArray from scipy.interpolate import interp1d +from scipy.signal.windows import tukey -DEG_TO_RAD = jnp.pi/180 +from jimgw.constants import * +from jimgw.wave import Polarization + +DEG_TO_RAD = jnp.pi / 180 # TODO: Need to expand this list. Currently it is only O3. -psd_file_dict= { +psd_file_dict = { "H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt", "L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt", "V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt", } + def np2(x): """ Returns the next power of two as big as or larger than x.""" @@ -29,13 +30,14 @@ def np2(x): p = p << 1 return p + class Detector(ABC): - """ + """ Base class for all detectors. """ - name: str + name: str @abstractmethod def load_data(self, data): @@ -44,20 +46,23 @@ def load_data(self, data): @abstractmethod def fd_response(self, frequency: Array, h: Array, params: dict) -> Array: """ - Modulate the waveform in the sky frame by the detector response in the frequency domain.""" + Modulate the waveform in the sky frame by the detector response + in the frequency domain.""" pass @abstractmethod def td_response(self, time: Array, h: Array, params: dict) -> Array: """ - Modulate the waveform in the sky frame by the detector response in the time domain.""" + Modulate the waveform in the sky frame by the detector response + in the time domain.""" pass - + + class GroundBased2G(Detector): polarization_mode: list[Polarization] frequencies: Array = None - data : Array = None + data: Array = None psd: Array = None latitude: float = 0 @@ -71,14 +76,14 @@ class GroundBased2G(Detector): def __init__(self, name: str, **kwargs) -> None: self.name = name - self.latitude = kwargs.get('latitude', 0) - self.longitude = kwargs.get('longitude', 0) - self.elevation = kwargs.get('elevation', 0) - self.xarm_azimuth = kwargs.get('xarm_azimuth', 0) - self.yarm_azimuth = kwargs.get('yarm_azimuth', 0) - self.xarm_tilt = kwargs.get('xarm_tilt', 0) - self.yarm_tilt = kwargs.get('yarm_tilt', 0) - modes = kwargs.get('mode', 'pc') + self.latitude = kwargs.get("latitude", 0) + self.longitude = kwargs.get("longitude", 0) + self.elevation = kwargs.get("elevation", 0) + self.xarm_azimuth = kwargs.get("xarm_azimuth", 0) + self.yarm_azimuth = kwargs.get("yarm_azimuth", 0) + self.xarm_tilt = kwargs.get("xarm_tilt", 0) + self.yarm_tilt = kwargs.get("yarm_tilt", 0) + modes = kwargs.get("mode", "pc") self.polarization_mode = [Polarization(m) for m in modes] @@ -87,7 +92,7 @@ def _get_arm(lat, lon, tilt, azimuth): """ Construct detector-arm vectors in Earth-centric Cartesian coordinates. - Arguments + Parameters --------- lat : float vertex latitude in rad. @@ -99,33 +104,42 @@ def _get_arm(lat, lon, tilt, azimuth): arm azimuth in rad. """ e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0]) - e_lat = jnp.array([-jnp.sin(lat) * jnp.cos(lon), - -jnp.sin(lat) * jnp.sin(lon), jnp.cos(lat)]) - e_h = jnp.array([jnp.cos(lat) * jnp.cos(lon), - jnp.cos(lat) * jnp.sin(lon), jnp.sin(lat)]) - - return (jnp.cos(tilt) * jnp.cos(azimuth) * e_lon + - jnp.cos(tilt) * jnp.sin(azimuth) * e_lat + - jnp.sin(tilt) * e_h) + e_lat = jnp.array( + [-jnp.sin(lat) * jnp.cos(lon), -jnp.sin(lat) * jnp.sin(lon), jnp.cos(lat)] + ) + e_h = jnp.array( + [jnp.cos(lat) * jnp.cos(lon), jnp.cos(lat) * jnp.sin(lon), jnp.sin(lat)] + ) + + return ( + jnp.cos(tilt) * jnp.cos(azimuth) * e_lon + + jnp.cos(tilt) * jnp.sin(azimuth) * e_lat + + jnp.sin(tilt) * e_h + ) @property def arms(self): """ Detector arm vectors (x, y). """ - x = self._get_arm(self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth) - y = self._get_arm(self.latitude, self.longitude, self.yarm_tilt, self.yarm_azimuth) + x = self._get_arm( + self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth + ) + y = self._get_arm( + self.latitude, self.longitude, self.yarm_tilt, self.yarm_azimuth + ) return x, y - + @property def tensor(self): """ Detector tensor defining the strain measurement. """ - #TODO: this could easily be generalized for other detector geometries + # TODO: this could easily be generalized for other detector geometries arm1, arm2 = self.arms - return 0.5 * (jnp.einsum('i,j->ij', arm1, arm1) - - jnp.einsum('i,j->ij', arm2, arm2)) + return 0.5 * ( + jnp.einsum("i,j->ij", arm1, arm1) - jnp.einsum("i,j->ij", arm2, arm2) + ) @property def vertex(self): @@ -140,19 +154,25 @@ def vertex(self): h = self.elevation major, minor = EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS # compute vertex location - r = major**2*(major**2*jnp.cos(lat)**2 + minor**2*jnp.sin(lat)**2)**(-0.5) + r = major**2 * ( + major**2 * jnp.cos(lat) ** 2 + minor**2 * jnp.sin(lat) ** 2 + ) ** (-0.5) x = (r + h) * jnp.cos(lat) * jnp.cos(lon) y = (r + h) * jnp.cos(lat) * jnp.sin(lon) - z = ((minor / major)**2 * r + h)*jnp.sin(lat) + z = ((minor / major) ** 2 * r + h) * jnp.sin(lat) return jnp.array([x, y, z]) - def load_data(self, trigger_time:float, - gps_start_pad: int, - gps_end_pad: int, - f_min: float, - f_max: float, - psd_pad: int = 16, - tukey_alpha: float = 0.2) -> None: + def load_data( + self, + trigger_time: float, + gps_start_pad: int, + gps_end_pad: int, + f_min: float, + f_max: float, + psd_pad: int = 16, + tukey_alpha: float = 0.2, + gwpy_kwargs: dict = {"cache": True}, + ) -> None: """ Load data from the detector. @@ -176,50 +196,65 @@ def load_data(self, trigger_time:float, """ print("Fetching data from {}...".format(self.name)) - data_td = TimeSeries.fetch_open_data(self.name, trigger_time - gps_start_pad, trigger_time + gps_end_pad, cache=True) + data_td = TimeSeries.fetch_open_data( + self.name, + trigger_time - gps_start_pad, + trigger_time + gps_end_pad, + **gwpy_kwargs + ) segment_length = data_td.duration.value n = len(data_td) delta_t = data_td.dt.value - data = jnp.fft.rfft(jnp.array(data_td.value)*tukey(n, tukey_alpha))*delta_t + data = jnp.fft.rfft(jnp.array(data_td.value) * tukey(n, tukey_alpha)) * delta_t freq = jnp.fft.rfftfreq(n, delta_t) # TODO: Check if this is the right way to fetch PSD - start_psd = int(trigger_time) - gps_start_pad - psd_pad # What does Int do here? - end_psd = int(trigger_time) + gps_end_pad + psd_pad + start_psd = ( + int(trigger_time) - gps_start_pad - 2 * psd_pad + ) # What does Int do here? + end_psd = int(trigger_time) - gps_start_pad - psd_pad print("Fetching PSD data...") - psd_data_td = TimeSeries.fetch_open_data(self.name, start_psd, end_psd, cache=True) - psd = psd_data_td.psd(fftlength=segment_length).value # TODO: Check whether this is sright. + psd_data_td = TimeSeries.fetch_open_data( + self.name, start_psd, end_psd, **gwpy_kwargs + ) + psd = psd_data_td.psd( + fftlength=segment_length + ).value # TODO: Check whether this is sright. - print("Finished generating data.") + print("Finished loading data.") - self.frequencies = freq[(freq>f_min)&(freqf_min)&(freqf_min)&(freq f_min) & (freq < f_max)] + self.data = data[(freq > f_min) & (freq < f_max)] + self.psd = psd[(freq > f_min) & (freq < f_max)] def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array: """ Modulate the waveform in the sky frame by the detector response in the frequency domain.""" - ra, dec, psi, gmst = params['ra'], params['dec'], params['psi'], params['gmst'] + ra, dec, psi, gmst = params["ra"], params["dec"], params["psi"], params["gmst"] antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst) timeshift = self.delay_from_geocenter(ra, dec, gmst) - h_detector = jax.tree_util.tree_map(lambda h, antenna: h * antenna * jnp.exp(-2j * jnp.pi * frequency * timeshift), h_sky, antenna_pattern) - return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)),axis=0) + h_detector = jax.tree_util.tree_map( + lambda h, antenna: h + * antenna + * jnp.exp(-2j * jnp.pi * frequency * timeshift), + h_sky, + antenna_pattern, + ) + return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)), axis=0) def td_response(self, time: Array, h: Array, params: Array) -> Array: """ Modulate the waveform in the sky frame by the detector response in the time domain.""" pass - - def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: - """ + """ Calculate time delay between two detectors in geocentric coordinates based on XLALArrivaTimeDiff in TimeDelay.c https://lscsoft.docs.ligo.org/lalsuite/lal/group___time_delay__h.html - Arguments + Parameters --------- ra : float right ascension of the source in rad. @@ -236,12 +271,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: gmst = jnp.mod(gmst, 2 * jnp.pi) phi = ra - gmst theta = jnp.pi / 2 - dec - omega = jnp.array([jnp.sin(theta)*jnp.cos(phi), - jnp.sin(theta)*jnp.sin(phi), - jnp.cos(theta)]) + omega = jnp.array( + [ + jnp.sin(theta) * jnp.cos(phi), + jnp.sin(theta) * jnp.sin(phi), + jnp.cos(theta), + ] + ) return jnp.dot(omega, delta_d) / C_SI - def antenna_pattern(self, ra:float, dec:float, psi:float, gmst:float) -> dict: + def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dict: """ Computes {name} antenna patterns for {modes} polarizations at the specified sky location, orientation and GMST. @@ -250,7 +289,7 @@ def antenna_pattern(self, ra:float, dec:float, psi:float, gmst:float) -> dict: given polarization is the dyadic product between the detector tensor and the corresponding polarization tensor. - Arguments + Parameters --------- ra : float source right ascension in radians. @@ -262,78 +301,90 @@ def antenna_pattern(self, ra:float, dec:float, psi:float, gmst:float) -> dict: Greenwich mean sidereal time (GMST) in radians. modes : str string of polarizations to include, defaults to tensor modes: 'pc'. - + Returns ------- result : list antenna pattern values for {modes}. - """ + """ detector_tensor = self.tensor antenna_patterns = {} for polarization in self.polarization_mode: wave_tensor = polarization.tensor_from_sky(ra, dec, psi, gmst) - antenna_patterns[polarization.name] = jnp.einsum('ij,ij->', detector_tensor, wave_tensor) + antenna_patterns[polarization.name] = jnp.einsum( + "ij,ij->", detector_tensor, wave_tensor + ) return antenna_patterns - def inject_signal(self, - key: PRNGKeyArray, - freqs: Array, - h_sky: dict, - params: dict, - psd_file: str = None) -> None: - """ - """ + def inject_signal( + self, + key: PRNGKeyArray, + freqs: Array, + h_sky: dict, + params: dict, + psd_file: str = None, + ) -> None: + """ """ self.frequencies = freqs self.psd = self.load_psd(freqs, psd_file) key, subkey = jax.random.split(key, 2) var = self.psd / (4 * (freqs[1] - freqs[0])) - noise_real = jax.random.normal(key, shape=freqs.shape)*jnp.sqrt(var) - noise_imag = jax.random.normal(subkey, shape=freqs.shape)*jnp.sqrt(var) - align_time = jnp.exp(-1j*2*jnp.pi*freqs*(params['epoch']+params['t_c'])) + noise_real = jax.random.normal(key, shape=freqs.shape) * jnp.sqrt(var) + noise_imag = jax.random.normal(subkey, shape=freqs.shape) * jnp.sqrt(var) + align_time = jnp.exp( + -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"]) + ) signal = self.fd_response(freqs, h_sky, params) * align_time - self.data = signal + noise_real + 1j*noise_imag + self.data = signal + noise_real + 1j * noise_imag def load_psd(self, freqs: Array, psd_file: str = None) -> None: if psd_file is None: - print("Grabbing GWTC-2 PSD for "+self.name) + print("Grabbing GWTC-2 PSD for " + self.name) url = psd_file_dict[self.name] data = requests.get(url) - open(self.name+".txt", "wb").write(data.content) - f, asd_vals = np.loadtxt(self.name+".txt", unpack=True) + open(self.name + ".txt", "wb").write(data.content) + f, asd_vals = np.loadtxt(self.name + ".txt", unpack=True) else: f, asd_vals = np.loadtxt(psd_file, unpack=True) psd_vals = asd_vals**2 psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) return psd -H1 = GroundBased2G('H1', -latitude = (46 + 27. / 60 + 18.528 / 3600) * DEG_TO_RAD, -longitude = -(119 + 24. / 60 + 27.5657 / 3600) * DEG_TO_RAD, -xarm_azimuth = 125.9994 * DEG_TO_RAD, -yarm_azimuth = 215.9994 * DEG_TO_RAD, -xarm_tilt = -6.195e-4, -yarm_tilt = 1.25e-5, -elevation = 142.554, -mode='pc') - -L1 = GroundBased2G('L1', -latitude = (30 + 33. / 60 + 46.4196 / 3600) * DEG_TO_RAD, -longitude = -(90 + 46. / 60 + 27.2654 / 3600) * DEG_TO_RAD, -xarm_azimuth = 197.7165 * DEG_TO_RAD, -yarm_azimuth = 287.7165 * DEG_TO_RAD, -xarm_tilt = 0 , -yarm_tilt = 0, -elevation = -6.574, -mode='pc') - -V1 = GroundBased2G('V1', -latitude = (43 + 37. / 60 + 53.0921 / 3600) * DEG_TO_RAD, -longitude = (10 + 30. / 60 + 16.1887 / 3600) * DEG_TO_RAD, -xarm_azimuth = 243. * DEG_TO_RAD, -yarm_azimuth = 333. * DEG_TO_RAD, -xarm_tilt = 0 , -yarm_tilt = 0, -elevation = 51.884, -mode='pc') \ No newline at end of file + +H1 = GroundBased2G( + "H1", + latitude=(46 + 27.0 / 60 + 18.528 / 3600) * DEG_TO_RAD, + longitude=-(119 + 24.0 / 60 + 27.5657 / 3600) * DEG_TO_RAD, + xarm_azimuth=125.9994 * DEG_TO_RAD, + yarm_azimuth=215.9994 * DEG_TO_RAD, + xarm_tilt=-6.195e-4, + yarm_tilt=1.25e-5, + elevation=142.554, + mode="pc", +) + +L1 = GroundBased2G( + "L1", + latitude=(30 + 33.0 / 60 + 46.4196 / 3600) * DEG_TO_RAD, + longitude=-(90 + 46.0 / 60 + 27.2654 / 3600) * DEG_TO_RAD, + xarm_azimuth=197.7165 * DEG_TO_RAD, + yarm_azimuth=287.7165 * DEG_TO_RAD, + xarm_tilt=0, + yarm_tilt=0, + elevation=-6.574, + mode="pc", +) + +V1 = GroundBased2G( + "V1", + latitude=(43 + 37.0 / 60 + 53.0921 / 3600) * DEG_TO_RAD, + longitude=(10 + 30.0 / 60 + 16.1887 / 3600) * DEG_TO_RAD, + xarm_azimuth=243.0 * DEG_TO_RAD, + yarm_azimuth=333.0 * DEG_TO_RAD, + xarm_tilt=0, + yarm_tilt=0, + elevation=51.884, + mode="pc", +) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 41dd8ef9..5df9cf48 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -10,10 +10,11 @@ import jax.numpy as jnp from flowMC.sampler.flowHMC import flowHMC + class Jim(object): """ Master class for interfacing with flowMC - + """ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): @@ -24,12 +25,15 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key_set = initialize_rng_keys(n_chains, seed=seed) num_layers = kwargs.get("num_layers", 10) - hidden_size = kwargs.get("hidden_size", [128,128]) + hidden_size = kwargs.get("hidden_size", [128, 128]) num_bins = kwargs.get("num_bins", 8) local_sampler_arg = kwargs.get("local_sampler_arg", {}) - local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix + local_sampler = MALA( + self.posterior, True, local_sampler_arg + ) # Remember to add routine to find automated mass matrix + flowHMC_params = kwargs.get("flowHMC_params", {}) model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) @@ -71,7 +75,7 @@ def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 10 print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose = True) + optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) state = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit @@ -109,39 +113,62 @@ def print_summary(self): production_global_acceptance: Array = production_summary["global_accs"] print("Training summary") - print('=' * 10) + print("=" * 10) for index in range(len(self.Prior.naming)): - print(f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}") - print(f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}") - print(f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}") - print(f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}") - print(f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}") + print( + f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}" + ) + print( + f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}" + ) + print( + f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}" + ) print("Production summary") - print('=' * 10) + print("=" * 10) for index in range(len(self.Prior.naming)): - print(f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}") - print(f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}") - print(f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}") - print(f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}") + print( + f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}" + ) + print( + f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}" + ) def get_samples(self, training: bool = False) -> dict: """ Get the samples from the sampler - Args: - training (bool, optional): If True, return the training samples. Defaults to False. + Parameters + ---------- + training : bool, optional + Whether to get the training samples or the production samples, by default False + + Returns + ------- + dict + Dictionary of samples - Returns: - Array: Samples """ if training: chains = self.Sampler.get_sampler_state(training=True)["chains"] else: chains = self.Sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.add_name(chains.transpose(2,0,1), transform_name=True) + chains = self.Prior.add_name(chains.transpose(2, 0, 1), transform_name=True) return chains def plot(self): - pass \ No newline at end of file + pass diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 3320c7a2..63164805 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -1,27 +1,34 @@ from abc import ABC, abstractmethod -from jaxtyping import Array, Float -from jimgw.waveform import Waveform -from jimgw.detector import Detector + +import jax import jax.numpy as jnp -from astropy.time import Time import numpy as np -from scipy.interpolate import interp1d -import jax +from astropy.time import Time from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer +from jaxtyping import Array, Float +from scipy.interpolate import interp1d + +from jimgw.detector import Detector from jimgw.prior import Prior +from jimgw.waveform import Waveform class LikelihoodBase(ABC): """ Base class for likelihoods. - Note that this likelihood class should work for a some what general class of problems. - In light of that, this class would be some what abstract, but the idea behind it is this - handles two main components of a likelihood: the data and the model. - - It should be able to take the data and model and evaluate the likelihood for a given set of parameters. + Note that this likelihood class should work + for a some what general class of problems. + In light of that, this class would be some what abstract, + but the idea behind it is this handles two main components of a likelihood: + the data and the model. + It should be able to take the data and model and evaluate the likelihood for + a given set of parameters. """ + _model: object + _data: object + @property def model(self): """ @@ -45,7 +52,6 @@ def evaluate(self, params) -> float: class TransientLikelihoodFD(LikelihoodBase): - detectors: list[Detector] waveform: Waveform @@ -84,7 +90,9 @@ def ifos(self): def evaluate( self, params: Array, data: dict - ) -> float: # TODO: Test whether we need to pass data in or with class changes is fine. + ) -> ( + float + ): # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ @@ -117,7 +125,6 @@ def evaluate( class HeterodynedTransientLikelihoodFD(TransientLikelihoodFD): - n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood freq_grid_low: Array # Heterodyned frequency grid @@ -139,7 +146,7 @@ def __init__( waveform: Waveform, prior: Prior, bounds: tuple[Array, Array], - n_bins: int = 101, + n_bins: int = 100, trigger_time: float = 0, duration: float = 4, post_trigger_duration: float = 2, @@ -150,16 +157,35 @@ def __init__( detectors, waveform, trigger_time, duration, post_trigger_duration ) + print("Initializing heterodyned likelihood..") + + # Get the original frequency grid + + assert jnp.all( + jnp.array( + [ + (self.detectors[0].frequencies == detector.frequencies).all() + for detector in self.detectors + ] + ) + ), "The detectors must have the same frequency grid" + frequency_original = self.detectors[0].frequencies + # Get the grid of the relative binning scheme (contains the final endpoint) + # and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( - np.array(frequency_original), n_bins + 1 + np.array(frequency_original), n_bins ) self.freq_grid_low = freq_grid[:-1] + print("Finding reference parameters..") + self.ref_params = self.maximize_likelihood( bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops ) + print("Constructing reference waveforms..") + self.ref_params["gmst"] = self.gmst self.waveform_low_ref = {} @@ -170,21 +196,31 @@ def __init__( self.B1_array = {} h_sky = self.waveform(frequency_original, self.ref_params) - h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) - h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) - f_valid = frequency_original[jnp.where((jnp.abs(h_sky['p'])+jnp.abs(h_sky['c']))>0)[0]] + # Get frequency masks to be applied, for both original + # and heterodyne frequency grid + h_amp = jnp.sum( + jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0 + ) + f_valid = frequency_original[jnp.where(h_amp > 0)[0]] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - h_sky = h_sky[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - h_sky_low = h_sky_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - h_sky_center = h_sky_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + mask_heterodyne_grid = jnp.where((freq_grid <= f_max) & (freq_grid >= f_min))[0] + mask_heterodyne_low = jnp.where( + (self.freq_grid_low <= f_max) & (self.freq_grid_low >= f_min) + )[0] + mask_heterodyne_center = jnp.where( + (self.freq_grid_center <= f_max) & (self.freq_grid_center >= f_min) + )[0] + freq_grid = freq_grid[mask_heterodyne_grid] + self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] + self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] - frequency_original = frequency_original[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - self.freq_grid_center = self.freq_grid_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) + h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + # Get phase shifts to align time of coalescence align_time = jnp.exp( -1j * 2 @@ -208,6 +244,7 @@ def __init__( ) for detector in self.detectors: + # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) * align_time @@ -227,13 +264,14 @@ def __init__( waveform_ref, detector.psd, frequency_original, - self.freq_grid_low, + freq_grid, self.freq_grid_center, ) - self.A0_array[detector.name] = A0 - self.A1_array[detector.name] = A1 - self.B0_array[detector.name] = B0 - self.B1_array[detector.name] = B1 + + self.A0_array[detector.name] = A0[mask_heterodyne_center] + self.A1_array[detector.name] = A1[mask_heterodyne_center] + self.B0_array[detector.name] = B0[mask_heterodyne_center] + self.B1_array[detector.name] = B1[mask_heterodyne_center] def evaluate(self, params: Array, data: dict) -> float: log_likelihood = 0 @@ -257,6 +295,7 @@ def evaluate(self, params: Array, data: dict) -> float: detector.fd_response(frequencies_center, waveform_sky_center, params) * align_time_center ) + r0 = waveform_center / self.waveform_center_ref[detector.name] r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center @@ -275,7 +314,9 @@ def evaluate(self, params: Array, data: dict) -> float: def evaluate_original( self, params: Array, data: dict - ) -> float: # TODO: Test whether we need to pass data in or with class changes is fine. + ) -> ( + float + ): # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ @@ -307,18 +348,66 @@ def evaluate_original( return log_likelihood @staticmethod - def max_phase_diff(f, f_low, f_high, chi=1): + def max_phase_diff( + f: Float[Array, "n_dim"], + f_low: float, + f_high: float, + chi: float = 1, + ): + """ + Compute the maximum phase difference between the frequencies in the array. + + Parameters + ---------- + f: Float[Array, "n_dim"] + Array of frequencies to be binned. + f_low: float + Lower frequency bound. + f_high: float + Upper frequency bound. + chi: float + Power law index. + + Returns + ------- + Float[Array, "n_dim"] + Maximum phase difference between the frequencies in the array. + """ + gamma = np.arange(-5, 6, 1) / 3.0 f = np.repeat(f[:, None], len(gamma), axis=1) f_star = np.repeat(f_low, len(gamma)) f_star[gamma >= 0] = f_high return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) - def make_binning_scheme(self, freqs, n_bins, chi=1): - phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=1) + def make_binning_scheme( + self, freqs: Float[Array, "n_dim"], n_bins: int, chi: float = 1 + ) -> tuple[Float[Array, "n_bins+1"], Float[Array, "n_bins"]]: + """ + Make a binning scheme based on the maximum phase difference between the + frequencies in the array. + + Parameters + ---------- + freqs: Float[Array, "dim"] + Array of frequencies to be binned. + n_bins: int + Number of bins to be used. + chi: float = 1 + The chi parameter used in the phase difference calculation. + + Returns + ------- + f_bins: Float[Array, "n_bins+1"] + The bin edges. + f_bins_center: Float[Array, "n_bins"] + The bin centers. + """ + + phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) - for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins): + for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): f_bins = np.append(f_bins, bin_f(i)) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 return f_bins, f_bins_center @@ -372,9 +461,11 @@ def maximize_likelihood( bounds = jnp.array(bounds).T popsize = popsize # TODO remove this? - y = lambda x: -self.evaluate_original( - prior.add_name(x, transform_name=True, transform_value=True), None - ) + def y(x): + return -self.evaluate_original( + prior.add_name(x, transform_name=True, transform_value=True), None + ) + y = jax.jit(jax.vmap(y)) print("Starting the optimizer") diff --git a/src/jimgw/waveform.py b/src/jimgw/waveform.py index c94b81ad..11220021 100644 --- a/src/jimgw/waveform.py +++ b/src/jimgw/waveform.py @@ -9,7 +9,7 @@ class Waveform(ABC): def __init__(self): return NotImplemented - def __call__(self, axis: Array, params: Array) -> Array: + def __call__(self, axis: Array, params: Array) -> dict: return NotImplemented @@ -47,7 +47,7 @@ class RippleIMRPhenomPv2(Waveform): def __init__(self, f_ref: float = 20.0): self.f_ref = f_ref - def __call__(self, frequency: Array, params: dict) -> Array: + def __call__(self, frequency: Array, params: dict) -> dict: output = {} theta = [ params["M_c"],