Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preconditioned mclmc #673

Merged
merged 42 commits into from
May 25, 2024
Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b60e4ca
TESTS
reubenharry May 13, 2024
0c5aa2d
TESTS
reubenharry May 13, 2024
5eeb3e1
UPDATE DOCSTRING
reubenharry May 13, 2024
4a09156
ADD STREAMING VERSION
reubenharry May 13, 2024
dfb5ee0
ADD PRECONDITIONING TO MCLMC
reubenharry May 13, 2024
2ab3365
ADD PRECONDITIONING TO TUNING FOR MCLMC
reubenharry May 13, 2024
4cc3971
UPDATE GITIGNORE
reubenharry May 13, 2024
f987da3
UPDATE GITIGNORE
reubenharry May 13, 2024
dbab9a3
UPDATE TESTS
reubenharry May 13, 2024
a7ffdb8
UPDATE TESTS
reubenharry May 13, 2024
098f5ad
UPDATE TESTS
reubenharry May 13, 2024
5bd2a3f
ADD DOCSTRING
reubenharry May 13, 2024
4fc1453
ADD TEST
reubenharry May 13, 2024
3678428
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 13, 2024
203f1fd
STREAMING AVERAGE
reubenharry May 15, 2024
fc347d6
ADD TEST
reubenharry May 15, 2024
49410f9
REFACTOR RUN_INFERENCE_ALGORITHM
reubenharry May 15, 2024
ffdca93
UPDATE DOCSTRING
reubenharry May 15, 2024
b7b7084
Precommit
reubenharry May 15, 2024
9d2601d
RESOLVE MERGE CONFLICTS
reubenharry May 15, 2024
97cfc9e
CLEAN TESTS
reubenharry May 15, 2024
45429b8
CLEAN TESTS
reubenharry May 15, 2024
dd9fb1c
Merge branch 'preconditioned_mclmc' of https://github.com/reubenharry…
reubenharry May 15, 2024
a27dba9
GITIGNORE
reubenharry May 15, 2024
7a6e42b
PRECOMMIT CLEAN UP
reubenharry May 15, 2024
6bd5ab1
ADD INITIAL_POSITION
reubenharry May 17, 2024
5615261
FIX TEST
reubenharry May 17, 2024
d66a561
Merge branch 'main' into inference_algorithm
reubenharry May 17, 2024
290addc
Merge branch 'main' into inference_algorithm
reubenharry May 18, 2024
67c0002
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 18, 2024
51fee69
ADD TEST
reubenharry May 18, 2024
29994d7
REMOVE BENCHMARKS
reubenharry May 18, 2024
64948e5
BUG FIX
reubenharry May 18, 2024
c3d44f3
CHANGE PRECISION
reubenharry May 18, 2024
94d43bd
CHANGE PRECISION
reubenharry May 18, 2024
178b452
RENAME O
reubenharry May 19, 2024
9c1c816
Merge branch 'inference_algorithm' of github.com:reubenharry/blackjax…
reubenharry May 19, 2024
db90cdc
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 19, 2024
a26d4a0
UPDATE STREAMING AVG
reubenharry May 19, 2024
4e2b7c0
MERGE
reubenharry May 20, 2024
6bacb6c
UPDATE PR
reubenharry May 24, 2024
9c2fea7
RENAME STD_MAT
reubenharry May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ADD PRECONDITIONING TO MCLMC
reubenharry committed May 13, 2024
commit dfb5ee0146d5f4f50ba7f71b164320002eced394
129 changes: 88 additions & 41 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.random import normal

from blackjax.mcmc.metrics import KineticEnergy
from blackjax.types import ArrayTree
@@ -293,43 +294,48 @@ def _normalized_flatten_array(x, tol=1e-13):
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
step_size: float,
coef: float,
previous_kinetic_energy_change=None,
is_last_call=False,
):
"""Momentum update based on Esh dynamics.

The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian`
There are no exponentials e^delta, which prevents overflows when the gradient norm
is large.
"""
del is_last_call

flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
momentum_proj = jnp.dot(flatten_momentum, normalized_gradient)
delta = step_size * coef * gradient_norm / (dims - 1)
zeta = jnp.exp(-delta)
new_momentum_raw = (
normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta))
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
- jnp.log(2)
+ jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2)
) * (dims - 1)
if previous_kinetic_energy_change is not None:
kinetic_energy_change += previous_kinetic_energy_change
return next_momentum, next_momentum, kinetic_energy_change
def esh_dynamics_momentum_update_one_step(std_mat):
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
def update(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
step_size: float,
coef: float,
previous_kinetic_energy_change=None,
is_last_call=False,
):
"""Momentum update based on Esh dynamics.

The momentum updating map of the esh dynamics as derived in :cite:p:`steeg2021hamiltonian`
There are no exponentials e^delta, which prevents overflows when the gradient norm
is large.
"""
del is_last_call

logdensity_grad = logdensity_grad * std_mat
flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
momentum_proj = jnp.dot(flatten_momentum, normalized_gradient)
delta = step_size * coef * gradient_norm / (dims - 1)
zeta = jnp.exp(-delta)
new_momentum_raw = (
normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta))
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
- jnp.log(2)
+ jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2)
) * (dims - 1)
if previous_kinetic_energy_change is not None:
kinetic_energy_change += previous_kinetic_energy_change
gr = std_mat * next_momentum
return next_momentum, gr, kinetic_energy_change

return update


def format_isokinetic_state_output(
@@ -348,15 +354,15 @@ def format_isokinetic_state_output(
)


def generate_isokinetic_integrator(cofficients):
def generate_isokinetic_integrator(coefficients):
def isokinetic_integrator(
logdensity_fn: Callable, *args, **kwargs
logdensity_fn: Callable, std_mat : ArrayTree, *args, **kwargs
) -> GeneralIntegrator:
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_two_stage_integrator(
esh_dynamics_momentum_update_one_step,
esh_dynamics_momentum_update_one_step(std_mat),
position_update_fn,
cofficients,
coefficients,
format_output_fn=format_isokinetic_state_output,
)
return one_step
@@ -368,6 +374,47 @@ def isokinetic_integrator(
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients)
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients)

def partially_refresh_momentum(momentum, rng_key, step_size, L):
"""Adds a small noise to momentum and normalizes.

Parameters
----------
rng_key
The pseudo-random number generator key used to generate random numbers.
momentum
PyTree that the structure the output should to match.
step_size
Step size
L
controls rate of momentum change

Returns
-------
momentum with random change in angle
"""
m, unravel_fn = ravel_pytree(momentum)
dim = m.shape[0]
nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim)
z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype)
return unravel_fn((m + z) / jnp.linalg.norm(m + z))



def with_isokinetic_maruyama(integrator):

def stochastic_integrator(init_state, step_size, L_proposal, rng_key):

key1, key2 = jax.random.split(rng_key)
# partial refreshment
state = init_state._replace(momentum=partially_refresh_momentum(momentum=init_state.momentum, rng_key=key1, L=L_proposal, step_size=step_size * 0.5))
# one step of the deterministic dynamics
state, info = integrator(state, step_size)
# partial refreshment
state = state._replace(momentum=partially_refresh_momentum(momentum=state.momentum, rng_key=key2, L=L_proposal, step_size=step_size * 0.5))
return state, info

return stochastic_integrator

FixedPointSolver = Callable[
[Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree],
Tuple[ArrayTree, ArrayTree, Any],
44 changes: 9 additions & 35 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
@@ -15,12 +15,10 @@
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.random import normal


from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan
from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan, with_isokinetic_maruyama
from blackjax.types import ArrayLike, PRNGKey
from blackjax.util import generate_unit_vector, pytree_size

@@ -58,8 +56,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
logdensity_grad=g,
)


def build_kernel(logdensity_fn, integrator):
def build_kernel(logdensity_fn, std_mat, integrator):
"""Build a HMC kernel.

Parameters
@@ -78,19 +75,17 @@ def build_kernel(logdensity_fn, integrator):
information about the transition.

"""
step = integrator(logdensity_fn)

print(std_mat, "foo")
step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat))

def kernel(
rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float
) -> tuple[IntegratorState, MCLMCInfo]:
(position, momentum, logdensity, logdensitygrad), kinetic_change = step(
state, step_size
state, step_size, L, rng_key
)

# Langevin-like noise
momentum = partially_refresh_momentum(
momentum=momentum, rng_key=rng_key, L=L, step_size=step_size
)

return IntegratorState(
position, momentum, logdensity, logdensitygrad
@@ -108,6 +103,7 @@ def as_top_level_api(
L,
step_size,
integrator=isokinetic_mclachlan,
std_mat=1.,
) -> SamplingAlgorithm:
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
@@ -155,7 +151,7 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(logdensity_fn, integrator)
kernel = build_kernel(logdensity_fn, std_mat, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return init(position, logdensity_fn, rng_key)
@@ -166,26 +162,4 @@ def update_fn(rng_key, state):
return SamplingAlgorithm(init_fn, update_fn)


def partially_refresh_momentum(momentum, rng_key, step_size, L):
"""Adds a small noise to momentum and normalizes.

Parameters
----------
rng_key
The pseudo-random number generator key used to generate random numbers.
momentum
PyTree that the structure the output should to match.
step_size
Step size
L
controls rate of momentum change

Returns
-------
momentum with random change in angle
"""
m, unravel_fn = ravel_pytree(momentum)
dim = m.shape[0]
nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim)
z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype)
return unravel_fn((m + z) / jnp.linalg.norm(m + z))
62 changes: 35 additions & 27 deletions explore.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,61 @@
import jax
import jax.numpy as jnp
from benchmarks.mcmc.sampling_algorithms import samplers

import blackjax
from blackjax.mcmc.mhmclmc import mhmclmc, rescale
from blackjax.mcmc.hmc import hmc
from blackjax.mcmc.dynamic_hmc import dynamic_hmc
from blackjax.mcmc.integrators import isokinetic_mclachlan
from blackjax.util import run_inference_algorithm





init_key, tune_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3)


def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))

initial_position = jnp.ones(10,)

initial_position = jnp.ones(
10,
)

def run_mclmc(logdensity_fn, num_steps, initial_position):
key = jax.random.PRNGKey(0)
init_key, tune_key, run_key = jax.random.split(key, 3)

def run_mclmc(logdensity_fn, key, num_steps, initial_position):
init_key, tune_key, run_key = jax.random.split(key, 3)

initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1, std_mat=1.)

average, states = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=num_steps,
progress_bar=True,
transform=lambda x: x.position,
streaming=True,
)

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
print(average)

_, states, _ = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
progress_bar=False,
transform=lambda x: x.position,
streaming=False,
)

print(blackjax_mclmc_sampler_params)
print(states.mean(axis=0))

# out = run_hmc(initial_position)
out = samplers["mhmclmc"](logdensity_fn=logdensity_fn, num_steps=5000, initial_position=initial_position, key=jax.random.PRNGKey(0))
print(out.mean(axis=0) )
return states


# out = run_hmc(initial_position)
out = run_mclmc(
logdensity_fn=logdensity_fn,
num_steps=5,
initial_position=initial_position,
key=jax.random.PRNGKey(0),
)