From 8a22b0988452a2d1866380bf6daac1bab337f696 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 14 May 2024 17:20:36 -0700 Subject: [PATCH 1/7] enable AdaptationInfo filtering --- blackjax/adaptation/window_adaptation.py | 35 +++++++++++++++++++++++- tests/mcmc/test_sampling.py | 34 ++++++++++++++++++++--- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index e15121dc5..346f09173 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -241,6 +241,34 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: return init, update, final +def return_all_adapt_info(state, info, adaptation_state): + """Return fully populated AdaptationInfo. Used for adaptation_info_fn + parameter of window_adaptation + """ + return AdaptationInfo(state, info, adaptation_state) + +def get_filter_adapt_info_fn( + state_keys: set = {}, + info_keys: set = {}, + adapt_state_keys: set = {}, + ): + """Generate a function to filter what is saved in AdaptationInfo. Used + for adptation_info_fn parameter of window_adaptation. + adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information + """ + + def filter_tuple(tup, key_set): + return tup._replace(**{k: None for k in tup._fields if k not in key_set}) + + def filter_fn(state, info, adaptation_state): + sample_state = filter_tuple(state, state_keys) + new_info = filter_tuple(info, info_keys) + new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) + + return AdaptationInfo(sample_state, new_info, new_adapt_state) + return filter_fn + + def window_adaptation( algorithm, logdensity_fn: Callable, @@ -248,6 +276,7 @@ def window_adaptation( initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -278,6 +307,10 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn. By default all information is saved - this can + result in excessive memory usage if the information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -316,7 +349,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 39c1b811b..034b2ea35 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -56,6 +56,18 @@ def rmh_proposal_distribution(rng_key, position): }, ] +window_adaptation_filters = [ + {"filter_fn": blackjax.adaptation.window_adaptation.return_all_adapt_info, + "return_sets": None, + }, + {"filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn(), + "return_sets": ({}, {}, {}), + }, + {"filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn({"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}), + "return_sets": ({"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}), + }, +] + class LinearRegressionTest(chex.TestCase): """Test sampling of a linear regression model.""" @@ -112,8 +124,8 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): return samples - @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) - def test_window_adaptation(self, case, is_mass_matrix_diagonal): + @parameterized.parameters(itertools.product(regression_test_cases, [True, False], window_adaptation_filters)) + def test_window_adaptation(self, case, is_mass_matrix_diagonal, window_adapt_config): """Test the HMC kernel and the Stan warmup.""" rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) x_data = jax.random.normal(init_key0, shape=(1000, 1)) @@ -130,16 +142,30 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): case["algorithm"], logposterior_fn, is_mass_matrix_diagonal, - progress_bar=True, + progress_bar=False, + adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) - (state, parameters), _ = warmup.run( + (state, parameters), info = warmup.run( warmup_key, case["initial_position"], case["num_warmup_steps"], ) inference_algorithm = case["algorithm"](logposterior_fn, **parameters) + def check_attrs(attribute, keyset): + for name, param in getattr(info, attribute)._asdict().items(): + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = window_adapt_config["return_sets"] + if keysets is None: + keysets = (info.state._fields, info.info._fields, info.adaptation_state._fields) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + _, states, _ = run_inference_algorithm( inference_key, state, inference_algorithm, case["num_sampling_steps"] ) From 1abc952c3a343fb0c32f0c778aae34f8ad59f348 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 14 May 2024 17:23:27 -0700 Subject: [PATCH 2/7] revert progress_bar --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 034b2ea35..ece7310fe 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -142,7 +142,7 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal, window_adapt_con case["algorithm"], logposterior_fn, is_mass_matrix_diagonal, - progress_bar=False, + progress_bar=True, adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) From b070a5f82de101e80fb30192139cbcde44ef9941 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 14 May 2024 18:42:53 -0700 Subject: [PATCH 3/7] fix pre-commit --- blackjax/adaptation/window_adaptation.py | 12 ++++---- tests/mcmc/test_sampling.py | 37 ++++++++++++++++++------ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 346f09173..9283b7083 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Set import jax import jax.numpy as jnp @@ -247,11 +247,12 @@ def return_all_adapt_info(state, info, adaptation_state): """ return AdaptationInfo(state, info, adaptation_state) + def get_filter_adapt_info_fn( - state_keys: set = {}, - info_keys: set = {}, - adapt_state_keys: set = {}, - ): + state_keys: Set[str] = set(), + info_keys: Set[str] = set(), + adapt_state_keys: Set[str] = set(), +): """Generate a function to filter what is saved in AdaptationInfo. Used for adptation_info_fn parameter of window_adaptation. adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information @@ -266,6 +267,7 @@ def filter_fn(state, info, adaptation_state): new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) return AdaptationInfo(sample_state, new_info, new_adapt_state) + return filter_fn diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ece7310fe..3314c81cc 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -57,14 +57,23 @@ def rmh_proposal_distribution(rng_key, position): ] window_adaptation_filters = [ - {"filter_fn": blackjax.adaptation.window_adaptation.return_all_adapt_info, - "return_sets": None, + { + "filter_fn": blackjax.adaptation.window_adaptation.return_all_adapt_info, + "return_sets": None, }, - {"filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn(), - "return_sets": ({}, {}, {}), + { + "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn(), + "return_sets": ({}, {}, {}), }, - {"filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn({"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}), - "return_sets": ({"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}), + { + "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn( + {"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"} + ), + "return_sets": ( + {"position"}, + {"is_divergent"}, + {"ss_state", "inverse_mass_matrix"}, + ), }, ] @@ -124,8 +133,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): return samples - @parameterized.parameters(itertools.product(regression_test_cases, [True, False], window_adaptation_filters)) - def test_window_adaptation(self, case, is_mass_matrix_diagonal, window_adapt_config): + @parameterized.parameters( + itertools.product( + regression_test_cases, [True, False], window_adaptation_filters + ) + ) + def test_window_adaptation( + self, case, is_mass_matrix_diagonal, window_adapt_config + ): """Test the HMC kernel and the Stan warmup.""" rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) x_data = jax.random.normal(init_key0, shape=(1000, 1)) @@ -162,7 +177,11 @@ def check_attrs(attribute, keyset): keysets = window_adapt_config["return_sets"] if keysets is None: - keysets = (info.state._fields, info.info._fields, info.adaptation_state._fields) + keysets = ( + info.state._fields, + info.info._fields, + info.adaptation_state._fields, + ) for i, attribute in enumerate(["state", "info", "adaptation_state"]): check_attrs(attribute, keysets[i]) From de8859e9e043b8295aac50b4e658b2542487fdca Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 14 May 2024 18:45:54 -0700 Subject: [PATCH 4/7] fix empty sets --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 3314c81cc..faa1f7803 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -63,7 +63,7 @@ def rmh_proposal_distribution(rng_key, position): }, { "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn(), - "return_sets": ({}, {}, {}), + "return_sets": (set(), set(), set()), }, { "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn( From d8cfa4124ddc9dd1d3e87d4e4d8df3bcd86ff3f3 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 15 May 2024 12:50:13 -0700 Subject: [PATCH 5/7] enable adapt info filtering for all adaptation algorithms --- blackjax/adaptation/base.py | 32 +++++++++++- blackjax/adaptation/chees_adaptation.py | 14 ++++-- blackjax/adaptation/meads_adaptation.py | 14 ++++-- blackjax/adaptation/pathfinder_adaptation.py | 10 +++- blackjax/adaptation/window_adaptation.py | 39 ++------------- tests/adaptation/test_adaptation.py | 51 +++++++++++++++++++- tests/mcmc/test_sampling.py | 8 +-- 7 files changed, 115 insertions(+), 53 deletions(-) diff --git a/blackjax/adaptation/base.py b/blackjax/adaptation/base.py index e0a01e596..c637bacc3 100644 --- a/blackjax/adaptation/base.py +++ b/blackjax/adaptation/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple +from typing import NamedTuple, Set from blackjax.types import ArrayTree @@ -25,3 +25,33 @@ class AdaptationInfo(NamedTuple): state: NamedTuple info: NamedTuple adaptation_state: NamedTuple + + +def return_all_adapt_info(state, info, adaptation_state): + """Return fully populated AdaptationInfo. Used for adaptation_info_fn + parameters of the adaptation algorithms. + """ + return AdaptationInfo(state, info, adaptation_state) + + +def get_filter_adapt_info_fn( + state_keys: Set[str] = set(), + info_keys: Set[str] = set(), + adapt_state_keys: Set[str] = set(), +): + """Generate a function to filter what is saved in AdaptationInfo. Used + for adptation_info_fn parameters of the adaptation algorithms. + adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information + """ + + def filter_tuple(tup, key_set): + return tup._replace(**{k: None for k in tup._fields if k not in key_set}) + + def filter_fn(state, info, adaptation_state): + sample_state = filter_tuple(state, state_keys) + new_info = filter_tuple(info, info_keys) + new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) + + return AdaptationInfo(sample_state, new_info, new_adapt_state) + + return filter_fn diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index e81bbeef8..60b3e719f 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -10,7 +10,7 @@ import blackjax.mcmc.dynamic_hmc as dynamic_hmc import blackjax.optimizers.dual_averaging as dual_averaging -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -278,6 +278,7 @@ def chees_adaptation( jitter_amount: float = 1.0, target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE, decay_rate: float = 0.5, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the step size and trajectory length (number of integration steps / step size) parameters of the jittered HMC algorthm. @@ -337,6 +338,11 @@ def chees_adaptation( Float representing how much to favor recent iterations over earlier ones in the optimization of step size and trajectory length. A value of 1 gives equal weight to all history. A value of 0 gives weight only to the most recent iteration. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -411,10 +417,8 @@ def one_step(carry, rng_key): info.is_divergent, ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) batch_init = jax.vmap( diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index 8ed135fb5..a431a591d 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -165,6 +165,7 @@ def update( def meads_adaptation( logdensity_fn: Callable, num_chains: int, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the parameters of the Generalized HMC algorithm. @@ -194,6 +195,11 @@ def meads_adaptation( The log density probability density function from which we wish to sample. num_chains Number of chains used for cross-chain warm-up training. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -227,10 +233,8 @@ def one_step(carry, rng_key): adaptation_state, new_states.position, new_states.logdensity_grad ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000): diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index efcc55741..c0b4ebc50 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import blackjax.vi as vi -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.step_size import ( DualAveragingAdaptationState, dual_averaging_adaptation, @@ -141,6 +141,7 @@ def pathfinder_adaptation( logdensity_fn: Callable, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, + adaptation_info_fn: Callable = return_all_adapt_info, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -156,6 +157,11 @@ def pathfinder_adaptation( The initial step size used in the algorithm. target_acceptance_rate The acceptance rate that we target during step size adaptation. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -188,7 +194,7 @@ def one_step(carry, rng_key): ) return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 9283b7083..dd3e7b282 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Stan warmup for the HMC family of sampling algorithms.""" -from typing import Callable, NamedTuple, Set +from typing import Callable, NamedTuple import jax import jax.numpy as jnp -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -241,36 +241,6 @@ def final(warmup_state: WindowAdaptationState) -> tuple[float, Array]: return init, update, final -def return_all_adapt_info(state, info, adaptation_state): - """Return fully populated AdaptationInfo. Used for adaptation_info_fn - parameter of window_adaptation - """ - return AdaptationInfo(state, info, adaptation_state) - - -def get_filter_adapt_info_fn( - state_keys: Set[str] = set(), - info_keys: Set[str] = set(), - adapt_state_keys: Set[str] = set(), -): - """Generate a function to filter what is saved in AdaptationInfo. Used - for adptation_info_fn parameter of window_adaptation. - adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information - """ - - def filter_tuple(tup, key_set): - return tup._replace(**{k: None for k in tup._fields if k not in key_set}) - - def filter_fn(state, info, adaptation_state): - sample_state = filter_tuple(state, state_keys) - new_info = filter_tuple(info, info_keys) - new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) - - return AdaptationInfo(sample_state, new_info, new_adapt_state) - - return filter_fn - - def window_adaptation( algorithm, logdensity_fn: Callable, @@ -311,8 +281,9 @@ def window_adaptation( Whether we should display a progress bar. adaptation_info_fn Function to select the adaptation info returned. See return_all_adapt_info - and get_filter_adapt_info_fn. By default all information is saved - this can - result in excessive memory usage if the information is unused. + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index f54d18c21..f4728d2c3 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -34,7 +34,32 @@ def test_adaptation_schedule(num_steps, expected_schedule): assert np.array_equal(adaptation_schedule, expected_schedule) -def test_chees_adaptation(): +@pytest.mark.parametrize( + "adaptation_filters", + [ + { + "filter_fn": blackjax.adaptation.base.return_all_adapt_info, + "return_sets": None, + }, + { + "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn(), + "return_sets": (set(), set(), set()), + }, + { + "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + "return_sets": ( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + }, + ], +) +def test_chees_adaptation(adaptation_filters): logprob_fn = lambda x: jax.scipy.stats.norm.logpdf( x, loc=0.0, scale=jnp.array([1.0, 10.0]) ).sum() @@ -47,7 +72,10 @@ def test_chees_adaptation(): init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3) warmup = blackjax.chees_adaptation( - logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75 + logprob_fn, + num_chains=num_chains, + target_acceptance_rate=0.75, + adaptation_info_fn=adaptation_filters["filter_fn"], ) initial_positions = jax.random.normal(init_key, (num_chains, 2)) @@ -66,6 +94,25 @@ def test_chees_adaptation(): )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) + + def check_attrs(attribute, keyset): + for name, param in getattr(warmup_info, attribute)._asdict().items(): + print(name, param) + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = adaptation_filters["return_sets"] + if keysets is None: + keysets = ( + warmup_info.state._fields, + warmup_info.info._fields, + warmup_info.adaptation_state._fields, + ) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1) np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1) np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index faa1f7803..04f762fd9 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -58,15 +58,15 @@ def rmh_proposal_distribution(rng_key, position): window_adaptation_filters = [ { - "filter_fn": blackjax.adaptation.window_adaptation.return_all_adapt_info, + "filter_fn": blackjax.adaptation.base.return_all_adapt_info, "return_sets": None, }, { - "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn(), + "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn(), "return_sets": (set(), set(), set()), }, { - "filter_fn": blackjax.adaptation.window_adaptation.get_filter_adapt_info_fn( + "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn( {"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"} ), "return_sets": ( @@ -157,7 +157,7 @@ def test_window_adaptation( case["algorithm"], logposterior_fn, is_mass_matrix_diagonal, - progress_bar=True, + progress_bar=False, adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) From f0d86576099ba8e33a5ea197d6a89589176df570 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 15 May 2024 13:26:50 -0700 Subject: [PATCH 6/7] fix precommit /progressbar=True --- tests/adaptation/test_adaptation.py | 7 ++++--- tests/mcmc/test_sampling.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index f4728d2c3..8b0f55a7f 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -6,6 +6,7 @@ import blackjax from blackjax.adaptation import window_adaptation +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.util import run_inference_algorithm @@ -38,15 +39,15 @@ def test_adaptation_schedule(num_steps, expected_schedule): "adaptation_filters", [ { - "filter_fn": blackjax.adaptation.base.return_all_adapt_info, + "filter_fn": return_all_adapt_info, "return_sets": None, }, { - "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn(), + "filter_fn": get_filter_adapt_info_fn(), "return_sets": (set(), set(), set()), }, { - "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn( + "filter_fn": get_filter_adapt_info_fn( {"logdensity"}, {"proposal"}, {"random_generator_arg", "step", "da_state"}, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 04f762fd9..e4ac5978d 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,6 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.util import run_inference_algorithm @@ -58,15 +59,15 @@ def rmh_proposal_distribution(rng_key, position): window_adaptation_filters = [ { - "filter_fn": blackjax.adaptation.base.return_all_adapt_info, + "filter_fn": return_all_adapt_info, "return_sets": None, }, { - "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn(), + "filter_fn": get_filter_adapt_info_fn(), "return_sets": (set(), set(), set()), }, { - "filter_fn": blackjax.adaptation.base.get_filter_adapt_info_fn( + "filter_fn": get_filter_adapt_info_fn( {"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"} ), "return_sets": ( @@ -157,7 +158,7 @@ def test_window_adaptation( case["algorithm"], logposterior_fn, is_mass_matrix_diagonal, - progress_bar=False, + progress_bar=True, adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) From 2a26f85c4e25671a96edbddd960ff016b8c82d9e Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 16 May 2024 08:42:43 -0700 Subject: [PATCH 7/7] change filter tuple to use tree_map --- blackjax/adaptation/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/blackjax/adaptation/base.py b/blackjax/adaptation/base.py index c637bacc3..c510abaf4 100644 --- a/blackjax/adaptation/base.py +++ b/blackjax/adaptation/base.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import NamedTuple, Set +import jax + from blackjax.types import ArrayTree @@ -45,7 +47,8 @@ def get_filter_adapt_info_fn( """ def filter_tuple(tup, key_set): - return tup._replace(**{k: None for k in tup._fields if k not in key_set}) + mapfn = lambda key, val: None if key not in key_set else val + return jax.tree.map(mapfn, type(tup)(*tup._fields), tup) def filter_fn(state, info, adaptation_state): sample_state = filter_tuple(state, state_keys)