diff --git a/example/GW150914_Pv2.py b/example/GW150914_Pv2.py index c922c822..5342fd68 100644 --- a/example/GW150914_Pv2.py +++ b/example/GW150914_Pv2.py @@ -117,7 +117,7 @@ Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -n_epochs = 30 +n_epochs = 40 n_loop_training = 100 learning_rate = 1e-4 diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..ac56a1e1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1],