Skip to content

Commit

Permalink
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
Browse files Browse the repository at this point in the history
…ss' into 98-moving-naming-tracking-into-jim-class-from-prior-class
  • Loading branch information
xuyuon authored Aug 1, 2024
2 parents 47af9cf + 06eb3ad commit 7826122
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 78 deletions.
7 changes: 4 additions & 3 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PRNGKeyArray

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior, trace_prior_parent
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


Expand Down Expand Up @@ -48,8 +48,9 @@ def __init__(
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print("No likelihood transforms provided. Using prior parameters as likelihood parameters")

print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand Down
4 changes: 2 additions & 2 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ScaleTransform,
OffsetTransform,
ArcSineTransform,
# PowerLawTransform,
# ParetoTransform,
PowerLawTransform,
ParetoTransform,
)


Expand Down
142 changes: 80 additions & 62 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from abc import ABC
from typing import Callable

import jax
import jax.numpy as jnp
from chex import assert_rank
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped

Expand Down Expand Up @@ -261,7 +260,6 @@ def __init__(

@jaxtyped(typechecker=typechecker)
class BoundToBound(BijectiveTransform):

"""
Bound to bound transformation
"""
Expand Down Expand Up @@ -300,6 +298,7 @@ def __init__(
for i in range(len(name_mapping[1]))
}


@jaxtyped(typechecker=typechecker)
class BoundToUnbound(BijectiveTransform):
"""
Expand All @@ -315,7 +314,7 @@ def __init__(
original_lower_bound: Float,
original_upper_bound: Float,
):

def logit(x):
return jnp.log(x / (1 - x))

Expand All @@ -331,17 +330,13 @@ def logit(x):
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
self.original_upper_bound - self.original_lower_bound
)
/ (
1
+ jnp.exp(-x[name_mapping[1][i]])
)
name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound)
/ (1 + jnp.exp(-x[name_mapping[1][i]]))
+ self.original_lower_bound[i]
for i in range(len(name_mapping[1]))
}


class SingleSidedUnboundTransform(BijectiveTransform):
"""
Unbound upper limit transformation
Expand All @@ -368,55 +363,78 @@ def __init__(
}


class PowerLawTransform(BijectiveTransform):
"""
PowerLaw transformation
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

xmin: Float
xmax: Float
alpha: Float

# class PowerLawTransform(UnivariateTransform):
# """
# PowerLaw transformation
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# xmin: Float
# xmax: Float
# alpha: Float

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# alpha: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.alpha = alpha
# self.transform_func = lambda x: (
# self.xmin ** (1.0 + self.alpha)
# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
# ) ** (1.0 / (1.0 + self.alpha))


# class ParetoTransform(UnivariateTransform):
# """
# Pareto transformation: Power law when alpha = -1
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.transform_func = lambda x: self.xmin * jnp.exp(
# x * jnp.log(self.xmax / self.xmin)
# )
def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
alpha: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.alpha = alpha
self.transform_func = lambda x: {
name_mapping[1][i]: (
self.xmin ** (1.0 + self.alpha)
+ x[name_mapping[0][i]]
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
** (1.0 / (1.0 + self.alpha))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
(
x[name_mapping[1][i]] ** (1.0 + self.alpha)
- self.xmin ** (1.0 + self.alpha)
)
/ (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
for i in range(len(name_mapping[1]))
}


class ParetoTransform(BijectiveTransform):
"""
Pareto transformation: Power law when alpha = -1
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.transform_func = lambda x: {
name_mapping[1][i]: self.xmin
* jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
jnp.log(x[name_mapping[1][i]] / self.xmin)
/ jnp.log(self.xmax / self.xmin)
)
for i in range(len(name_mapping[1]))
}
18 changes: 7 additions & 11 deletions test/unit/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,16 @@ def func(alpha):
assert jnp.all(jnp.isfinite(powerlaw_samples['x']))

# Check that all the log_probs are finite
samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base']
base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples})
assert jnp.all(jnp.isfinite(base_log_p))
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_p = jax.vmap(p.log_prob, [0])(samples)
assert jnp.all(jnp.isfinite(log_p))

# Check that the log_prob is correct in the support
samples = jnp.linspace(-10.0, 10.0, 1000)
transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x']
# cut off the samples that are outside the support
samples = samples[transformed_samples >= xmin]
transformed_samples = transformed_samples[transformed_samples >= xmin]
samples = samples[transformed_samples <= xmax]
transformed_samples = transformed_samples[transformed_samples <= xmax]
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax)
# log pdf of powerlaw
assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4)
assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4)

# Test Pareto Transform
func(-1.0)
Expand Down

0 comments on commit 7826122

Please sign in to comment.