Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preconditioned mclmc #673

Merged
merged 42 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b60e4ca
TESTS
reubenharry May 13, 2024
0c5aa2d
TESTS
reubenharry May 13, 2024
5eeb3e1
UPDATE DOCSTRING
reubenharry May 13, 2024
4a09156
ADD STREAMING VERSION
reubenharry May 13, 2024
dfb5ee0
ADD PRECONDITIONING TO MCLMC
reubenharry May 13, 2024
2ab3365
ADD PRECONDITIONING TO TUNING FOR MCLMC
reubenharry May 13, 2024
4cc3971
UPDATE GITIGNORE
reubenharry May 13, 2024
f987da3
UPDATE GITIGNORE
reubenharry May 13, 2024
dbab9a3
UPDATE TESTS
reubenharry May 13, 2024
a7ffdb8
UPDATE TESTS
reubenharry May 13, 2024
098f5ad
UPDATE TESTS
reubenharry May 13, 2024
5bd2a3f
ADD DOCSTRING
reubenharry May 13, 2024
4fc1453
ADD TEST
reubenharry May 13, 2024
3678428
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 13, 2024
203f1fd
STREAMING AVERAGE
reubenharry May 15, 2024
fc347d6
ADD TEST
reubenharry May 15, 2024
49410f9
REFACTOR RUN_INFERENCE_ALGORITHM
reubenharry May 15, 2024
ffdca93
UPDATE DOCSTRING
reubenharry May 15, 2024
b7b7084
Precommit
reubenharry May 15, 2024
9d2601d
RESOLVE MERGE CONFLICTS
reubenharry May 15, 2024
97cfc9e
CLEAN TESTS
reubenharry May 15, 2024
45429b8
CLEAN TESTS
reubenharry May 15, 2024
dd9fb1c
Merge branch 'preconditioned_mclmc' of https://github.com/reubenharry…
reubenharry May 15, 2024
a27dba9
GITIGNORE
reubenharry May 15, 2024
7a6e42b
PRECOMMIT CLEAN UP
reubenharry May 15, 2024
6bd5ab1
ADD INITIAL_POSITION
reubenharry May 17, 2024
5615261
FIX TEST
reubenharry May 17, 2024
d66a561
Merge branch 'main' into inference_algorithm
reubenharry May 17, 2024
290addc
Merge branch 'main' into inference_algorithm
reubenharry May 18, 2024
67c0002
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 18, 2024
51fee69
ADD TEST
reubenharry May 18, 2024
29994d7
REMOVE BENCHMARKS
reubenharry May 18, 2024
64948e5
BUG FIX
reubenharry May 18, 2024
c3d44f3
CHANGE PRECISION
reubenharry May 18, 2024
94d43bd
CHANGE PRECISION
reubenharry May 18, 2024
178b452
RENAME O
reubenharry May 19, 2024
9c1c816
Merge branch 'inference_algorithm' of github.com:reubenharry/blackjax…
reubenharry May 19, 2024
db90cdc
Merge branch 'inference_algorithm' into preconditioned_mclmc
reubenharry May 19, 2024
a26d4a0
UPDATE STREAMING AVG
reubenharry May 19, 2024
4e2b7c0
MERGE
reubenharry May 20, 2024
6bacb6c
UPDATE PR
reubenharry May 24, 2024
9c2fea7
RENAME STD_MAT
reubenharry May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

explore.py

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
152 changes: 82 additions & 70 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size
from blackjax.util import pytree_size, streaming_average_update


class MCLMCAdaptationState(NamedTuple):
Expand All @@ -30,10 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov_mat
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov_mat: float


def mclmc_find_L_and_step_size(
Expand All @@ -47,6 +50,7 @@ def mclmc_find_L_and_step_size(
desired_energy_var=5e-4,
trust_in_estimate=1.5,
num_effective_samples=150,
diagonal_preconditioning=True,
):
"""
Finds the optimal value of the parameters for the MCLMC algorithm.
Expand Down Expand Up @@ -78,38 +82,30 @@ def mclmc_find_L_and_step_size(
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.

reubenharry marked this conversation as resolved.
Show resolved Hide resolved

Examples
Example
-------

.. code::
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=std_mat,
)

# Define the kernel function
def kernel(x):
return x ** 2

# Define the initial state
initial_state = MCMCState(position=0, momentum=1)

# Generate a random number generator key
rng_key = jax.random.key(0)

# Find the optimal parameters for the MCLMC algorithm
final_state, final_params = mclmc_find_L_and_step_size(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=1000,
num_steps=num_steps,
state=initial_state,
rng_key=rng_key,
frac_tune1=0.2,
frac_tune2=0.3,
frac_tune3=0.1,
desired_energy_var=1e-4,
trust_in_estimate=2.0,
num_effective_samples=200,
rng_key=tune_key,
diagonal_preconditioning=preconditioning,
)
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

state, params = make_L_step_size_adaptation(
Expand All @@ -120,12 +116,13 @@ def kernel(x):
desired_energy_var=desired_energy_var,
trust_in_estimate=trust_in_estimate,
num_effective_samples=num_effective_samples,
diagonal_preconditioning=diagonal_preconditioning,
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
state, params, num_steps, part2_key
)
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params

Expand All @@ -135,6 +132,7 @@ def make_L_step_size_adaptation(
dim,
frac_tune1,
frac_tune2,
diagonal_preconditioning,
desired_energy_var=1e-3,
trust_in_estimate=1.5,
num_effective_samples=150,
Expand All @@ -150,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(
next_state, info = kernel(params.sqrt_diag_cov_mat)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -185,68 +183,84 @@ def predictor(previous_state, params, adaptive_state, rng_key):
) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
params_new = params._replace(step_size=step_size)

return state, params_new, params_new, (time, x_average, step_size_max), success

def update_kalman(x, state, outer_weight, success, step_size):
"""kalman filter to estimate the size of the posterior"""
time, x_average, x_squared_average = state
weight = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
x_average = (time * x_average + weight * x) / (
time + weight + zero_prevention
) # Update <f(x)> with a Kalman filter
x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / (
time + weight + zero_prevention
) # Update <f(x)> with a Kalman filter
time += weight
return (time, x_average, x_squared_average)
adaptive_state = (time, x_average, step_size_max)

adap0 = (0.0, 0.0, jnp.inf)
return state, params_new, adaptive_state, success

def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

outer_weight, rng_key = weight_and_key
state, params, adaptive_state, kalman_state = iteration_state
state, params, params_final, adaptive_state, success = predictor(
mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state

state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)
position, _ = ravel_pytree(state.position)
kalman_state = update_kalman(
position, kalman_state, outer_weight, success, params.step_size

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)

return (state, params_final, adaptive_state, kalman_state), None
return (state, params, adaptive_state, streaming_avg), None

run_steps = lambda xs, state, params: jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=xs,
)[0]

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
num_steps * frac_tune2
num_steps1, num_steps2 = (
int(num_steps * frac_tune1) + 1,
int(num_steps * frac_tune2) + 1,
)
L_step_size_adaptation_keys = jax.random.split(
rng_key, num_steps1 + num_steps2 + 1
)
L_step_size_adaptation_keys, final_key = (
L_step_size_adaptation_keys[:-1],
L_step_size_adaptation_keys[-1],
)
L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# initial state of the kalman filter
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
kalman_state, *_ = jax.lax.scan(
step,
init=(state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
state, params, _, (_, average) = run_steps(
xs=(mask, L_step_size_adaptation_keys), state=state, params=params
)
state, params, _, kalman_state_output = kalman_state

L = params.L
# determine L
sqrt_diag_cov_mat = params.sqrt_diag_cov_mat
if num_steps2 != 0.0:
_, F1, F2 = kalman_state_output
variances = F2 - jnp.square(F1)
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

return state, MCLMCAdaptationState(L, params.step_size)
if diagonal_preconditioning:
sqrt_diag_cov_mat = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat)
L = jnp.sqrt(dim)

# readjust the stepsize
steps = num_steps2 // 3 # we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = run_steps(
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat)

return L_step_size_adaptation

Expand All @@ -258,7 +272,6 @@ def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)

# run kernel in the normal way
def step(state, key):
next_state, _ = kernel(
rng_key=key,
Expand Down Expand Up @@ -297,5 +310,4 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

return nonans, state, step_size, kinetic_change
Loading
Loading