Skip to content

Commit

Permalink
Fixing convention bugs in sphere prior
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Nov 27, 2023
1 parent d011aca commit 7227fdc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,16 @@ def __init__(self, naming: str):
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
rng_keys = jax.random.split(rng_key, 3)
theta = jax.random.uniform(
rng_keys[0], (n_samples,), minval=0, maxval=2 * jnp.pi
rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0
)
phi = jnp.arccos(
jax.random.uniform(rng_keys[1], (n_samples,), minval=-1.0, maxval=1.0)
jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi)
)
mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1)
return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)

def log_prob(self, x: dict) -> Float:
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[1]]))
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]]))


class Composite(Prior):
Expand Down

0 comments on commit 7227fdc

Please sign in to comment.