Skip to content

Commit

Permalink
Removing inverse.
Browse files Browse the repository at this point in the history
Inverse limts the type of transform one can use,
And it doesn't seem to have case that will require log_prob on target space
  • Loading branch information
kazewong committed Jul 25, 2024
1 parent c16eef5 commit 8ab92a4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 63 deletions.
14 changes: 7 additions & 7 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def sample(

def log_prob(self, x: dict[str, Float]) -> Float:
"""
Requiring inverse transform in log_prob may not be the best option,
may need alternative
log_prob has to be evaluated in the space of the base_prior.
"""
output = 0.0
for transform in reversed(self.transforms):
x, log_jacobian = transform.inverse_transform(x)
output += log_jacobian
output += self.base_prior.log_prob(x)
output = self.base_prior.log_prob(x)
for transform in self.transforms:
x, log_jacobian = transform.transform(x)
output -= log_jacobian
return output

# class Combine(Prior):
Expand Down
56 changes: 0 additions & 56 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from chex import assert_rank
from jaxtyping import Array, Float, jaxtyped


class Transform(ABC):
"""
Base class for transform.
Expand All @@ -18,8 +17,6 @@ class Transform(ABC):

name_mapping: tuple[list[str], list[str]]
transform_func: Callable[[dict[str, Float]], dict[str, Float]]
inverse_func: Callable[[dict[str, Float]], dict[str, Float]]

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
Expand Down Expand Up @@ -49,23 +46,6 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
"""
raise NotImplementedError

@abstractmethod
def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Inverse transform the input x to transformed coordinate y.
Parameters
----------
x : dict[str, Float]
The input dictionary.
Returns
-------
y : dict[str, Float]
The transformed dictionary.
"""
raise NotImplementedError

@abstractmethod
def forward(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Expand All @@ -82,23 +62,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]:
The transformed dictionary.
"""
raise NotImplementedError

@abstractmethod
def backward(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Pull back the input x to transformed coordinate y.
Parameters
----------
x : dict[str, Float]
The input dictionary.
Returns
-------
y : dict[str, Float]
The transformed dictionary.
"""
raise NotImplementedError

def propagate_name(self, x: list[str]) -> list[str]:
input_set = set(x)
Expand All @@ -123,29 +86,13 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
x[self.name_mapping[1][0]] = output_params
return x, jnp.log(jacobian)

def inverse_transform(self, x: dict[str, Float]) -> dict[str, Float]:
output_params = x.pop(self.name_mapping[1][0])
assert_rank(output_params, 0)
input_params = self.inverse_func(output_params)
jacobian = jax.jacfwd(self.inverse_func)(output_params)
x[self.name_mapping[0][0]] = input_params
return x, jnp.log(jacobian)

def forward(self, x: dict[str, Float]) -> dict[str, Float]:
input_params = x.pop(self.name_mapping[0][0])
assert_rank(input_params, 0)
output_params = self.transform_func(input_params)
x[self.name_mapping[1][0]] = output_params
return x

def backward(self, x: dict[str, Float]) -> dict[str, Float]:
output_params = x.pop(self.name_mapping[1][0])
assert_rank(output_params, 0)
input_params = self.inverse_func(output_params)
x[self.name_mapping[0][0]] = input_params
return x


class Scale(UnivariateTransform):
scale: Float

Expand All @@ -157,7 +104,6 @@ def __init__(
super().__init__(name_mapping)
self.scale = scale
self.transform_func = lambda x: x * self.scale
self.inverse_func = lambda x: x / self.scale

class Offset(UnivariateTransform):
offset: Float
Expand All @@ -170,7 +116,6 @@ def __init__(
super().__init__(name_mapping)
self.offset = offset
self.transform_func = lambda x: x + self.offset
self.inverse_func = lambda x: x - self.offset

class Logit(UnivariateTransform):
"""
Expand All @@ -189,7 +134,6 @@ def __init__(
):
super().__init__(name_mapping)
self.transform_func = lambda x: 1 / (1 + jnp.exp(-x))
self.inverse_func = lambda x: jnp.log(x / (1 - x))

class Sine(UnivariateTransform):
"""
Expand Down

0 comments on commit 8ab92a4

Please sign in to comment.