diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b318c29a..e0a35d8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,16 +1,16 @@ files: src/ repos: - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.1.7' + rev: 'v0.1.8' hooks: - id: ruff args: ["--fix"] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.339 + rev: v1.1.340 hooks: - id: pyright additional_dependencies: [beartype, jax, jaxtyping, pytest, typing_extensions, flowMC, ripplegw, gwpy, astropy] diff --git a/pyproject.toml b/pyproject.toml index 482af776..006e01c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ [build-system] requires = ["setuptools","wheel"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[tool.pyright] +reportIncompatibleMethodOverride = "warning" \ No newline at end of file diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index 9d60c8df..bc859b99 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -5,7 +5,7 @@ import numpy as np import requests from gwpy.timeseries import TimeSeries -from jaxtyping import Array, PRNGKeyArray, Float +from jaxtyping import Array, PRNGKeyArray, Float, jaxtyped from scipy.interpolate import interp1d from scipy.signal.windows import tukey @@ -160,13 +160,13 @@ def arms(self) -> tuple[Float[Array, " 3"], Float[Array, " 3"]]: return x, y @property - def tensor(self) -> Float[Array, " 3, 3"]: + def tensor(self) -> Float[Array, " 3 3"]: """ Detector tensor defining the strain measurement. Returns ------- - tensor : Float[Array, " 3, 3"] + tensor : Float[Array, " 3 3"] detector tensor. """ # TODO: this could easily be generalized for other detector geometries @@ -389,6 +389,7 @@ def inject_signal( signal = self.fd_response(freqs, h_sky, params) * align_time self.data = signal + noise_real + 1j * noise_imag + @jaxtyped def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" ) -> Float[Array, " n_sample"]: @@ -401,8 +402,7 @@ def load_psd( else: f, asd_vals = np.loadtxt(psd_file, unpack=True) psd_vals = asd_vals**2 - assert isinstance(f, Float[Array, "n_sample"]) - assert isinstance(psd_vals, Float[Array, "n_sample"]) + psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore psd = jnp.array(psd) return psd diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 02a9db83..6c05eb03 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -95,7 +95,7 @@ def posterior(self, params: Array, data: dict): ) def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): - if initial_guess is jnp.array([]): + if initial_guess.size == 0: initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T self.Sampler.sample(initial_guess, None) # type: ignore diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 1811b4f3..f00cbabc 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -185,7 +185,7 @@ def __init__( self.freq_grid_low = freq_grid[:-1] print("Finding reference parameters..") - + self.ref_params = self.maximize_likelihood( bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops ) @@ -474,3 +474,8 @@ def y(x): _ = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return prior.transform(prior.add_name(best_fit)) + + +class PopulationLikelihood(LikelihoodBase): + events: Float[Array, " n_events n_samples n_dim"] + reference_pop: Float[Array, " n_det n_dim"] diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index d96eacab..d7cb58cb 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int, PRNGKeyArray +from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped from typing import Callable, Union from dataclasses import field @@ -90,6 +90,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError +@jaxtyped class Uniform(Prior): xmin: Float = 0.0 xmax: Float = 1.0 @@ -102,8 +103,6 @@ def __init__( transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, Float), "xmin must be a Float" - assert isinstance(xmax, Float), "xmax must be a Float" assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin @@ -142,6 +141,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: return output + jnp.log(1.0 / (self.xmax - self.xmin)) +@jaxtyped class Unconstrained_Uniform(Prior): xmin: Float = 0.0 xmax: Float = 1.0 @@ -154,8 +154,6 @@ def __init__( transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, Float), "xmin must be a Float" - assert isinstance(xmax, Float), "xmax must be a Float" assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin @@ -257,6 +255,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) +@jaxtyped class Alignedspin(Prior): """ @@ -284,7 +283,6 @@ def __init__( transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(amax, Float), "xmin must be a Float" assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" self.amax = amax @@ -359,6 +357,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: return log_p +@jaxtyped class Powerlaw(Prior): """ @@ -380,9 +379,6 @@ def __init__( transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, Float), "xmin must be a Float" - assert isinstance(xmax, Float), "xmax must be a Float" - assert isinstance(alpha, (Int, Float)), "alpha must be a int or a Float" if alpha < 0.0: assert xmin > 0.0, "With negative alpha, xmin must > 0" assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" diff --git a/src/jimgw/wave.py b/src/jimgw/wave.py index 63acd56d..12dd4628 100644 --- a/src/jimgw/wave.py +++ b/src/jimgw/wave.py @@ -27,7 +27,7 @@ def __init__(self, name: str): def tensor_from_basis( self, x: Float[Array, " 3"], y: Float[Array, " 3"] - ) -> Float[Array, " 3, 3"]: + ) -> Float[Array, " 3 3"]: """Constructor to obtain polarization tensor from waveframe basis defined by orthonormal vectors (x, y) in arbitrary Cartesian coordinates. @@ -52,7 +52,7 @@ def tensor_from_basis( def tensor_from_sky( self, ra: Float, dec: Float, psi: Float, gmst: Float - ) -> Float[Array, " 3, 3"]: + ) -> Float[Array, " 3 3"]: """Computes {name} polarization tensor in celestial coordinates from sky location and orientation parameters.