Skip to content

Commit

Permalink
Fixed powerLaw
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyuon committed Aug 1, 2024
1 parent 1c45acb commit 255fad3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
11 changes: 7 additions & 4 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PRNGKeyArray

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior, trace_prior_parent
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


Expand Down Expand Up @@ -39,11 +39,14 @@ def __init__(
self.likelihood_transforms = likelihood_transforms

if len(sample_transforms) == 0:
print("No sample transforms provided. Using prior parameters as sampling parameters")
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)

if len(likelihood_transforms) == 0:
print("No likelihood transforms provided. Using prior parameters as likelihood parameters")

print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand Down
57 changes: 33 additions & 24 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from abc import ABC
from typing import Callable

import jax
import jax.numpy as jnp
from chex import assert_rank
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped

Expand Down Expand Up @@ -261,7 +260,6 @@ def __init__(

@jaxtyped(typechecker=typechecker)
class BoundToBound(BijectiveTransform):

"""
Bound to bound transformation
"""
Expand Down Expand Up @@ -300,6 +298,7 @@ def __init__(
for i in range(len(name_mapping[1]))
}


class BoundToUnbound(BijectiveTransform):
"""
Bound to unbound transformation
Expand All @@ -314,7 +313,7 @@ def __init__(
original_lower_bound: Float,
original_upper_bound: Float,
):

def logit(x):
return jnp.log(x / (1 - x))

Expand All @@ -330,17 +329,13 @@ def logit(x):
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
self.original_upper_bound - self.original_lower_bound
)
/ (
1
+ jnp.exp(-x[name_mapping[1][i]])
)
name_mapping[0][i]: (self.original_upper_bound - self.original_lower_bound)
/ (1 + jnp.exp(-x[name_mapping[1][i]]))
+ self.original_lower_bound[i]
for i in range(len(name_mapping[1]))
}


class SingleSidedUnboundTransform(BijectiveTransform):
"""
Unbound upper limit transformation
Expand All @@ -367,7 +362,6 @@ def __init__(
}



class PowerLawTransform(BijectiveTransform):
"""
PowerLaw transformation
Expand All @@ -392,14 +386,22 @@ def __init__(
self.xmin = xmin
self.xmax = xmax
self.alpha = alpha
self.transform_func = lambda x: [(
self.xmin ** (1.0 + self.alpha)
+ x[0] * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
) ** (1.0 / (1.0 + self.alpha))]
self.inverse_transform_func = lambda x: [(
(x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
/ (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)]
self.transform_func = lambda x: {
name_mapping[1][i]: (
self.xmin ** (1.0 + self.alpha)
+ x[0]
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
** (1.0 / (1.0 + self.alpha))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
(x[0] ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
/ (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
for i in range(len(name_mapping[1]))
}


class ParetoTransform(BijectiveTransform):
Expand All @@ -420,7 +422,14 @@ def __init__(
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.transform_func = lambda x: [self.xmin * jnp.exp(
x[0] * jnp.log(self.xmax / self.xmin)
)]
self.inverse_transform_func = lambda x: [(jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin))]
self.transform_func = lambda x: {
name_mapping[1][i]: self.xmin
* jnp.exp(x[0] * jnp.log(self.xmax / self.xmin))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
jnp.log(x[0] / self.xmin) / jnp.log(self.xmax / self.xmin)
)
for i in range(len(name_mapping[1]))
}

0 comments on commit 255fad3

Please sign in to comment.