Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Uniform and add UniformSphere #109

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 52 additions & 91 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from jimgw.single_event.detector import GroundBased2G, detector_preset
from jimgw.single_event.utils import zenith_azimuth_to_ra_dec
from jimgw.transforms import Transform, Logit, Scale, Offset
from jimgw.transforms import Transform, Logit, Scale, Offset, ArcCosine


class Prior(Distribution):
Expand Down Expand Up @@ -49,7 +49,7 @@ def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""

return dict(zip(self.parameter_names, x))

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
Expand Down Expand Up @@ -106,9 +106,7 @@ class SequentialTransform(Prior):
transforms: list[Transform]

def __repr__(self):
return (
f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})"
)
return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})"

def __init__(
self,
Expand All @@ -127,24 +125,28 @@ def sample(
) -> dict[str, Float[Array, " n_samples"]]:
output = self.base_prior.sample(rng_key, n_samples)
return jax.vmap(self.transform)(output)

def log_prob(self, x: dict[str, Float]) -> Float:
"""
log_prob has to be evaluated in the space of the base_prior.


"""
output = self.base_prior.log_prob(x)
for transform in self.transforms:
x, log_jacobian = transform.transform(x)
output -= log_jacobian
return output


def sample_base(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self.base_prior.sample(rng_key, n_samples)

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
for transform in self.transforms:
x = transform.forward(x)
return x


class Combine(Prior):
"""
A prior class constructed by joinning multiple priors together to form a multivariate prior.
Expand Down Expand Up @@ -184,111 +186,70 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return output



@jaxtyped(typechecker=typechecker)
class Uniform(Prior):
_dist: SequentialTransform

class Uniform(SequentialTransform):
xmin: float
xmax: float

def __repr__(self):
return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})"
return f"Uniform(xmin={self.xmin}, xmax={self.xmax}, parameter_names={self.parameter_names})"

def __init__(
self,
xmin: float,
xmax: float,
parameter_names: list[str],
):
super().__init__(parameter_names)
self.parameter_names = parameter_names
assert self.n_dim == 1, "Uniform needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
self._dist = SequentialTransform(
LogisticDistribution(parameter_names),
super().__init__(
LogisticDistribution(self.parameter_names),
[
Logit((parameter_names, parameter_names)),
Scale((parameter_names, parameter_names), xmax - xmin),
Offset((parameter_names, parameter_names), xmin),
])

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self._dist.sample(rng_key, n_samples)

def log_prob(self, x: dict[str, Array]) -> Float:
return self._dist.log_prob(x)

def sample_base(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
return self._dist.base_prior.sample(rng_key, n_samples)

def transform(self, x: dict[str, Float]) -> dict[str, Float]:
return self._dist.transform(x)

# ====================== Things below may need rework ======================
Logit((self.parameter_names, self.parameter_names)),
Scale((self.parameter_names, self.parameter_names), xmax - xmin),
Offset((self.parameter_names, self.parameter_names), xmin),
],
)

class Sphere(Prior):
"""
A prior on a sphere represented by Cartesian coordinates.

Magnitude is sampled from a uniform distribution.
"""
@jaxtyped(typechecker=typechecker)
class UniformSphere(Combine):

def __repr__(self):
return f"Sphere(naming={self.naming})"
return f"UniformSphere(parameter_names={self.parameter_names})"

def __init__(self, naming: list[str], **kwargs):
name = naming[0]
self.naming = [f"{name}_theta", f"{name}_phi", f"{name}_mag"]
self.transforms = {
self.naming[0]: (
f"{naming}_x",
lambda params: jnp.sin(params[self.naming[0]])
* jnp.cos(params[self.naming[1]])
* params[self.naming[2]],
),
self.naming[1]: (
f"{naming}_y",
lambda params: jnp.sin(params[self.naming[0]])
* jnp.sin(params[self.naming[1]])
* params[self.naming[2]],
),
self.naming[2]: (
f"{naming}_z",
lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]],
),
}

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
rng_keys = jax.random.split(rng_key, 3)
theta = jnp.arccos(
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
def __init__(self, parameter_names: list[str], **kwargs):
assert (
len(parameter_names) == 1
), "UniformSphere only takes the name of the vector"
parameter_names = parameter_names[0]
self.parameter_names = [
f"{parameter_names}_mag",
f"{parameter_names}_theta",
f"{parameter_names}_phi",
]
super().__init__(
[
Uniform(0.0, 1.0, [self.parameter_names[0]]),
SequentialTransform(
Uniform(-1.0, 1.0, [f"cos_{self.parameter_names[1]}"]),
[
ArcCosine(
(
[f"cos_{self.parameter_names[1]}"],
[self.parameter_names[1]],
)
)
],
),
Uniform(0.0, 2 * jnp.pi, [self.parameter_names[2]]),
]
)
phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi)
mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1)
return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
theta = x[self.naming[0]]
phi = x[self.naming[1]]
mag = x[self.naming[2]]
output = jnp.where(
(mag > 1)
| (mag < 0)
| (phi > 2 * jnp.pi)
| (phi < 0)
| (theta > jnp.pi)
| (theta < 0),
jnp.zeros_like(0) - jnp.inf,
jnp.log(mag**2 * jnp.sin(x[self.naming[0]])),
)
return output

# ====================== Things below may need rework ======================


@jaxtyped(typechecker=typechecker)
Expand Down
37 changes: 30 additions & 7 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import field
from typing import Callable, Union
from typing import Callable

import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from chex import assert_rank
from jaxtyping import Array, Float, jaxtyped
from jaxtyping import Float


class Transform(ABC):
"""
Expand All @@ -17,6 +16,7 @@ class Transform(ABC):

name_mapping: tuple[list[str], list[str]]
transform_func: Callable[[dict[str, Float]], dict[str, Float]]

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
Expand Down Expand Up @@ -45,7 +45,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
The log Jacobian determinant.
"""
raise NotImplementedError

@abstractmethod
def forward(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Expand Down Expand Up @@ -92,7 +92,8 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]:
output_params = self.transform_func(input_params)
x[self.name_mapping[1][0]] = output_params
return x



class Scale(UnivariateTransform):
scale: Float

Expand All @@ -105,6 +106,7 @@ def __init__(
self.scale = scale
self.transform_func = lambda x: x * self.scale


class Offset(UnivariateTransform):
offset: Float

Expand All @@ -117,6 +119,7 @@ def __init__(
self.offset = offset
self.transform_func = lambda x: x + self.offset


class Logit(UnivariateTransform):
"""
Logit transform following
Expand All @@ -135,10 +138,11 @@ def __init__(
super().__init__(name_mapping)
self.transform_func = lambda x: 1 / (1 + jnp.exp(-x))


class ArcSine(UnivariateTransform):
"""
ArcSine transformation

Parameters
----------
name_mapping : tuple[list[str], list[str]]
Expand All @@ -152,3 +156,22 @@ def __init__(
):
super().__init__(name_mapping)
self.transform_func = lambda x: jnp.arcsin(x)


class ArcCosine(UnivariateTransform):
"""
ArcCosine transformation

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
self.transform_func = lambda x: jnp.arccos(x)
Loading