Skip to content

Commit

Permalink
Proposed implementation for metric scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Sep 6, 2024
1 parent 6b0f2aa commit 0f9f9d4
Showing 1 changed file with 123 additions and 52 deletions.
175 changes: 123 additions & 52 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union
from typing import Callable, NamedTuple, Optional, Protocol, Union, Tuple, List

import jax.numpy as jnp
import jax.scipy as jscipy
Expand All @@ -43,19 +43,19 @@

class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
...


class CheckTurning(Protocol):
def __call__(
self,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
self,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
...

Expand All @@ -64,6 +64,7 @@ class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
scale: Callable[[ArrayLikeTree, Tuple[Tuple[ArrayLikeTree, bool]]], ArrayLikeTree]


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -94,7 +95,7 @@ def default_metric(metric: MetricTypes) -> Metric:


def gaussian_euclidean(
inverse_mass_matrix: Array,
inverse_mass_matrix: Array,
) -> Metric:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum
:cite:p:`betancourt2013general`.
Expand Down Expand Up @@ -128,42 +129,14 @@ def gaussian_euclidean(
itself given the values of the momentum along the trajectory.
"""
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]
inv_mass_matrix_sqrt, mass_matrix_sqrt, matmul = _format_covariance(inverse_mass_matrix, get_inv=True)

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
matmul = jnp.multiply

elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
# Note that mass_matrix_sqrt is a upper triangular matrix here, with
# jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T)
# == inverse_mass_matrix
# An alternative is to compute directly the cholesky factor of the inverse mass
# matrix
# mass_matrix_sqrt = jscipy.linalg.cholesky(
# jscipy.linalg.inv(inverse_mass_matrix), lower=True)
# which the result would instead be a lower triangular matrix.
matmul = jnp.matmul

else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
del position
momentum, _ = ravel_pytree(momentum)
Expand All @@ -172,11 +145,11 @@ def kinetic_energy(
return kinetic_energy_val

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`.
Expand Down Expand Up @@ -205,12 +178,43 @@ def is_turning(
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)
def scale(position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]) -> Tuple[ArrayLikeTree]:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position. Not used in this metric.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
for element, inv in elements:
ravelled_element, unravel_fn = ravel_pytree(element)
if inv:
ravelled_element = matmul(inv_mass_matrix_sqrt, ravelled_element)
else:
ravelled_element = matmul(mass_matrix_sqrt, ravelled_element)
scaled_elements.append(unravel_fn(ravelled_element))
return tuple(scaled_elements)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)


def gaussian_riemannian(
mass_matrix_fn: Callable,
mass_matrix_fn: Callable,
) -> Metric:





def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
Expand All @@ -227,7 +231,7 @@ def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTr
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
if position is None:
raise ValueError(
Expand All @@ -252,11 +256,11 @@ def kinetic_energy(
)

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
del momentum_left, momentum_right, momentum_sum, position_left, position_right
raise NotImplementedError(
Expand All @@ -283,4 +287,71 @@ def is_turning(
# turning_at_right = jnp.dot(velocity_right, rho) <= 0
# return turning_at_left | turning_at_right

def scale(position: ArrayLikeTree, elements: Tuple[Tuple[ArrayLikeTree, bool]]) -> Tuple[ArrayLikeTree]:
"""Scale elements by the mass matrix.
Parameters
----------
position
The current position.
elements
A tuple of (element, inv) pairs to scale.
If inv is True, the element is scaled by the inverse square root mass matrix, i.e., elem <- M^{-1/2} elem.
Returns
-------
scaled_elements
The scaled elements.
"""
scaled_elements = []
mass_matrix = mass_matrix_fn(position)
# some small performance improvement: group by inv and only compute the inverse Cholesky if needed

inv_elements = [(k, element) for k, (element, inv) in enumerate(elements) if inv]
non_inv_elements = [(k, element) for k, (element, inv) in enumerate(elements) if not inv]
argsort = [k for k, _ in non_inv_elements] + [k for k, _ in inv_elements]

mass_matrix_sqrt, inv_sqrt_mass_matrix, matmul = _format_covariance(mass_matrix, get_inv=bool(inv_elements))

for _, element in non_inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = matmul(mass_matrix_sqrt, rav_element)
scaled_elements.append(unravel_fn(rav_element))

if inv_elements:
for _, element in inv_elements:
rav_element, unravel_fn = ravel_pytree(element)
rav_element = matmul(inv_sqrt_mass_matrix, rav_element)
scaled_elements.append(unravel_fn(rav_element))

scaled_elements = [scaled_elements[k] for k in argsort]

return tuple(scaled_elements)

return Metric(momentum_generator, kinetic_energy, is_turning)

def _format_covariance(mass_matrix: Array, get_inv):
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
mass_matrix_sqrt = jnp.sqrt(mass_matrix)
matmul = jnp.multiply
if get_inv:
inv_mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt)
else:
inv_mass_matrix_sqrt = None
elif ndim == 2:
mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True)
matmul = jnp.matmul
if get_inv:
identity = jnp.identity(mass_matrix.shape[0])
inv_mass_matrix_sqrt = jscipy.linalg.solve_triangular(
mass_matrix_sqrt, identity, lower=True
)
else:
inv_mass_matrix_sqrt = None
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
return mass_matrix_sqrt, inv_mass_matrix_sqrt, matmul

0 comments on commit 0f9f9d4

Please sign in to comment.