Skip to content

Commit

Permalink
Update transforms.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyuon authored Aug 21, 2024
1 parent 8e014ee commit bc8e7dc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.transform_func(transform_params)
jacobian = jax.jacfwd(self.transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jax.tree.map(
lambda key: x_copy.pop(key),
self.name_mapping[0],
Expand Down Expand Up @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jax.tree.map(
lambda key: y_copy.pop(key),
self.name_mapping[1],
Expand Down

0 comments on commit bc8e7dc

Please sign in to comment.