From d8b2d1feb0bdae6b4fd59bef8cc345eca2bb8c33 Mon Sep 17 00:00:00 2001 From: kazewong Date: Thu, 25 Jul 2024 13:48:17 -0400 Subject: [PATCH] Add transformation function --- src/jimgw/prior.py | 19 +++++++++++++++---- test/test_prior.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 61202609..232b4a0f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -126,9 +126,7 @@ def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: output = self.base_prior.sample(rng_key, n_samples) - for transform in self.transforms: - output = jax.vmap(transform.forward)(output) - return output + return jax.vmap(self.transform)(output) def log_prob(self, x: dict[str, Float]) -> Float: """ @@ -141,6 +139,11 @@ def log_prob(self, x: dict[str, Float]) -> Float: x, log_jacobian = transform.transform(x) output -= log_jacobian return output + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + for transform in self.transforms: + x = transform.forward(x) + return x # class Combine(Prior): # """ @@ -185,7 +188,7 @@ def log_prob(self, x: dict[str, Float]) -> Float: @jaxtyped(typechecker=typechecker) class Uniform(Prior): - _dist: Prior + _dist: SequentialTransform xmin: float xmax: float @@ -218,6 +221,14 @@ def sample( def log_prob(self, x: dict[str, Array]) -> Float: return self._dist.log_prob(x) + + def sample_base( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + return self._dist.base_prior.sample(rng_key, n_samples) + + def transform(self, x: dict[str, Float]) -> dict[str, Float]: + return self._dist.transform(x) # ====================== Things below may need rework ====================== diff --git a/test/test_prior.py b/test/test_prior.py index 6890f9b9..b6eb7c87 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -7,7 +7,7 @@ def test_logistic(self): def test_uniform(self): p = Uniform(0.0, 10.0, ['x']) - samples = p.sample(jax.random.PRNGKey(0), 10000) + samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, -jnp.log(10.0))