diff --git a/.buildinfo b/.buildinfo index 025b1a1a3..a93fd10c6 100644 --- a/.buildinfo +++ b/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file records the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 85d59c9c35810c6b2db32f0a854c643a +config: 2ab4986d76b17ecc833ff035a2ace445 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/.doctrees/autoapi/blackjax/adaptation/adjusted_mclmc_adaptation/index.doctree b/.doctrees/autoapi/blackjax/adaptation/adjusted_mclmc_adaptation/index.doctree new file mode 100644 index 000000000..0c9c997f0 Binary files /dev/null and b/.doctrees/autoapi/blackjax/adaptation/adjusted_mclmc_adaptation/index.doctree differ diff --git a/.doctrees/autoapi/blackjax/adaptation/index.doctree b/.doctrees/autoapi/blackjax/adaptation/index.doctree index 5026589e5..616d820a0 100644 Binary files a/.doctrees/autoapi/blackjax/adaptation/index.doctree and b/.doctrees/autoapi/blackjax/adaptation/index.doctree differ diff --git a/.doctrees/autoapi/blackjax/adaptation/mclmc_adaptation/index.doctree b/.doctrees/autoapi/blackjax/adaptation/mclmc_adaptation/index.doctree index 1de429fa4..55fb3afb1 100644 Binary files a/.doctrees/autoapi/blackjax/adaptation/mclmc_adaptation/index.doctree and b/.doctrees/autoapi/blackjax/adaptation/mclmc_adaptation/index.doctree differ diff --git a/.doctrees/autoapi/blackjax/mcmc/adjusted_mclmc/index.doctree b/.doctrees/autoapi/blackjax/mcmc/adjusted_mclmc/index.doctree new file mode 100644 index 000000000..e2d4935cc Binary files /dev/null and b/.doctrees/autoapi/blackjax/mcmc/adjusted_mclmc/index.doctree differ diff --git a/.doctrees/autoapi/blackjax/mcmc/index.doctree b/.doctrees/autoapi/blackjax/mcmc/index.doctree index b7701569e..e1646584a 100644 Binary files a/.doctrees/autoapi/blackjax/mcmc/index.doctree and b/.doctrees/autoapi/blackjax/mcmc/index.doctree differ diff --git a/.doctrees/environment.pickle b/.doctrees/environment.pickle index 52d861f22..de2376da8 100644 Binary files a/.doctrees/environment.pickle and b/.doctrees/environment.pickle differ diff --git a/.doctrees/examples/howto_custom_gradients.doctree b/.doctrees/examples/howto_custom_gradients.doctree index b5c559e0a..0460c3a7c 100644 Binary files a/.doctrees/examples/howto_custom_gradients.doctree and b/.doctrees/examples/howto_custom_gradients.doctree differ diff --git a/.doctrees/examples/howto_metropolis_within_gibbs.doctree b/.doctrees/examples/howto_metropolis_within_gibbs.doctree index 61ad4c70f..575456583 100644 Binary files a/.doctrees/examples/howto_metropolis_within_gibbs.doctree and b/.doctrees/examples/howto_metropolis_within_gibbs.doctree differ diff --git a/.doctrees/examples/howto_other_frameworks.doctree b/.doctrees/examples/howto_other_frameworks.doctree index 493973dc2..e2999c887 100644 Binary files a/.doctrees/examples/howto_other_frameworks.doctree and b/.doctrees/examples/howto_other_frameworks.doctree differ diff --git a/.doctrees/examples/howto_sample_multiple_chains.doctree b/.doctrees/examples/howto_sample_multiple_chains.doctree index f207bae9c..11f72732c 100644 Binary files a/.doctrees/examples/howto_sample_multiple_chains.doctree and b/.doctrees/examples/howto_sample_multiple_chains.doctree differ diff --git a/.doctrees/examples/howto_use_aesara.doctree b/.doctrees/examples/howto_use_aesara.doctree index 6e6619eb5..b29547843 100644 Binary files a/.doctrees/examples/howto_use_aesara.doctree and b/.doctrees/examples/howto_use_aesara.doctree differ diff --git a/.doctrees/examples/howto_use_numpyro.doctree b/.doctrees/examples/howto_use_numpyro.doctree index 68d17345a..a476422d9 100644 Binary files a/.doctrees/examples/howto_use_numpyro.doctree and b/.doctrees/examples/howto_use_numpyro.doctree differ diff --git a/.doctrees/examples/howto_use_oryx.doctree b/.doctrees/examples/howto_use_oryx.doctree index 51db8a135..a606f894b 100644 Binary files a/.doctrees/examples/howto_use_oryx.doctree and b/.doctrees/examples/howto_use_oryx.doctree differ diff --git a/.doctrees/examples/howto_use_pymc.doctree b/.doctrees/examples/howto_use_pymc.doctree index d18452af4..241d7be1e 100644 Binary files a/.doctrees/examples/howto_use_pymc.doctree and b/.doctrees/examples/howto_use_pymc.doctree differ diff --git a/.doctrees/examples/howto_use_tfp.doctree b/.doctrees/examples/howto_use_tfp.doctree index 1d53ea25c..49d865ea6 100644 Binary files a/.doctrees/examples/howto_use_tfp.doctree and b/.doctrees/examples/howto_use_tfp.doctree differ diff --git a/.doctrees/examples/quickstart.doctree b/.doctrees/examples/quickstart.doctree index 71bbed9b7..d35918cac 100644 Binary files a/.doctrees/examples/quickstart.doctree and b/.doctrees/examples/quickstart.doctree differ diff --git a/_images/30151cf96c302689790ef3a3f854839083dd142e8ee7e5205af7e6427f690770.png b/_images/30151cf96c302689790ef3a3f854839083dd142e8ee7e5205af7e6427f690770.png new file mode 100644 index 000000000..fda3c2106 Binary files /dev/null and b/_images/30151cf96c302689790ef3a3f854839083dd142e8ee7e5205af7e6427f690770.png differ diff --git a/_images/33a3a728360fc8e933d6993544e9fff34b9e76a404e1d4040affbf502eafa39f.png b/_images/33a3a728360fc8e933d6993544e9fff34b9e76a404e1d4040affbf502eafa39f.png new file mode 100644 index 000000000..6b4ce4f7d Binary files /dev/null and b/_images/33a3a728360fc8e933d6993544e9fff34b9e76a404e1d4040affbf502eafa39f.png differ diff --git a/_images/47148b2d4fc86f80bbb19e8f82772b2fd9df45444fddc0d4751e7e4536c596c4.png b/_images/47148b2d4fc86f80bbb19e8f82772b2fd9df45444fddc0d4751e7e4536c596c4.png new file mode 100644 index 000000000..c0ea7af7e Binary files /dev/null and b/_images/47148b2d4fc86f80bbb19e8f82772b2fd9df45444fddc0d4751e7e4536c596c4.png differ diff --git a/_images/4ca7410e1335c9aaeb014a351e26e6b8c451290ec6ae183b3b0e17f25597749e.png b/_images/4ca7410e1335c9aaeb014a351e26e6b8c451290ec6ae183b3b0e17f25597749e.png deleted file mode 100644 index 13ddef655..000000000 Binary files a/_images/4ca7410e1335c9aaeb014a351e26e6b8c451290ec6ae183b3b0e17f25597749e.png and /dev/null differ diff --git a/_images/522e438db0b41742da015ffd1722198ddac7b1cac367632b1d493dd2188bde80.png b/_images/522e438db0b41742da015ffd1722198ddac7b1cac367632b1d493dd2188bde80.png deleted file mode 100644 index ec00dffd9..000000000 Binary files a/_images/522e438db0b41742da015ffd1722198ddac7b1cac367632b1d493dd2188bde80.png and /dev/null differ diff --git a/_images/57d2726ad1815ae171b46a27fb936e16f4973893a63fc9ae307afe87297cb6dc.png b/_images/57d2726ad1815ae171b46a27fb936e16f4973893a63fc9ae307afe87297cb6dc.png deleted file mode 100644 index 1a254ee62..000000000 Binary files a/_images/57d2726ad1815ae171b46a27fb936e16f4973893a63fc9ae307afe87297cb6dc.png and /dev/null differ diff --git a/_images/5d0e3bee35fc007f43f08168bddec27d2311a935b61140ebb2dc7aec2384b5a7.png b/_images/5d0e3bee35fc007f43f08168bddec27d2311a935b61140ebb2dc7aec2384b5a7.png new file mode 100644 index 000000000..406572867 Binary files /dev/null and b/_images/5d0e3bee35fc007f43f08168bddec27d2311a935b61140ebb2dc7aec2384b5a7.png differ diff --git a/_images/6cc068b371f698e9906873cbbdfe0c841b8651d750c2976a6651efcbd9f642a5.png b/_images/6cc068b371f698e9906873cbbdfe0c841b8651d750c2976a6651efcbd9f642a5.png new file mode 100644 index 000000000..8bda4dda0 Binary files /dev/null and b/_images/6cc068b371f698e9906873cbbdfe0c841b8651d750c2976a6651efcbd9f642a5.png differ diff --git a/_images/87b755145447df9e8345a79c2038bfe327352ae945e7a05759891bc6c6bdd3ed.png b/_images/87b755145447df9e8345a79c2038bfe327352ae945e7a05759891bc6c6bdd3ed.png new file mode 100644 index 000000000..379f5d944 Binary files /dev/null and b/_images/87b755145447df9e8345a79c2038bfe327352ae945e7a05759891bc6c6bdd3ed.png differ diff --git a/_images/90865f631acbe8c9182e89c18ac584d0ca7ceebca841b82bd00a3fedd248941b.png b/_images/90865f631acbe8c9182e89c18ac584d0ca7ceebca841b82bd00a3fedd248941b.png new file mode 100644 index 000000000..0cb3af758 Binary files /dev/null and b/_images/90865f631acbe8c9182e89c18ac584d0ca7ceebca841b82bd00a3fedd248941b.png differ diff --git a/_images/90867b15d083bf53a3189ef0afd8737f8b607760a73eea5c018ba6838f8b585e.png b/_images/90867b15d083bf53a3189ef0afd8737f8b607760a73eea5c018ba6838f8b585e.png deleted file mode 100644 index 0d8d9e087..000000000 Binary files a/_images/90867b15d083bf53a3189ef0afd8737f8b607760a73eea5c018ba6838f8b585e.png and /dev/null differ diff --git a/_images/9298628c8c92a85d183f424dc473737db5853e98389b890d64f5416b2ea715f7.png b/_images/9298628c8c92a85d183f424dc473737db5853e98389b890d64f5416b2ea715f7.png deleted file mode 100644 index 28e5e13d6..000000000 Binary files a/_images/9298628c8c92a85d183f424dc473737db5853e98389b890d64f5416b2ea715f7.png and /dev/null differ diff --git a/_images/a285f4e3be44d9af9c6a03d37e8554abf09b02ace3281b163313dff6300276a9.png b/_images/a285f4e3be44d9af9c6a03d37e8554abf09b02ace3281b163313dff6300276a9.png new file mode 100644 index 000000000..c9c2fa02f Binary files /dev/null and b/_images/a285f4e3be44d9af9c6a03d37e8554abf09b02ace3281b163313dff6300276a9.png differ diff --git a/_images/ab5af6945663321dd6a32df73162f38632823cd96201e5d16acd724c6a956980.png b/_images/ab5af6945663321dd6a32df73162f38632823cd96201e5d16acd724c6a956980.png deleted file mode 100644 index fe9c8db1f..000000000 Binary files a/_images/ab5af6945663321dd6a32df73162f38632823cd96201e5d16acd724c6a956980.png and /dev/null differ diff --git a/_images/ba1fe251807b6937f6dccb69bb577f04e7ff0888f406bc08c9ac0dba388fa759.png b/_images/ba1fe251807b6937f6dccb69bb577f04e7ff0888f406bc08c9ac0dba388fa759.png deleted file mode 100644 index 94015a20c..000000000 Binary files a/_images/ba1fe251807b6937f6dccb69bb577f04e7ff0888f406bc08c9ac0dba388fa759.png and /dev/null differ diff --git a/_images/dce209a2bd9d607c07e71737aae343d349049186b9f237f811a2e79945fca591.png b/_images/dce209a2bd9d607c07e71737aae343d349049186b9f237f811a2e79945fca591.png deleted file mode 100644 index a849f7a02..000000000 Binary files a/_images/dce209a2bd9d607c07e71737aae343d349049186b9f237f811a2e79945fca591.png and /dev/null differ diff --git a/_images/f3231baa9117215e2eae18a64d7752622ce0f6698ba3575aa9cf69036f4e26cd.png b/_images/f3231baa9117215e2eae18a64d7752622ce0f6698ba3575aa9cf69036f4e26cd.png deleted file mode 100644 index 77d5f1e4f..000000000 Binary files a/_images/f3231baa9117215e2eae18a64d7752622ce0f6698ba3575aa9cf69036f4e26cd.png and /dev/null differ diff --git a/_modules/blackjax/adaptation/adjusted_mclmc_adaptation.html b/_modules/blackjax/adaptation/adjusted_mclmc_adaptation.html new file mode 100644 index 000000000..aa45855d8 --- /dev/null +++ b/_modules/blackjax/adaptation/adjusted_mclmc_adaptation.html @@ -0,0 +1,822 @@ + + + + + + + + + + blackjax.adaptation.adjusted_mclmc_adaptation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + + + + +
+ +
+ + + + + +
+
+ + + + + + +
+ + + + + + + + + + + +
+ +
+ + + +
+ +
+
+ +
+
+ +
+ +
+ +
+ + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+
+ + + +
+

+ +
+
+ +
+
+
+ + + + +
+ +

Source code for blackjax.adaptation.adjusted_mclmc_adaptation

+import jax
+import jax.numpy as jnp
+from jax.flatten_util import ravel_pytree
+
+from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState
+from blackjax.adaptation.step_size import (
+    DualAveragingAdaptationState,
+    dual_averaging_adaptation,
+)
+from blackjax.diagnostics import effective_sample_size
+from blackjax.util import incremental_value_update, pytree_size
+
+
+[docs] +Lratio_lowerbound = 0.0
+ +
+[docs] +Lratio_upperbound = 2.0
+ + + +
+[docs] +def adjusted_mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + target, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + diagonal_preconditioning=True, + params=None, + max="avg", + num_windows=1, + tuning_factor=1.3, +): + """ + Finds the optimal value of the parameters for the MH-MCHMC algorithm. + + Parameters + ---------- + mclmc_kernel + The kernel function used for the MCMC algorithm. + num_steps + The number of MCMC steps that will subsequently be run, after tuning. + state + The initial state of the MCMC algorithm. + rng_key + The random number generator key. + target + The target acceptance rate for the step size adaptation. + frac_tune1 + The fraction of tuning for the first step of the adaptation. + frac_tune2 + The fraction of tuning for the second step of the adaptation. + frac_tune3 + The fraction of tuning for the third step of the adaptation. + diagonal_preconditioning + Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) + max + whether to calculate L from maximum or average eigenvalue. Average is advised. + num_windows + how many iterations of the tuning are carried out + tuning_factor + multiplicative factor for L + + + Returns + ------- + A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + """ + + frac_tune1 /= num_windows + frac_tune2 /= num_windows + frac_tune3 /= num_windows + + dim = pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + + for i in range(num_windows): + window_key = jax.random.fold_in(part1_key, i) + (state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + target=target, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )(state, params, num_steps, window_key) + + if frac_tune3 != 0: + for i in range(num_windows): + part2_key = jax.random.fold_in(part2_key, i) + part2_key1, part2_key2 = jax.random.split(part2_key, 2) + + state, params = adjusted_mclmc_make_adaptation_L( + mclmc_kernel, + frac=frac_tune3, + Lfactor=0.5, + max=max, + eigenvector=eigenvector, + )(state, params, num_steps, part2_key1) + + (state, params, _) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )(state, params, num_steps, part2_key2) + + return state, params
+ + + +
+[docs] +def adjusted_mclmc_make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + target, + diagonal_preconditioning, + fix_L_first_da=False, + max="avg", + tuning_factor=1.0, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC""" + + def dual_avg_step(fix_L, update_da): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + def step(iteration_state, weight_and_key): + mask, rng_key = weight_and_key + ( + previous_state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ) = iteration_state + + avg_num_integration_steps = params.L / params.step_size + + state, info = kernel( + rng_key=rng_key, + state=previous_state, + avg_num_integration_steps=avg_num_integration_steps, + step_size=params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + state, + params.step_size, + step_size_max, + info.energy, + ) + + with_mask = lambda x, y: mask * x + (1 - mask) * y + + log_step_size, log_step_size_avg, step, avg_error, mu = update_da( + adaptive_state, info.acceptance_rate + ) + + adaptive_state = DualAveragingAdaptationState( + with_mask(log_step_size, adaptive_state.log_step_size), + with_mask(log_step_size_avg, adaptive_state.log_step_size_avg), + with_mask(step, adaptive_state.step), + with_mask(avg_error, adaptive_state.avg_error), + with_mask(mu, adaptive_state.mu), + ) + + step_size = jax.lax.clamp( + 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 + ) + adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) + + x = ravel_pytree(state.position)[0] + + # update the running average of x, x^2 + previous_weight_and_average = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=previous_weight_and_average, + weight=(1 - mask) * success * step_size, + zero_prevention=mask, + ) + + params = params._replace(step_size=with_mask(step_size, params.step_size)) + if not fix_L: + params = params._replace( + L=with_mask(params.L * (step_size / params.step_size), params.L), + ) + + state_position = state.position + + return ( + state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ), ( + info, + state_position, + ) + + return step + + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): + return jax.lax.scan( + dual_avg_step(fix_L, update_da), + init=( + state, + params, + (initial_da(params.step_size), jnp.inf), # step size max + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, keys), + ) + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + + check_key, rng_key = jax.random.split(rng_key, 2) + + rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) + L_step_size_adaptation_keys_pass1 = jax.random.split( + rng_key_pass1, num_steps1 + num_steps2 + ) + L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) + + # determine which steps to ignore in the streaming average + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, position_samples), + ) = step_size_adaptation( + mask, + state, + params, + L_step_size_adaptation_keys_pass1, + fix_L=fix_L_first_da, + initial_da=initial_da, + update_da=update_da, + ) + + final_stepsize = final_da(dual_avg_state) + params = params._replace(step_size=final_stepsize) + + # determine L + eigenvector = None + if num_steps2 != 0.0: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + + if max == "max": + contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor + + elif max == "avg": + contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor + + else: + raise ValueError("max should be either 'max' or 'avg'") + + change = jax.lax.clamp( + Lratio_lowerbound, + contract(variances) / params.L, + Lratio_upperbound, + ) + params = params._replace( + L=params.L * change, step_size=params.step_size * change + ) + if diagonal_preconditioning: + params = params._replace( + sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) + ) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + jnp.ones(num_steps1), + state, + params, + L_step_size_adaptation_keys_pass2, + fix_L=True, + update_da=update_da, + initial_da=initial_da, + ) + + params = params._replace(step_size=final_da(dual_avg_state)) + + return state, params, eigenvector + + return L_step_size_adaptation
+ + + +
+[docs] +def adjusted_mclmc_make_adaptation_L( + kernel, frac, Lfactor, max="avg", eigenvector=None +): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + def step(state, key): + next_state, _ = kernel( + rng_key=key, + state=state, + step_size=params.step_size, + avg_num_integration_steps=params.L / params.step_size, + sqrt_diag_cov=params.sqrt_diag_cov, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan( + f=step, + init=state, + xs=adaptation_L_keys, + ) + + if max == "max": + contract = jnp.min + else: + contract = jnp.mean + + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + + if eigenvector is not None: + flat_samples = jnp.expand_dims( + jnp.einsum("ij,j", flat_samples, eigenvector), 1 + ) + + # number of effective samples per 1 actual sample + ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps + + return state, params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ) + + return adaptation_L
+ + + +
+[docs] +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): + """if there are nans, let's reduce the stepsize, and not update the state. The + function returns the old state in this case.""" + + reduced_step_size = 0.8 + p, unravel_fn = ravel_pytree(next_state.position) + nonans = jnp.all(jnp.isfinite(p)) + state, step_size, kinetic_change = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * reduced_step_size, 0.0), + ) + + return nonans, state, step_size, kinetic_change
+ +
+ +
+ + + + + + +
+ +
+
+
+ +
+ + + + +
+ + + +
+
+
+ + + + + + + + \ No newline at end of file diff --git a/_modules/blackjax/adaptation/base.html b/_modules/blackjax/adaptation/base.html index 46f808f75..7d65d4bb9 100644 --- a/_modules/blackjax/adaptation/base.html +++ b/_modules/blackjax/adaptation/base.html @@ -15,8 +15,8 @@ document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + + + + +
+ +
+ + + + + +
+
+ + + + + + +
+ + + + + + + + + + + +
+ +
+ + + +
+ +
+
+ +
+
+ +
+ +
+ +
+ + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+
+ + + +
+

+ +
+
+ +
+
+
+ + + + +
+ +

Source code for blackjax.mcmc.adjusted_mclmc

+# Copyright 2020- The Blackjax Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin"."""
+from typing import Callable, Union
+
+import jax
+import jax.numpy as jnp
+
+import blackjax.mcmc.integrators as integrators
+from blackjax.base import SamplingAlgorithm
+from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence
+from blackjax.mcmc.hmc import HMCInfo
+from blackjax.mcmc.proposal import static_binomial_sampling
+from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
+from blackjax.util import generate_unit_vector
+
+__all__ = ["init", "build_kernel", "as_top_level_api"]
+
+
+
+[docs] +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg)
+ + + +
+[docs] +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel
+ + + +
+[docs] +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type]
+ + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) +
+ +
+ + + + + + +
+ +
+
+
+ +
+ + + + +
+ + + +
+
+
+ + + + + + + + \ No newline at end of file diff --git a/_modules/blackjax/mcmc/barker.html b/_modules/blackjax/mcmc/barker.html index 0450f36fb..e44ac108d 100644 --- a/_modules/blackjax/mcmc/barker.html +++ b/_modules/blackjax/mcmc/barker.html @@ -15,8 +15,8 @@ document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + + + + +
+ +
+ + + + + +
+
+ + + + +
+ + + + + + + + + + + +
+ +
+ + + +
+ +
+
+ +
+
+ +
+ +
+ +
+ + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+
+ + + + + + + + +
+ +
+

blackjax.adaptation.adjusted_mclmc_adaptation#

+
+

Attributes#

+ +
+
+

Functions#

+
+ + + + + + + + + + + + + + +

adjusted_mclmc_find_L_and_step_size(mclmc_kernel, ...)

Finds the optimal value of the parameters for the MH-MCHMC algorithm.

adjusted_mclmc_make_L_step_size_adaptation(kernel, ...)

Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC

adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor)

determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)

handle_nans(previous_state, next_state, step_size, ...)

if there are nans, let's reduce the stepsize, and not update the state. The

+
+
+
+

Module Contents#

+
+
+Lratio_lowerbound = 0.0[source]#
+
+ +
+
+Lratio_upperbound = 2.0[source]#
+
+ +
+
+adjusted_mclmc_find_L_and_step_size(mclmc_kernel, num_steps, state, rng_key, target, frac_tune1=0.1, frac_tune2=0.1, frac_tune3=0.0, diagonal_preconditioning=True, params=None, max='avg', num_windows=1, tuning_factor=1.3)[source]#
+

Finds the optimal value of the parameters for the MH-MCHMC algorithm.

+
+
Parameters:
+
    +
  • mclmc_kernel – The kernel function used for the MCMC algorithm.

  • +
  • num_steps – The number of MCMC steps that will subsequently be run, after tuning.

  • +
  • state – The initial state of the MCMC algorithm.

  • +
  • rng_key – The random number generator key.

  • +
  • target – The target acceptance rate for the step size adaptation.

  • +
  • frac_tune1 – The fraction of tuning for the first step of the adaptation.

  • +
  • frac_tune2 – The fraction of tuning for the second step of the adaptation.

  • +
  • frac_tune3 – The fraction of tuning for the third step of the adaptation.

  • +
  • diagonal_preconditioning – Whether to do diagonal preconditioning (i.e. a mass matrix)

  • +
  • params – Initial params to start tuning from (optional)

  • +
  • max – whether to calculate L from maximum or average eigenvalue. Average is advised.

  • +
  • num_windows – how many iterations of the tuning are carried out

  • +
  • tuning_factor – multiplicative factor for L

  • +
+
+
Return type:
+

A tuple containing the final state of the MCMC algorithm and the final hyperparameters.

+
+
+
+ +
+
+adjusted_mclmc_make_L_step_size_adaptation(kernel, dim, frac_tune1, frac_tune2, target, diagonal_preconditioning, fix_L_first_da=False, max='avg', tuning_factor=1.0)[source]#
+

Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC

+
+ +
+
+adjusted_mclmc_make_adaptation_L(kernel, frac, Lfactor, max='avg', eigenvector=None)[source]#
+

determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)

+
+ +
+
+handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change)[source]#
+

if there are nans, let’s reduce the stepsize, and not update the state. The +function returns the old state in this case.

+
+ +
+
+ + +
+ + + + + + + + +
+ + + + + + + +
+ + + +
+
+
+ + + + + + + + \ No newline at end of file diff --git a/autoapi/blackjax/adaptation/base/index.html b/autoapi/blackjax/adaptation/base/index.html index e392c74ec..89cf0d275 100644 --- a/autoapi/blackjax/adaptation/base/index.html +++ b/autoapi/blackjax/adaptation/base/index.html @@ -16,8 +16,8 @@ document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + + + + +
+ +
+ + + + + +
+
+ + + + +
+ + + + + + + + + + + +
+ +
+ + + +
+ +
+
+ +
+
+ +
+ +
+ +
+ + +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ +
+
+ + + +
+

blackjax.mcmc.adjusted_mclmc

+ +
+
+ +
+

Contents

+
+ +
+
+
+ + + + +
+ +
+

blackjax.mcmc.adjusted_mclmc#

+

Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence “Hamiltonian” and not “Langevin”.

+
+

Functions#

+
+ + + + + + + + + + + +

init(position, logdensity_fn, random_generator_arg)

build_kernel(integration_steps_fn[, integrator, ...])

Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly.

as_top_level_api([1], integration_steps_fn, 1, ...)

Implements the (basic) user interface for the dynamic MHMCHMC kernel.

+
+
+
+

Module Contents#

+
+
+init(position: blackjax.types.ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: blackjax.types.Array)[source]#
+
+ +
+
+build_kernel(integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: ..., sqrt_diag_cov=1.0)[source]#
+

Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly.

+
+
Parameters:
+
    +
  • integrator – The integrator to use to integrate the Hamiltonian dynamics.

  • +
  • divergence_threshold – Value of the difference in energy above which we consider that the transition is divergent.

  • +
  • next_random_arg_fn – Function that generates the next random_generator_arg from its previous value.

  • +
  • integration_steps_fn – Function that generates the next pseudo or quasi-random number of integration steps in the +sequence, given the current random_generator_arg. Needs to return an int.

  • +
+
+
Returns:
+

    +
  • A kernel that takes a rng_key and a Pytree that contains the current state

  • +
  • of the chain and that returns a new state of the chain along with

  • +
  • information about the transition.

  • +
+

+
+
+
+ +
+
+as_top_level_api(logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, sqrt_diag_cov=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, next_random_arg_fn: Callable = lambda key: ..., integration_steps_fn: Callable = lambda key: ...) blackjax.base.SamplingAlgorithm[source]#
+

Implements the (basic) user interface for the dynamic MHMCHMC kernel.

+
+
Parameters:
+
    +
  • logdensity_fn – The log-density function we wish to draw samples from.

  • +
  • step_size – The value to use for the step size in the symplectic integrator.

  • +
  • divergence_threshold – The absolute value of the difference in energy between two states above +which we say that the transition is divergent. The default value is +commonly found in other libraries, and yet is arbitrary.

  • +
  • integrator – (algorithm parameter) The symplectic integrator to use to integrate the trajectory.

  • +
  • next_random_arg_fn – Function that generates the next random_generator_arg from its previous value.

  • +
  • integration_steps_fn – Function that generates the next pseudo or quasi-random number of integration steps in the +sequence, given the current random_generator_arg.

  • +
+
+
Return type:
+

A SamplingAlgorithm.

+
+
+
+ +
+
+ + +
+ + + + + + + + +
+ + + + +
+ + +
+ + + +
+
+
+ + + + + + + + \ No newline at end of file diff --git a/autoapi/blackjax/mcmc/barker/index.html b/autoapi/blackjax/mcmc/barker/index.html index 3e63ecb1d..e9d6b2199 100644 --- a/autoapi/blackjax/mcmc/barker/index.html +++ b/autoapi/blackjax/mcmc/barker/index.html @@ -16,8 +16,8 @@ document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; -