Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preconditioned mclmc #673

Merged
merged 42 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b60e4ca
TESTS
reubenharry May 13, 2024
0c5aa2d
TESTS
reubenharry May 13, 2024
5eeb3e1
UPDATE DOCSTRING
reubenharry May 13, 2024
4a09156
ADD STREAMING VERSION
reubenharry May 13, 2024
dfb5ee0
ADD PRECONDITIONING TO MCLMC
reubenharry May 13, 2024
2ab3365
ADD PRECONDITIONING TO TUNING FOR MCLMC
reubenharry May 13, 2024
4cc3971
UPDATE GITIGNORE
reubenharry May 13, 2024
f987da3
UPDATE GITIGNORE
reubenharry May 13, 2024
dbab9a3
UPDATE TESTS
reubenharry May 13, 2024
a7ffdb8
UPDATE TESTS
reubenharry May 13, 2024
098f5ad
UPDATE TESTS
reubenharry May 13, 2024
5bd2a3f
ADD DOCSTRING
reubenharry May 13, 2024
4fc1453
ADD TEST
reubenharry May 13, 2024
3678428
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 13, 2024
203f1fd
STREAMING AVERAGE
reubenharry May 15, 2024
fc347d6
ADD TEST
reubenharry May 15, 2024
49410f9
REFACTOR RUN_INFERENCE_ALGORITHM
reubenharry May 15, 2024
ffdca93
UPDATE DOCSTRING
reubenharry May 15, 2024
b7b7084
Precommit
reubenharry May 15, 2024
9d2601d
RESOLVE MERGE CONFLICTS
reubenharry May 15, 2024
97cfc9e
CLEAN TESTS
reubenharry May 15, 2024
45429b8
CLEAN TESTS
reubenharry May 15, 2024
dd9fb1c
Merge branch 'preconditioned_mclmc' of https://github.com/reubenharry…
reubenharry May 15, 2024
a27dba9
GITIGNORE
reubenharry May 15, 2024
7a6e42b
PRECOMMIT CLEAN UP
reubenharry May 15, 2024
6bd5ab1
ADD INITIAL_POSITION
reubenharry May 17, 2024
5615261
FIX TEST
reubenharry May 17, 2024
d66a561
Merge branch 'main' into inference_algorithm
reubenharry May 17, 2024
290addc
Merge branch 'main' into inference_algorithm
reubenharry May 18, 2024
67c0002
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 18, 2024
51fee69
ADD TEST
reubenharry May 18, 2024
29994d7
REMOVE BENCHMARKS
reubenharry May 18, 2024
64948e5
BUG FIX
reubenharry May 18, 2024
c3d44f3
CHANGE PRECISION
reubenharry May 18, 2024
94d43bd
CHANGE PRECISION
reubenharry May 18, 2024
178b452
RENAME O
reubenharry May 19, 2024
9c1c816
Merge branch 'inference_algorithm' of github.com:reubenharry/blackjax…
reubenharry May 19, 2024
db90cdc
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 19, 2024
a26d4a0
UPDATE STREAMING AVG
reubenharry May 19, 2024
4e2b7c0
MERGE
reubenharry May 20, 2024
6bacb6c
UPDATE PR
reubenharry May 24, 2024
9c2fea7
RENAME STD_MAT
reubenharry May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 49 additions & 32 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average
from blackjax.util import pytree_size, streaming_average_update


class MCLMCAdaptationState(NamedTuple):
Expand All @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
std_mat
sqrt_diag_cov_mat
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
A matrix used for preconditioning.
"""

L: float
step_size: float
std_mat: float
sqrt_diag_cov_mat: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -81,10 +81,30 @@ def mclmc_find_L_and_step_size(
Returns
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.

Example
-------
.. code::
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=std_mat,
)

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
diagonal_preconditioning=preconditioning,
)
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, std_mat=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -101,7 +121,7 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
Expand All @@ -128,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(params.std_mat)(
next_state, info = kernel(params.sqrt_diag_cov_mat)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -179,7 +199,7 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average(
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * params.step_size,
Expand All @@ -188,6 +208,17 @@ def step(iteration_state, weight_and_key):

return (state, params, adaptive_state, streaming_avg), None

run_steps = lambda xs, state, params: jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=xs,
)[0]

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = (
int(num_steps * frac_tune1) + 1,
Expand All @@ -205,45 +236,31 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(mask, L_step_size_adaptation_keys),
)[0]
state, params, _, (_, average) = run_steps(
xs=(mask, L_step_size_adaptation_keys), state=state, params=params
)

L = params.L
# determine L
std_mat = params.std_mat
sqrt_diag_cov_mat = params.sqrt_diag_cov_mat
if num_steps2 != 0.0:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
std_mat = jnp.sqrt(variances)
params = params._replace(std_mat=std_mat)
sqrt_diag_cov_mat = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat)
L = jnp.sqrt(dim)

# readjust the stepsize
steps = num_steps2 // 3 # we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(jnp.ones(steps), keys),
)[0]

return state, MCLMCAdaptationState(L, params.step_size, std_mat)
state, params, _, (_, average) = run_steps(
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat)

return L_step_size_adaptation

Expand Down
11 changes: 6 additions & 5 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13):
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(std_mat):
def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0):
def update(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
Expand All @@ -313,7 +313,7 @@ def update(

logdensity_grad = logdensity_grad
flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_grads = flatten_grads * std_mat
flatten_grads = flatten_grads * sqrt_diag_cov_mat
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
Expand All @@ -325,7 +325,7 @@ def update(
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
gr = unravel_fn(new_momentum_normalized * std_mat)
gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
Expand Down Expand Up @@ -357,11 +357,12 @@ def format_isokinetic_state_output(

def generate_isokinetic_integrator(coefficients):
def isokinetic_integrator(
logdensity_fn: Callable, std_mat: ArrayTree = 1.0, *args, **kwargs
logdensity_fn: Callable, *args, **kwargs
) -> GeneralIntegrator:
sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0)
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_two_stage_integrator(
esh_dynamics_momentum_update_one_step(std_mat),
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat),
position_update_fn,
coefficients,
format_output_fn=format_isokinetic_state_output,
Expand Down
8 changes: 4 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
)


def build_kernel(logdensity_fn, std_mat, integrator):
def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator):
"""Build a HMC kernel.

Parameters
Expand All @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, std_mat, integrator):

"""

step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat))
step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat))

def kernel(
rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float
Expand All @@ -105,7 +105,7 @@ def as_top_level_api(
L,
step_size,
integrator=isokinetic_mclachlan,
std_mat=1.0,
sqrt_diag_cov_mat=1.0,
) -> SamplingAlgorithm:
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
Expand Down Expand Up @@ -153,7 +153,7 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(logdensity_fn, std_mat, integrator)
kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return init(position, logdensity_fn, rng_key)
Expand Down
6 changes: 4 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average(expectation(transform(state)), average)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
Expand All @@ -232,7 +232,9 @@ def one_step(average_and_state, xs, return_state):
return transform(final_state), state_history, info_history


def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0):
def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
Expand Down
6 changes: 4 additions & 2 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def test_esh_momentum_update(self, dims):
) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta)))

# Efficient implementation
update_stable = self.variant(esh_dynamics_momentum_update_one_step(std_mat=1.0))
update_stable = self.variant(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
)
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

Expand All @@ -258,7 +260,7 @@ def test_isokinetic_leapfrog(self):
next_state, kinetic_energy_change = step(initial_state, step_size)

# explicit integration
op1 = esh_dynamics_momentum_update_one_step(std_mat=1.0)
op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
op2 = integrators.euclidean_position_update_fn(logdensity_fn)
position, momentum, _, logdensity_grad = initial_state
momentum, kinetic_grad, kinetic_energy_change0 = op1(
Expand Down
18 changes: 9 additions & 9 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def run_mclmc(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan,
std_mat=std_mat,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)

(
Expand All @@ -132,7 +132,7 @@ def run_mclmc(
logdensity_fn,
L=blackjax_mclmc_sampler_params.L,
step_size=blackjax_mclmc_sampler_params.step_size,
std_mat=blackjax_mclmc_sampler_params.std_mat,
sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat,
)

_, samples, _ = run_inference_algorithm(
Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(self, d, condition_number):

integrator = isokinetic_mclachlan

def get_std_mat():
def get_sqrt_diag_cov_mat():
init_key, tune_key = jax.random.split(key)

initial_position = model.sample_init(init_key)
Expand All @@ -311,10 +311,10 @@ def get_std_mat():
rng_key=init_key,
)

kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=model.logdensity_fn,
integrator=integrator,
std_mat=std_mat,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)

(
Expand All @@ -328,13 +328,13 @@ def get_std_mat():
diagonal_preconditioning=True,
)

return blackjax_mclmc_sampler_params.std_mat
return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat

std_mat = get_std_mat()
sqrt_diag_cov_mat = get_sqrt_diag_cov_mat()
assert (
jnp.abs(
jnp.dot(
(std_mat**2) / jnp.linalg.norm(std_mat**2),
(sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2),
eigs / jnp.linalg.norm(eigs),
)
- 1
Expand Down
Loading