Skip to content

Commit

Permalink
Add transformation function
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Jul 25, 2024
1 parent 8ab92a4 commit d8b2d1f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
output = self.base_prior.sample(rng_key, n_samples)
for transform in self.transforms:
output = jax.vmap(transform.forward)(output)
return output
return jax.vmap(self.transform)(output)

def log_prob(self, x: dict[str, Float]) -> Float:
"""
Expand All @@ -141,6 +139,11 @@ def log_prob(self, x: dict[str, Float]) -> Float:
x, log_jacobian = transform.transform(x)
output -= log_jacobian
return output

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
for transform in self.transforms:
x = transform.forward(x)
return x

# class Combine(Prior):
# """
Expand Down Expand Up @@ -185,7 +188,7 @@ def log_prob(self, x: dict[str, Float]) -> Float:

@jaxtyped(typechecker=typechecker)
class Uniform(Prior):
_dist: Prior
_dist: SequentialTransform

xmin: float
xmax: float
Expand Down Expand Up @@ -218,6 +221,14 @@ def sample(

def log_prob(self, x: dict[str, Array]) -> Float:
return self._dist.log_prob(x)

def sample_base(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self._dist.base_prior.sample(rng_key, n_samples)

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
return self._dist.transform(x)

# ====================== Things below may need rework ======================

Expand Down
2 changes: 1 addition & 1 deletion test/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_logistic(self):

def test_uniform(self):
p = Uniform(0.0, 10.0, ['x'])
samples = p.sample(jax.random.PRNGKey(0), 10000)
samples = p._dist.base_prior.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.allclose(log_prob, -jnp.log(10.0))

Expand Down

0 comments on commit d8b2d1f

Please sign in to comment.