Skip to content

Commit

Permalink
Merge pull request #75 from tsunhopang/expotential_prior
Browse files Browse the repository at this point in the history
Adding exponential distribution
  • Loading branch information
kazewong authored Mar 17, 2024
2 parents 1589dbb + 7ce6e35 commit 32a4306
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,86 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return log_p + log_in_range


@jaxtyped
class Exponential(Prior):
"""
A prior following the power-law with alpha in the range [xmin, xmax).
p(x) ~ exp(\alpha x)
"""

xmin: Float = 0.0
xmax: Float = jnp.inf
alpha: Float = -1.0
normalization: Float = 1.0

def __repr__(self):
return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})"

def __init__(
self,
xmin: Float,
xmax: Float,
alpha: Union[Int, Float],
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
if alpha < 0.0:
assert xmin != -jnp.inf, "With negative alpha, xmin must finite"
if alpha > 0.0:
assert xmax != jnp.inf, "With positive alpha, xmax must finite"
assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead"
assert self.n_dim == 1, "Exponential needs to be 1D distributions"

self.xmax = xmax
self.xmin = xmin
self.alpha = alpha

self.normalization = self.alpha / (
jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin)
)

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a exponential distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
samples = (
self.xmin
+ jnp.log1p(
q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0)
)
/ self.alpha
)
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Float]) -> Float:
variable = x[self.naming[0]]
log_in_range = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
log_p = self.alpha * variable + jnp.log(self.normalization)
return log_p + log_in_range


class Composite(Prior):
priors: list[Prior] = field(default_factory=list)

Expand Down

0 comments on commit 32a4306

Please sign in to comment.