Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prior and Transformation Classes #111

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 109 additions & 140 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flowMC.nfmodel.base import Distribution
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped

from jimgw.transforms import Transform, Logit, Scale, Offset
from jimgw.transforms import Transform, Logit, Scale, Offset, ArcSine, ArcCosine


class Prior(Distribution):
Expand Down Expand Up @@ -59,17 +59,17 @@ def log_prob(self, x: dict[str, Array]) -> Float:
class LogisticDistribution(Prior):

def __repr__(self):
return f"Logistic(parameter_names={self.parameter_names})"
return f"LogisticDistribution(parameter_names={self.parameter_names})"

def __init__(self, parameter_names: list[str], **kwargs):
super().__init__(parameter_names)
assert self.n_dim == 1, "Logit needs to be 1D distributions"
assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions"

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a logit distribution.
Sample from a logistic distribution.

Parameters
----------
Expand All @@ -93,6 +93,45 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return -variable - 2 * jnp.log(1 + jnp.exp(-variable))


@jaxtyped(typechecker=typechecker)
class StandardNormalDistribution(Prior):

def __repr__(self):
return f"StandardNormalDistribution(parameter_names={self.parameter_names})"

def __init__(self, parameter_names: list[str], **kwargs):
super().__init__(parameter_names)
assert (
self.n_dim == 1
), "StandardNormalDistribution needs to be 1D distributions"

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a standard normal distribution.

Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.

Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.

"""
samples = jax.random.normal(rng_key, (n_samples,))
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Float]) -> Float:
variable = x[self.parameter_names[0]]
return -0.5 * variable**2 - 0.5 * jnp.log(2 * jnp.pi)


class SequentialTransform(Prior):
"""
Transform a prior distribution by applying a sequence of transforms.
Expand All @@ -119,21 +158,24 @@ def __init__(
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
output = self.base_prior.sample(rng_key, n_samples)
output = self.sample_base(rng_key, n_samples)
return jax.vmap(self.transform)(output)

def log_prob(self, x: dict[str, Float]) -> Float:
"""
log_prob has to be evaluated in the space of the base_prior.


"""
output = self.base_prior.log_prob(x)
for transform in self.transforms:
x, log_jacobian = transform.transform(x)
output -= log_jacobian
return output

def sample_base(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self.base_prior.sample(rng_key, n_samples)

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
for transform in self.transforms:
x = transform.forward(x)
Expand All @@ -149,9 +191,7 @@ class Combine(Prior):
priors: list[Prior] = field(default_factory=list)

def __repr__(self):
return (
f"Composite(priors={self.priors}, parameter_names={self.parameter_names})"
)
return f"Combine(priors={self.priors}, parameter_names={self.parameter_names})"

def __init__(
self,
Expand Down Expand Up @@ -180,112 +220,95 @@ def log_prob(self, x: dict[str, Float]) -> Float:


@jaxtyped(typechecker=typechecker)
class Uniform(Prior):
_dist: SequentialTransform

class Uniform(SequentialTransform):
xmin: float
xmax: float

def __repr__(self):
return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})"
return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})"

def __init__(
self,
xmin: float,
xmax: float,
parameter_names: list[str],
):
super().__init__(parameter_names)
self.parameter_names = parameter_names
assert self.n_dim == 1, "Uniform needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
self._dist = SequentialTransform(
LogisticDistribution(parameter_names),
super().__init__(
LogisticDistribution(self.parameter_names),
[
Logit((parameter_names, parameter_names)),
Scale((parameter_names, parameter_names), xmax - xmin),
Offset((parameter_names, parameter_names), xmin),
Logit((self.parameter_names, self.parameter_names)),
Scale((self.parameter_names, self.parameter_names), xmax - xmin),
Offset((self.parameter_names, self.parameter_names), xmin),
],
)

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self._dist.sample(rng_key, n_samples)

def log_prob(self, x: dict[str, Array]) -> Float:
return self._dist.log_prob(x)
@jaxtyped(typechecker=typechecker)
class Sine(SequentialTransform):
"""
A prior distribution where the pdf is proportional to sin(x) in the range [0, pi].
"""

def sample_base(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self._dist.base_prior.sample(rng_key, n_samples)
def __repr__(self):
return f"Sine(parameter_names={self.parameter_names})"

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
return self._dist.transform(x)
def __init__(self, parameter_names: list[str]):
self.parameter_names = parameter_names
assert self.n_dim == 1, "Sine needs to be 1D distributions"
super().__init__(
Uniform(-1.0, 1.0, f"cos_{self.parameter_names}"),
[ArcCosine(([f"cos_{self.parameter_names}"], [self.parameter_names]))],
)


# ====================== Things below may need rework ======================
@jaxtyped(typechecker=typechecker)
class Cosine(SequentialTransform):
"""
A prior distribution where the pdf is proportional to cos(x) in the range [-pi/2, pi/2].
"""

def __repr__(self):
return f"Cosine(parameter_names={self.parameter_names})"

# class Sphere(Prior):
# """
# A prior on a sphere represented by Cartesian coordinates.
def __init__(self, parameter_names: list[str]):
self.parameter_names = parameter_names
assert self.n_dim == 1, "Cosine needs to be 1D distributions"
super().__init__(
Uniform(-1.0, 1.0, f"sin_{self.parameter_names}"),
[ArcSine(([f"sin_{self.parameter_names}"], [self.parameter_names]))],
)

# Magnitude is sampled from a uniform distribution.
# """

# def __repr__(self):
# return f"Sphere(naming={self.naming})"
@jaxtyped(typechecker=typechecker)
class UniformSphere(Combine):

# def __init__(self, naming: list[str], **kwargs):
# name = naming[0]
# self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"]
# self.transforms = {
# self.naming[0]: (
# f"{naming}_x",
# lambda params: jnp.sin(params[self.naming[0]])
# * jnp.cos(params[self.naming[1]])
# * params[self.naming[2]],
# ),
# self.naming[1]: (
# f"{naming}_y",
# lambda params: jnp.sin(params[self.naming[0]])
# * jnp.sin(params[self.naming[1]])
# * params[self.naming[2]],
# ),
# self.naming[2]: (
# f"{naming}_z",
# lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]],
# ),
# }
def __repr__(self):
return f"UniformSphere(parameter_names={self.parameter_names})"

# def sample(
# self, rng_key: PRNGKeyArray, n_samples: int
# ) -> dict[str, Float[Array, " n_samples"]]:
# rng_keys = jax.random.split(rng_key, 3)
# theta = jnp.arccos(
# jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
# )
# phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi)
# mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1)
# return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)
def __init__(self, parameter_names: list[str], **kwargs):
assert (
len(parameter_names) == 1
), "UniformSphere only takes the name of the vector"
parameter_names = parameter_names[0]
self.parameter_names = [
f"{parameter_names}_mag",
f"{parameter_names}_theta",
f"{parameter_names}_phi",
]
super().__init__(
[
Uniform(0.0, 1.0, [self.parameter_names[0]]),
Sine([self.parameter_names[1]]),
Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]),
]
)

# def log_prob(self, x: dict[str, Float]) -> Float:
# theta = x[self.naming[0]]
# phi = x[self.naming[1]]
# mag = x[self.naming[2]]
# output = jnp.where(
# (mag > 1)
# | (mag < 0)
# | (phi > 2 * jnp.pi)
# | (phi < 0)
# | (theta > jnp.pi)
# | (theta < 0),
# jnp.zeros_like(0) - jnp.inf,
# jnp.log(mag**2 * jnp.sin(x[self.naming[0]])),
# )
# return output

# ====================== Things below may need rework ======================


# @jaxtyped(typechecker=typechecker)
Expand Down Expand Up @@ -634,57 +657,3 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]:
# )
# log_p = self.alpha * variable + jnp.log(self.normalization)
# return log_p + log_in_range


# @jaxtyped(typechecker=typechecker)
# class Normal(Prior):
# mean: Float = 0.0
# std: Float = 1.0

# def __repr__(self):
# return f"Normal(mean={self.mean}, std={self.std})"

# def __init__(
# self,
# mean: Float,
# std: Float,
# naming: list[str],
# transforms: dict[str, tuple[str, Callable]] = {},
# **kwargs,
# ):
# super().__init__(naming, transforms)
# assert self.n_dim == 1, "Normal needs to be 1D distributions"
# self.mean = mean
# self.std = std

# def sample(
# self, rng_key: PRNGKeyArray, n_samples: int
# ) -> dict[str, Float[Array, " n_samples"]]:
# """
# Sample from a normal distribution.

# Parameters
# ----------
# rng_key : PRNGKeyArray
# A random key to use for sampling.
# n_samples : int
# The number of samples to draw.

# Returns
# -------
# samples : dict
# Samples from the distribution. The keys are the names of the parameters.

# """
# samples = jax.random.normal(rng_key, (n_samples,))
# samples = self.mean + samples * self.std
# return self.add_name(samples[None])

# def log_prob(self, x: dict[str, Array]) -> Float:
# variable = x[self.naming[0]]
# output = (
# -0.5 * jnp.log(2 * jnp.pi)
# - jnp.log(self.std)
# - 0.5 * ((variable - self.mean) / self.std) ** 2
# )
# return output
23 changes: 21 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.numpy as jnp
from chex import assert_rank
from jaxtyping import Float, Array
from jaxtyping import Float


class Transform(ABC):
Expand All @@ -15,7 +15,7 @@ class Transform(ABC):
"""

name_mapping: tuple[list[str], list[str]]
transform_func: Callable[[Float[Array, " n_dim"]], Float[Array, " n_dim"]]
transform_func: Callable[[dict[str, Float]], dict[str, Float]]

def __init__(
self,
Expand Down Expand Up @@ -156,3 +156,22 @@ def __init__(
):
super().__init__(name_mapping)
self.transform_func = lambda x: jnp.arcsin(x)


class ArcCosine(UnivariateTransform):
"""
ArcCosine transformation

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
self.transform_func = lambda x: jnp.arccos(x)
Loading