From e1f428d6d73dd79bd4d159beb1046c119ed21bce Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 27 Jun 2024 00:07:29 -0700 Subject: [PATCH 1/6] Added Normal prior --- src/jimgw/prior.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 6408318d..bbf1846b 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -551,6 +551,57 @@ def log_prob(self, x: dict[str, Float]) -> Float: log_p = self.alpha * variable + jnp.log(self.normalization) return log_p + log_in_range +@jaxtyped(typechecker=typechecker) +class Normal(Prior): + mean: Float = 0.0 + std: Float = 1.0 + + def __repr__(self): + return f"Normal(mean={self.mean}, std={self.std})" + + def __init__( + self, + mean: Float, + std: Float, + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + assert self.n_dim == 1, "Normal needs to be 1D distributions" + self.mean = mean + self.std = std + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a normal distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.normal( + rng_key, (n_samples,) + ) + samples = self.mean + samples * self.std + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Array]) -> Float: + variable = x[self.naming[0]] + output = - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.std) - 0.5 * ((variable - self.mean) / self.std) ** 2 + return output + class Composite(Prior): priors: list[Prior] = field(default_factory=list) From 9bbd6d8f4c224b524f0c6bfaf6789976afc03d9f Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Thu, 27 Jun 2024 12:43:30 -0700 Subject: [PATCH 2/6] change for numpy v2 typing --- src/jimgw/single_event/likelihood.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 58ffb083..d519d5cb 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -437,7 +437,7 @@ def evaluate_original( @staticmethod def max_phase_diff( - f: npt.NDArray[np.float_], + f: npt.NDArray[np.floating], f_low: float, f_high: float, chi: Float = 1.0, @@ -469,7 +469,7 @@ def max_phase_diff( return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) def make_binning_scheme( - self, freqs: npt.NDArray[np.float_], n_bins: int, chi: float = 1 + self, freqs: npt.NDArray[np.floating], n_bins: int, chi: float = 1 ) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]: """ Make a binning scheme based on the maximum phase difference between the From 7637049032d4c37ccc3006eeadcc70deab92bc75 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 27 Jun 2024 17:25:42 -0400 Subject: [PATCH 3/6] Update likelihood.py --- src/jimgw/single_event/likelihood.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index d519d5cb..3e3e290a 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -437,7 +437,7 @@ def evaluate_original( @staticmethod def max_phase_diff( - f: npt.NDArray[np.floating], + f: npt.NDArray[float], f_low: float, f_high: float, chi: Float = 1.0, @@ -469,7 +469,7 @@ def max_phase_diff( return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) def make_binning_scheme( - self, freqs: npt.NDArray[np.floating], n_bins: int, chi: float = 1 + self, freqs: npt.NDArray[float], n_bins: int, chi: float = 1 ) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]: """ Make a binning scheme based on the maximum phase difference between the From 841f7cdc9136bea5bc37a141338ee69d7e21de3f Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 27 Jun 2024 17:35:02 -0400 Subject: [PATCH 4/6] Update setup.cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 9a2ed7e7..357b339b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,6 @@ install_requires = jax>=0.4.12 jaxlib>=0.4.12 flowMC>=0.2.4 - numpy>=1,<2 ripplegw gwpy corner From 8b497175e680dce115d99a7fb547434414a7fb69 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 28 Jun 2024 09:14:27 -0400 Subject: [PATCH 5/6] Update likelihood.py Walk back typing changes --- src/jimgw/single_event/likelihood.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 3e3e290a..d519d5cb 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -437,7 +437,7 @@ def evaluate_original( @staticmethod def max_phase_diff( - f: npt.NDArray[float], + f: npt.NDArray[np.floating], f_low: float, f_high: float, chi: Float = 1.0, @@ -469,7 +469,7 @@ def max_phase_diff( return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) def make_binning_scheme( - self, freqs: npt.NDArray[float], n_bins: int, chi: float = 1 + self, freqs: npt.NDArray[np.floating], n_bins: int, chi: float = 1 ) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]: """ Make a binning scheme based on the maximum phase difference between the From 113dd0f3b7be30028539f83b1e1cdc37b4cc470e Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 1 Jul 2024 15:12:46 -0400 Subject: [PATCH 6/6] commit Normal prior --- src/jimgw/prior.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index bbf1846b..712abf1a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -551,6 +551,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: log_p = self.alpha * variable + jnp.log(self.normalization) return log_p + log_in_range + @jaxtyped(typechecker=typechecker) class Normal(Prior): mean: Float = 0.0 @@ -591,15 +592,17 @@ def sample( Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.normal( - rng_key, (n_samples,) - ) + samples = jax.random.normal(rng_key, (n_samples,)) samples = self.mean + samples * self.std return self.add_name(samples[None]) def log_prob(self, x: dict[str, Array]) -> Float: variable = x[self.naming[0]] - output = - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.std) - 0.5 * ((variable - self.mean) / self.std) ** 2 + output = ( + -0.5 * jnp.log(2 * jnp.pi) + - jnp.log(self.std) + - 0.5 * ((variable - self.mean) / self.std) ** 2 + ) return output