From 7227fdcbde46e179f332e55721c9873599f205dc Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 27 Nov 2023 15:20:57 -0500 Subject: [PATCH] Fixing convention bugs in sphere prior --- src/jimgw/prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 8a32e2c3..91befa2a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -228,16 +228,16 @@ def __init__(self, naming: str): def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: rng_keys = jax.random.split(rng_key, 3) theta = jax.random.uniform( - rng_keys[0], (n_samples,), minval=0, maxval=2 * jnp.pi + rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 ) phi = jnp.arccos( - jax.random.uniform(rng_keys[1], (n_samples,), minval=-1.0, maxval=1.0) + jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) ) mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict) -> Float: - return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[1]])) + return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) class Composite(Prior):