diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 39c55642..b201f6f8 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -87,7 +87,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -124,7 +126,9 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -298,6 +302,60 @@ def __init__( } +@jaxtyped(typechecker=typechecker) +class SineTransform(BijectiveTransform): + """ + Sine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.sin(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.arcsin(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + +@jaxtyped(typechecker=typechecker) +class CosineTransform(BijectiveTransform): + """ + Cosine transformation + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + ): + super().__init__(name_mapping) + self.transform_func = lambda x: { + name_mapping[1][i]: jnp.cos(x[name_mapping[0][i]]) + for i in range(len(name_mapping[0])) + } + self.inverse_transform_func = lambda x: { + name_mapping[0][i]: jnp.arccos(x[name_mapping[1][i]]) + for i in range(len(name_mapping[1])) + } + + @jaxtyped(typechecker=typechecker) class ArcSineTransform(BijectiveTransform): """