Skip to content

Commit

Permalink
Merge pull request #137 from xuyuon/fix-jacobian
Browse files Browse the repository at this point in the history
Taking the absolute value of the Jacobian determinant
  • Loading branch information
kazewong authored Aug 22, 2024
2 parents 1b79a3a + 702ee20 commit 7910785
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.transform_func(transform_params)
jacobian = jax.jacfwd(self.transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jax.tree.map(
lambda key: x_copy.pop(key),
self.name_mapping[0],
Expand Down Expand Up @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jax.tree.map(
lambda key: y_copy.pop(key),
self.name_mapping[1],
Expand Down

0 comments on commit 7910785

Please sign in to comment.