From 85702ae0ebc17a6f6cd5bfc86c2921d589a3ff10 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Fri, 18 Oct 2024 14:50:40 -0400 Subject: [PATCH 1/2] Adding Sine and Cosine transform --- src/jimgw/transforms.py | 62 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 39c55642..b201f6f8 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -87,7 +87,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -124,7 +126,9 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -298,6 +302,60 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class SineTransform(BijectiveTransform): + """ + Sine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.sin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.arcsin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class CosineTransform(BijectiveTransform): + """ + Cosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.cos(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.arccos(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + @jaxtyped(typechecker=typechecker) class ArcSineTransform(BijectiveTransform): """ From 5ef9b8ed30574d07f1a059a3ff6ec0416b26aee6 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Fri, 18 Oct 2024 15:02:36 -0400 Subject: [PATCH 2/2] Adding a few more comments --- src/jimgw/transforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index b201f6f8..f8e128b7 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -307,6 +307,8 @@ class SineTransform(BijectiveTransform): """ Sine transformation + The original parameter is expected to be in [-pi/2, pi/2] + Parameters ---------- name_mapping : tuple[list[str], list[str]] @@ -334,6 +336,8 @@ class CosineTransform(BijectiveTransform): """ Cosine transformation + The original parameter is expected to be in [0, pi] + Parameters ---------- name_mapping : tuple[list[str], list[str]]