From 3cf1f41781e6e7a2d4aa8a9244e940a1a5102bc6 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 18 Jan 2024 15:59:49 -0500 Subject: [PATCH] Fix log_prob calculation in Sphere class --- src/jimgw/prior.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7577ffd7..384eb8e5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -263,7 +263,13 @@ def sample( return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict[str, Float]) -> Float: - return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) + mag = x[self.naming[2]] + output = jnp.where( + (mag > 1) | (mag < 0), + jnp.zeros_like(0) - jnp.inf, + jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), + ) + return output @jaxtyped