Skip to content

Commit

Permalink
add trans argument to riemannian scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Sep 24, 2024
1 parent a7e1831 commit f33b8a0
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,23 @@ def scale(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

def _linear_map_transpose():
return jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt.T, ravelled_element),
lambda: linear_map(mass_matrix_sqrt.T, ravelled_element),
)

def _linear_map():
return jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

scaled = jax.lax.cond(trans, _linear_map_transpose, _linear_map)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down

0 comments on commit f33b8a0

Please sign in to comment.