Skip to content

Commit

Permalink
Implement metric scaling (#733)
Browse files Browse the repository at this point in the history
* Plotting BlackJAX with BlackJAX

* Plotting BlackJAX with BlackJAX

* Proposed implementation for metric scaling

* Add tests and fix some small typing issues raised by pre-commit.

* Fix remaining failing tests

* pre-commit run

* The original implementation was using upper cholesky, I was using lower.

* Fixing a bunch of tests

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <[email protected]>

* Merged comments from Junpeng

* Merged comments from Junpeng

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
AdrienCorenflos and junpenglao authored Sep 16, 2024
1 parent 8a9b546 commit e1d816a
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 78 deletions.
9 changes: 7 additions & 2 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
Expand Down Expand Up @@ -129,8 +130,8 @@ def kernel(
"""

flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
flat_inverse_scale**2
)

Expand Down Expand Up @@ -248,6 +249,10 @@ def as_top_level_api(
A PyTree of the same structure as the target PyTree (position) with the
values used for as a step size for each dimension of the target space in
the velocity verlet integrator.
momentum_inverse_scale
Pytree with the same structure as the targeted position variable
specifying the per dimension inverse scaling transformation applied
to the persistent momentum variable prior to the integration step.
alpha
The value defining the persistence of the momentum variable.
delta
Expand Down
186 changes: 123 additions & 63 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
from jax.scipy import stats as sp_stats

from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise
from blackjax.types import Array, ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise, linear_map

__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]


class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
...


Expand All @@ -60,10 +60,18 @@ def __call__(
...


class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
...


class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Scale


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -128,46 +136,19 @@ def gaussian_euclidean(
itself given the values of the momentum along the trajectory.
"""
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
matmul = jnp.multiply

elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
# Note that mass_matrix_sqrt is a upper triangular matrix here, with
# jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T)
# == inverse_mass_matrix
# An alternative is to compute directly the cholesky factor of the inverse mass
# matrix
# mass_matrix_sqrt = jscipy.linalg.cholesky(
# jscipy.linalg.inv(inverse_mass_matrix), lower=True)
# which the result would instead be a lower triangular matrix.
matmul = jnp.matmul

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
inverse_mass_matrix, is_inv=True
)

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
del position
momentum, _ = ravel_pytree(momentum)
velocity = matmul(inverse_mass_matrix, momentum)
velocity = linear_map(inverse_mass_matrix, momentum)
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum)
return kinetic_energy_val

Expand Down Expand Up @@ -196,39 +177,60 @@ def is_turning(
m_right, _ = ravel_pytree(momentum_right)
m_sum, _ = ravel_pytree(momentum_sum)

velocity_left = matmul(inverse_mass_matrix, m_left)
velocity_right = matmul(inverse_mass_matrix, m_right)
velocity_left = linear_map(inverse_mass_matrix, m_left)
velocity_right = linear_map(inverse_mass_matrix, m_right)

# rho = m_sum
rho = m_sum - (m_right + m_left) / 2
turning_at_left = jnp.dot(velocity_left, rho) <= 0
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position. Not used in this metric.
elements
Elements to scale
invs
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
Returns
-------
scaled_elements
The scaled elements.
"""

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def gaussian_riemannian(
mass_matrix_fn: Callable,
) -> Metric:
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
mass_matrix_sqrt = jnp.sqrt(mass_matrix)
elif ndim == 2:
mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
mass_matrix_sqrt, *_ = _format_covariance(mass_matrix, is_inv=False)

return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
) -> Numeric:
if position is None:
raise ValueError(
"A Reinmannian kinetic energy function must be called with the "
Expand All @@ -238,18 +240,11 @@ def kinetic_energy(

momentum, _ = ravel_pytree(momentum)
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix)))
elif ndim == 2:
return -sp_stats.multivariate_normal.logpdf(
momentum, jnp.zeros_like(momentum), mass_matrix
)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
sqrt_mass_matrix, inv_sqrt_mass_matrix, diag = _format_covariance(
mass_matrix, is_inv=False
)

return _energy(momentum, 0, sqrt_mass_matrix, inv_sqrt_mass_matrix.T, diag)

def is_turning(
momentum_left: ArrayLikeTree,
Expand Down Expand Up @@ -283,4 +278,69 @@ def is_turning(
# turning_at_right = jnp.dot(velocity_right, rho) <= 0
# return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position.
Returns
-------
scaled_elements
The scaled elements.
"""
mass_matrix = mass_matrix_fn(position)
mass_matrix_sqrt, inv_mass_matrix_sqrt, diag = _format_covariance(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def _format_covariance(cov: Array, is_inv):
ndim = jnp.ndim(cov)
if ndim == 1:
cov_sqrt = jnp.sqrt(cov)
inv_cov_sqrt = 1 / cov_sqrt
diag = lambda x: x
if is_inv:
inv_cov_sqrt, cov_sqrt = cov_sqrt, inv_cov_sqrt
elif ndim == 2:
identity = jnp.identity(cov.shape[0])
if is_inv:
inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True)
cov_sqrt = jscipy.linalg.solve_triangular(
inv_cov_sqrt, identity, lower=True, trans=True
)
else:
cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T
inv_cov_sqrt = jscipy.linalg.solve_triangular(
cov_sqrt, identity, lower=True, trans=True
)

diag = lambda x: jnp.diag(x)

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(cov)}."
)
return cov_sqrt, inv_cov_sqrt, diag


def _energy(x, mean, cov_sqrt, inv_cov_sqrt, diag):
d = x.shape[0]
z = linear_map(inv_cov_sqrt, x - mean)
const = jnp.sum(jnp.log(diag(cov_sqrt))) + d / 2 * jnp.log(2 * jnp.pi)
return 0.5 * jnp.sum(z**2) + const
2 changes: 1 addition & 1 deletion blackjax/mcmc/periodic_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def kernel(
"""

momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
inverse_mass_matrix
)
bijection_fn = bijection(logdensity_fn, kinetic_energy_fn)
Expand Down
4 changes: 4 additions & 0 deletions blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ class WelfordAlgorithmState(NamedTuple):

#: JAX PRNGKey
PRNGKey = jax.Array

#: JAX Scalar types
Scalar = Union[float, int]
Numeric = Union[jax.Array, Scalar]
Loading

0 comments on commit e1d816a

Please sign in to comment.