Skip to content

Commit

Permalink
fix tests add trans to scale
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Sep 24, 2024
1 parent 8d580ec commit a7e1831
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/mcmc/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def test_gaussian_euclidean_dim_1(self):
assert momentum_val == expected_momentum_val
assert kinetic_energy_val == expected_kinetic_energy_val

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True)
scaled_momentum = scale(arbitrary_position, momentum_val, False)
inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)

expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix)
expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix)
Expand Down Expand Up @@ -164,8 +164,8 @@ def test_gaussian_euclidean_dim_2(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True)
scaled_momentum = scale(arbitrary_position, momentum_val, False)
inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)

expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val
expected_scaled_momentum = L_inv @ momentum_val
Expand Down Expand Up @@ -226,8 +226,8 @@ def test_gaussian_riemannian_dim_1(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True)
scaled_momentum = scale(arbitrary_position, momentum_val, False)
inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix)
expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix)

Expand Down Expand Up @@ -265,8 +265,8 @@ def test_gaussian_riemannian_dim_2(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True)
scaled_momentum = scale(arbitrary_position, momentum_val, False)
inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val
expected_scaled_momentum = L_inv @ momentum_val

Expand Down

0 comments on commit a7e1831

Please sign in to comment.