Skip to content

Commit

Permalink
Merge branch 'main' into adjusted_mclmc
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Dec 27, 2024
2 parents 477b11a + e1d816a commit 677dea7
Show file tree
Hide file tree
Showing 29 changed files with 873 additions and 375 deletions.
48 changes: 0 additions & 48 deletions .github/workflows/nightly.yml

This file was deleted.

13 changes: 4 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ or via conda-forge:
conda install -c conda-forge blackjax
```

Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`:

```bash
pip install blackjax-nightly
```

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
Expand Down Expand Up @@ -81,9 +75,10 @@ state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for step in range(100):
nuts_key = jax.random.fold_in(rng_key, step)
state, _ = nuts.step(nuts_key, state)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
```

See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
Expand Down
35 changes: 25 additions & 10 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -147,20 +147,24 @@ def predictor(previous_state, params, adaptive_state, rng_key):

time, x_average, step_size_max = adaptive_state

rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
state=previous_state,
L=params.L,
step_size=params.step_size,
)

# step updating
success, state, step_size_max, energy_change = handle_nans(
previous_state,
next_state,
params.step_size,
step_size_max,
info.energy_change,
nan_key,
)

# Warning: var = 0 if there were nans, but we will give it a very small weight
Expand Down Expand Up @@ -199,11 +203,10 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=mask * success * params.step_size,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
)

# we use the last num_steps2 to compute the diagonal preconditioner
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = run_steps(
Expand All @@ -243,7 +246,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))
Expand Down Expand Up @@ -298,16 +301,28 @@ def step(state, key):
return adaptation_L


def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
def handle_nans(
previous_state, next_state, step_size, step_size_max, kinetic_change, key
):
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
nonans = jnp.all(jnp.isfinite(p))
q, unravel_fn = ravel_pytree(next_state.momentum)
nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q)))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

state = jax.lax.cond(
jnp.isnan(next_state.logdensity),
lambda: state._replace(
momentum=generate_unit_vector(key, previous_state.position)
),
lambda: state,
)

return nonans, state, step_size, kinetic_change
4 changes: 2 additions & 2 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple):
alpha
Value of the alpha parameter of the generalized HMC algorithm.
delta
Value of the alpha parameter of the generalized HMC algorithm.
Value of the delta parameter of the generalized HMC algorithm.
"""

Expand All @@ -60,7 +60,7 @@ def base():
with shape.
This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain
adaptation instead of parallel ensample chain adaptation.
adaptation instead of parallel ensemble chain adaptation.
Returns
-------
Expand Down
15 changes: 7 additions & 8 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
dual_averaging_adaptation,
)
from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.progress_bar import gen_scan_fn
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

Expand Down Expand Up @@ -333,17 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):

if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
else:
one_step_ = jax.jit(one_step)

scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar)
start_state = (init_state, init_adaptation_state)
keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
last_state, info = scan_fn(
one_step,
start_state,
(jnp.arange(num_steps), keys, schedule),
)

last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
Expand Down
2 changes: 1 addition & 1 deletion blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple):
"""A pair of functions that represents a MCMC sampling algorithm.
Blackjax sampling algorithms are implemented as a pair of pure functions: a
kernel, that takes a new samples starting from the current state, and an
kernel, that generates a new sample from the current state, and an
initialization function that creates a kernel state from a chain position.
As they represent Markov kernels, the kernel functions are pure functions
Expand Down
9 changes: 7 additions & 2 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
Expand Down Expand Up @@ -129,8 +130,8 @@ def kernel(
"""

flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
flat_inverse_scale = ravel_pytree(momentum_inverse_scale)[0]
momentum_generator, kinetic_energy_fn, *_ = metrics.gaussian_euclidean(
flat_inverse_scale**2
)

Expand Down Expand Up @@ -248,6 +249,10 @@ def as_top_level_api(
A PyTree of the same structure as the target PyTree (position) with the
values used for as a step size for each dimension of the target space in
the velocity verlet integrator.
momentum_inverse_scale
Pytree with the same structure as the targeted position variable
specifying the per dimension inverse scaling transformation applied
to the persistent momentum variable prior to the integration step.
alpha
The value defining the persistence of the momentum variable.
delta
Expand Down
1 change: 1 addition & 0 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key):
)
# one step of the deterministic dynamics
state, info = integrator(state, step_size)

# partial refreshment
state = state._replace(
momentum=partially_refresh_momentum(
Expand Down
Loading

0 comments on commit 677dea7

Please sign in to comment.