diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 0791f3cb1..098649a9a 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -131,8 +131,8 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -164,8 +164,8 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val @@ -226,8 +226,8 @@ def test_gaussian_riemannian_dim_1(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -265,8 +265,8 @@ def test_gaussian_riemannian_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val