diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7577ffd7..384eb8e5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -263,7 +263,13 @@ def sample( return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict[str, Float]) -> Float: - return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) + mag = x[self.naming[2]] + output = jnp.where( + (mag > 1) | (mag < 0), + jnp.zeros_like(0) - jnp.inf, + jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), + ) + return output @jaxtyped