diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 32107e1e..a658b494 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -470,6 +470,86 @@ def log_prob(self, x: dict[str, Float]) -> Float: return log_p + log_in_range +@jaxtyped +class Exponential(Prior): + """ + A prior following the power-law with alpha in the range [xmin, xmax). + p(x) ~ exp(\alpha x) + """ + + xmin: Float = 0.0 + xmax: Float = jnp.inf + alpha: Float = -1.0 + normalization: Float = 1.0 + + def __repr__(self): + return f"Exponential(xmin={self.xmin}, xmax={self.xmax}, alpha={self.alpha}, naming={self.naming})" + + def __init__( + self, + xmin: Float, + xmax: Float, + alpha: Union[Int, Float], + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + if alpha < 0.0: + assert xmin != -jnp.inf, "With negative alpha, xmin must finite" + if alpha > 0.0: + assert xmax != jnp.inf, "With positive alpha, xmax must finite" + assert not jnp.isclose(alpha, 0.0), "alpha=zero is given, use Uniform instead" + assert self.n_dim == 1, "Exponential needs to be 1D distributions" + + self.xmax = xmax + self.xmin = xmin + self.alpha = alpha + + self.normalization = self.alpha / ( + jnp.exp(self.alpha * self.xmax) - jnp.exp(self.alpha * self.xmin) + ) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a exponential distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0) + samples = ( + self.xmin + + jnp.log1p( + q_samples * (jnp.exp(self.alpha * (self.xmax - self.xmin)) - 1.0) + ) + / self.alpha + ) + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Float]) -> Float: + variable = x[self.naming[0]] + log_in_range = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + log_p = self.alpha * variable + jnp.log(self.normalization) + return log_p + log_in_range + + class Composite(Prior): priors: list[Prior] = field(default_factory=list)