From f33b8a04c30d41d4696d888612c2fab854256372 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Tue, 24 Sep 2024 15:00:38 -0500 Subject: [PATCH] add trans argument to riemannian scaling --- blackjax/mcmc/metrics.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index e298a3425..72c167698 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -322,11 +322,23 @@ def scale( mass_matrix, is_inv=False ) ravelled_element, unravel_fn = ravel_pytree(element) - scaled = jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + + def _linear_map_transpose(): + return jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt.T, ravelled_element), + lambda: linear_map(mass_matrix_sqrt.T, ravelled_element), + ) + + def _linear_map(): + return jax.lax.cond( + inv, + lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), + lambda: linear_map(mass_matrix_sqrt, ravelled_element), + ) + + scaled = jax.lax.cond(trans, _linear_map_transpose, _linear_map) + return unravel_fn(scaled) return Metric(momentum_generator, kinetic_energy, is_turning, scale)