Skip to content

Commit

Permalink
Enable fitlering of AdaptationInfo (#674)
Browse files Browse the repository at this point in the history
* enable AdaptationInfo filtering

* revert progress_bar

* fix pre-commit

* fix empty sets

* enable adapt info filtering for all adaptation algorithms

* fix precommit /progressbar=True

* change filter tuple to use tree_map
  • Loading branch information
andrewdipper authored May 16, 2024
1 parent af79fa4 commit cd91e41
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 20 deletions.
35 changes: 34 additions & 1 deletion blackjax/adaptation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple
from typing import NamedTuple, Set

import jax

from blackjax.types import ArrayTree

Expand All @@ -25,3 +27,34 @@ class AdaptationInfo(NamedTuple):
state: NamedTuple
info: NamedTuple
adaptation_state: NamedTuple


def return_all_adapt_info(state, info, adaptation_state):
"""Return fully populated AdaptationInfo. Used for adaptation_info_fn
parameters of the adaptation algorithms.
"""
return AdaptationInfo(state, info, adaptation_state)


def get_filter_adapt_info_fn(
state_keys: Set[str] = set(),
info_keys: Set[str] = set(),
adapt_state_keys: Set[str] = set(),
):
"""Generate a function to filter what is saved in AdaptationInfo. Used
for adptation_info_fn parameters of the adaptation algorithms.
adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information
"""

def filter_tuple(tup, key_set):
mapfn = lambda key, val: None if key not in key_set else val
return jax.tree.map(mapfn, type(tup)(*tup._fields), tup)

def filter_fn(state, info, adaptation_state):
sample_state = filter_tuple(state, state_keys)
new_info = filter_tuple(info, info_keys)
new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys)

return AdaptationInfo(sample_state, new_info, new_adapt_state)

return filter_fn
14 changes: 9 additions & 5 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import blackjax.mcmc.dynamic_hmc as dynamic_hmc
import blackjax.optimizers.dual_averaging as dual_averaging
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size
Expand Down Expand Up @@ -278,6 +278,7 @@ def chees_adaptation(
jitter_amount: float = 1.0,
target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
decay_rate: float = 0.5,
adaptation_info_fn: Callable = return_all_adapt_info,
) -> AdaptationAlgorithm:
"""Adapt the step size and trajectory length (number of integration steps / step size)
parameters of the jittered HMC algorthm.
Expand Down Expand Up @@ -337,6 +338,11 @@ def chees_adaptation(
Float representing how much to favor recent iterations over earlier ones in the optimization
of step size and trajectory length. A value of 1 gives equal weight to all history. A value
of 0 gives weight only to the most recent iteration.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
Returns
-------
Expand Down Expand Up @@ -411,10 +417,8 @@ def one_step(carry, rng_key):
info.is_divergent,
)

return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
return (new_states, new_adaptation_state), adaptation_info_fn(
new_states, info, new_adaptation_state
)

batch_init = jax.vmap(
Expand Down
14 changes: 9 additions & 5 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -165,6 +165,7 @@ def update(
def meads_adaptation(
logdensity_fn: Callable,
num_chains: int,
adaptation_info_fn: Callable = return_all_adapt_info,
) -> AdaptationAlgorithm:
"""Adapt the parameters of the Generalized HMC algorithm.
Expand Down Expand Up @@ -194,6 +195,11 @@ def meads_adaptation(
The log density probability density function from which we wish to sample.
num_chains
Number of chains used for cross-chain warm-up training.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
Returns
-------
Expand Down Expand Up @@ -227,10 +233,8 @@ def one_step(carry, rng_key):
adaptation_state, new_states.position, new_states.logdensity_grad
)

return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
return (new_states, new_adaptation_state), adaptation_info_fn(
new_states, info, new_adaptation_state
)

def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000):
Expand Down
10 changes: 8 additions & 2 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

import blackjax.vi as vi
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
Expand Down Expand Up @@ -141,6 +141,7 @@ def pathfinder_adaptation(
logdensity_fn: Callable,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
adaptation_info_fn: Callable = return_all_adapt_info,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand All @@ -156,6 +157,11 @@ def pathfinder_adaptation(
The initial step size used in the algorithm.
target_acceptance_rate
The acceptance rate that we target during step size adaptation.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Expand Down Expand Up @@ -188,7 +194,7 @@ def one_step(carry, rng_key):
)
return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
adaptation_info_fn(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400):
Expand Down
10 changes: 8 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax
import jax.numpy as jnp

from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
mass_matrix_adaptation,
Expand Down Expand Up @@ -248,6 +248,7 @@ def window_adaptation(
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand Down Expand Up @@ -278,6 +279,11 @@ def window_adaptation(
The acceptance rate that we target during step size adaptation.
progress_bar
Whether we should display a progress bar.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Expand Down Expand Up @@ -316,7 +322,7 @@ def one_step(carry, xs):

return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
adaptation_info_fn(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
Expand Down
52 changes: 50 additions & 2 deletions tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import blackjax
from blackjax.adaptation import window_adaptation
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
from blackjax.util import run_inference_algorithm


Expand Down Expand Up @@ -34,7 +35,32 @@ def test_adaptation_schedule(num_steps, expected_schedule):
assert np.array_equal(adaptation_schedule, expected_schedule)


def test_chees_adaptation():
@pytest.mark.parametrize(
"adaptation_filters",
[
{
"filter_fn": return_all_adapt_info,
"return_sets": None,
},
{
"filter_fn": get_filter_adapt_info_fn(),
"return_sets": (set(), set(), set()),
},
{
"filter_fn": get_filter_adapt_info_fn(
{"logdensity"},
{"proposal"},
{"random_generator_arg", "step", "da_state"},
),
"return_sets": (
{"logdensity"},
{"proposal"},
{"random_generator_arg", "step", "da_state"},
),
},
],
)
def test_chees_adaptation(adaptation_filters):
logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
x, loc=0.0, scale=jnp.array([1.0, 10.0])
).sum()
Expand All @@ -47,7 +73,10 @@ def test_chees_adaptation():
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)

warmup = blackjax.chees_adaptation(
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
logprob_fn,
num_chains=num_chains,
target_acceptance_rate=0.75,
adaptation_info_fn=adaptation_filters["filter_fn"],
)

initial_positions = jax.random.normal(init_key, (num_chains, 2))
Expand All @@ -66,6 +95,25 @@ def test_chees_adaptation():
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)

def check_attrs(attribute, keyset):
for name, param in getattr(warmup_info, attribute)._asdict().items():
print(name, param)
if name in keyset:
assert param is not None
else:
assert param is None

keysets = adaptation_filters["return_sets"]
if keysets is None:
keysets = (
warmup_info.state._fields,
warmup_info.info._fields,
warmup_info.adaptation_state._fields,
)
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
check_attrs(attribute, keysets[i])

np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1)
np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0)
52 changes: 49 additions & 3 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import blackjax
import blackjax.diagnostics as diagnostics
import blackjax.mcmc.random_walk
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
from blackjax.util import run_inference_algorithm


Expand Down Expand Up @@ -56,6 +57,27 @@ def rmh_proposal_distribution(rng_key, position):
},
]

window_adaptation_filters = [
{
"filter_fn": return_all_adapt_info,
"return_sets": None,
},
{
"filter_fn": get_filter_adapt_info_fn(),
"return_sets": (set(), set(), set()),
},
{
"filter_fn": get_filter_adapt_info_fn(
{"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}
),
"return_sets": (
{"position"},
{"is_divergent"},
{"ss_state", "inverse_mass_matrix"},
),
},
]


class LinearRegressionTest(chex.TestCase):
"""Test sampling of a linear regression model."""
Expand Down Expand Up @@ -112,8 +134,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):

return samples

@parameterized.parameters(itertools.product(regression_test_cases, [True, False]))
def test_window_adaptation(self, case, is_mass_matrix_diagonal):
@parameterized.parameters(
itertools.product(
regression_test_cases, [True, False], window_adaptation_filters
)
)
def test_window_adaptation(
self, case, is_mass_matrix_diagonal, window_adapt_config
):
"""Test the HMC kernel and the Stan warmup."""
rng_key, init_key0, init_key1 = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
Expand All @@ -131,15 +159,33 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
logposterior_fn,
is_mass_matrix_diagonal,
progress_bar=True,
adaptation_info_fn=window_adapt_config["filter_fn"],
**case["parameters"],
)
(state, parameters), _ = warmup.run(
(state, parameters), info = warmup.run(
warmup_key,
case["initial_position"],
case["num_warmup_steps"],
)
inference_algorithm = case["algorithm"](logposterior_fn, **parameters)

def check_attrs(attribute, keyset):
for name, param in getattr(info, attribute)._asdict().items():
if name in keyset:
assert param is not None
else:
assert param is None

keysets = window_adapt_config["return_sets"]
if keysets is None:
keysets = (
info.state._fields,
info.info._fields,
info.adaptation_state._fields,
)
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
check_attrs(attribute, keysets[i])

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
)
Expand Down

0 comments on commit cd91e41

Please sign in to comment.