From 113dd0f3b7be30028539f83b1e1cdc37b4cc470e Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 1 Jul 2024 15:12:46 -0400 Subject: [PATCH] commit Normal prior --- src/jimgw/prior.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index bbf1846b..712abf1a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -551,6 +551,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: log_p = self.alpha * variable + jnp.log(self.normalization) return log_p + log_in_range + @jaxtyped(typechecker=typechecker) class Normal(Prior): mean: Float = 0.0 @@ -591,15 +592,17 @@ def sample( Samples from the distribution. The keys are the names of the parameters. """ - samples = jax.random.normal( - rng_key, (n_samples,) - ) + samples = jax.random.normal(rng_key, (n_samples,)) samples = self.mean + samples * self.std return self.add_name(samples[None]) def log_prob(self, x: dict[str, Array]) -> Float: variable = x[self.naming[0]] - output = - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.std) - 0.5 * ((variable - self.mean) / self.std) ** 2 + output = ( + -0.5 * jnp.log(2 * jnp.pi) + - jnp.log(self.std) + - 0.5 * ((variable - self.mean) / self.std) ** 2 + ) return output