Skip to content

Commit

Permalink
commit Normal prior
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Jul 1, 2024
1 parent 8b49717 commit 113dd0f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def log_prob(self, x: dict[str, Float]) -> Float:
log_p = self.alpha * variable + jnp.log(self.normalization)
return log_p + log_in_range


@jaxtyped(typechecker=typechecker)
class Normal(Prior):
mean: Float = 0.0
Expand Down Expand Up @@ -591,15 +592,17 @@ def sample(
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.normal(
rng_key, (n_samples,)
)
samples = jax.random.normal(rng_key, (n_samples,))
samples = self.mean + samples * self.std
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.std) - 0.5 * ((variable - self.mean) / self.std) ** 2
output = (
-0.5 * jnp.log(2 * jnp.pi)
- jnp.log(self.std)
- 0.5 * ((variable - self.mean) / self.std) ** 2
)
return output


Expand Down

0 comments on commit 113dd0f

Please sign in to comment.