From 7bd5491f977aa3b407aad91c72b1a06efcb4f49c Mon Sep 17 00:00:00 2001 From: Zipeng Wang Date: Fri, 2 Feb 2024 10:38:29 -0500 Subject: [PATCH 1/2] Update prior.py --- src/jimgw/prior.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 384eb8e5..c42104db 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -65,8 +65,11 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: A dictionary of parameters with the transforms applied. """ output = {} + #print("transform input:", x) for value in self.transforms.values(): output[value[0]] = value[1](x) + #print("transform output:", output) + return output def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: @@ -264,11 +267,17 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: mag = x[self.naming[2]] + phi = x[self.naming[1]] output = jnp.where( (mag > 1) | (mag < 0), jnp.zeros_like(0) - jnp.inf, jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), ) + output = jnp.where( + (phi > 2* jnp.pi) | (phi < 0), + jnp.zeros_like(0) - jnp.inf, + output, + ) return output From 79f0cb494fd286909a3c1aa231dfb33828584a99 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 2 Feb 2024 18:33:49 -0500 Subject: [PATCH 2/2] Update prior.py --- src/jimgw/prior.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index c42104db..941cf933 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -65,11 +65,8 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: A dictionary of parameters with the transforms applied. """ output = {} - #print("transform input:", x) for value in self.transforms.values(): output[value[0]] = value[1](x) - #print("transform output:", output) - return output def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: @@ -266,18 +263,14 @@ def sample( return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict[str, Float]) -> Float: - mag = x[self.naming[2]] + theta = x[self.naming[0]] phi = x[self.naming[1]] + mag = x[self.naming[2]] output = jnp.where( - (mag > 1) | (mag < 0), + (mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1), jnp.zeros_like(0) - jnp.inf, jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), ) - output = jnp.where( - (phi > 2* jnp.pi) | (phi < 0), - jnp.zeros_like(0) - jnp.inf, - output, - ) return output