Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fitlering of AdaptationInfo #674

Merged
merged 7 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion blackjax/adaptation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple
from typing import NamedTuple, Set

from blackjax.types import ArrayTree

Expand All @@ -25,3 +25,33 @@ class AdaptationInfo(NamedTuple):
state: NamedTuple
info: NamedTuple
adaptation_state: NamedTuple


def return_all_adapt_info(state, info, adaptation_state):
"""Return fully populated AdaptationInfo. Used for adaptation_info_fn
parameters of the adaptation algorithms.
"""
return AdaptationInfo(state, info, adaptation_state)


def get_filter_adapt_info_fn(
state_keys: Set[str] = set(),
info_keys: Set[str] = set(),
adapt_state_keys: Set[str] = set(),
):
"""Generate a function to filter what is saved in AdaptationInfo. Used
for adptation_info_fn parameters of the adaptation algorithms.
adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information
"""

def filter_tuple(tup, key_set):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, is there a jax tree function that does this kind of filtering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to use tree_map - that seems like the most I can offload to the jax tree functions

return tup._replace(**{k: None for k in tup._fields if k not in key_set})

def filter_fn(state, info, adaptation_state):
sample_state = filter_tuple(state, state_keys)
new_info = filter_tuple(info, info_keys)
new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys)

return AdaptationInfo(sample_state, new_info, new_adapt_state)

return filter_fn
14 changes: 9 additions & 5 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import blackjax.mcmc.dynamic_hmc as dynamic_hmc
import blackjax.optimizers.dual_averaging as dual_averaging
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size
Expand Down Expand Up @@ -278,6 +278,7 @@ def chees_adaptation(
jitter_amount: float = 1.0,
target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
decay_rate: float = 0.5,
adaptation_info_fn: Callable = return_all_adapt_info,
) -> AdaptationAlgorithm:
"""Adapt the step size and trajectory length (number of integration steps / step size)
parameters of the jittered HMC algorthm.
Expand Down Expand Up @@ -337,6 +338,11 @@ def chees_adaptation(
Float representing how much to favor recent iterations over earlier ones in the optimization
of step size and trajectory length. A value of 1 gives equal weight to all history. A value
of 0 gives weight only to the most recent iteration.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.

Returns
-------
Expand Down Expand Up @@ -411,10 +417,8 @@ def one_step(carry, rng_key):
info.is_divergent,
)

return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
return (new_states, new_adaptation_state), adaptation_info_fn(
new_states, info, new_adaptation_state
)

batch_init = jax.vmap(
Expand Down
14 changes: 9 additions & 5 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.base import AdaptationAlgorithm
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -165,6 +165,7 @@ def update(
def meads_adaptation(
logdensity_fn: Callable,
num_chains: int,
adaptation_info_fn: Callable = return_all_adapt_info,
) -> AdaptationAlgorithm:
"""Adapt the parameters of the Generalized HMC algorithm.

Expand Down Expand Up @@ -194,6 +195,11 @@ def meads_adaptation(
The log density probability density function from which we wish to sample.
num_chains
Number of chains used for cross-chain warm-up training.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.

Returns
-------
Expand Down Expand Up @@ -227,10 +233,8 @@ def one_step(carry, rng_key):
adaptation_state, new_states.position, new_states.logdensity_grad
)

return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
return (new_states, new_adaptation_state), adaptation_info_fn(
new_states, info, new_adaptation_state
)

def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000):
Expand Down
10 changes: 8 additions & 2 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

import blackjax.vi as vi
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
Expand Down Expand Up @@ -141,6 +141,7 @@ def pathfinder_adaptation(
logdensity_fn: Callable,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
adaptation_info_fn: Callable = return_all_adapt_info,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand All @@ -156,6 +157,11 @@ def pathfinder_adaptation(
The initial step size used in the algorithm.
target_acceptance_rate
The acceptance rate that we target during step size adaptation.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Expand Down Expand Up @@ -188,7 +194,7 @@ def one_step(carry, rng_key):
)
return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
adaptation_info_fn(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400):
Expand Down
10 changes: 8 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax
import jax.numpy as jnp

from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
mass_matrix_adaptation,
Expand Down Expand Up @@ -248,6 +248,7 @@ def window_adaptation(
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand Down Expand Up @@ -278,6 +279,11 @@ def window_adaptation(
The acceptance rate that we target during step size adaptation.
progress_bar
Whether we should display a progress bar.
adaptation_info_fn
Function to select the adaptation info returned. See return_all_adapt_info
and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all
information is saved - this can result in excessive memory usage if the
information is unused.
**extra_parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.
Expand Down Expand Up @@ -316,7 +322,7 @@ def one_step(carry, xs):

return (
(new_state, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
adaptation_info_fn(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
Expand Down
52 changes: 50 additions & 2 deletions tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import blackjax
from blackjax.adaptation import window_adaptation
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
from blackjax.util import run_inference_algorithm


Expand Down Expand Up @@ -34,7 +35,32 @@ def test_adaptation_schedule(num_steps, expected_schedule):
assert np.array_equal(adaptation_schedule, expected_schedule)


def test_chees_adaptation():
@pytest.mark.parametrize(
"adaptation_filters",
[
{
"filter_fn": return_all_adapt_info,
"return_sets": None,
},
{
"filter_fn": get_filter_adapt_info_fn(),
"return_sets": (set(), set(), set()),
},
{
"filter_fn": get_filter_adapt_info_fn(
{"logdensity"},
{"proposal"},
{"random_generator_arg", "step", "da_state"},
),
"return_sets": (
{"logdensity"},
{"proposal"},
{"random_generator_arg", "step", "da_state"},
),
},
],
)
def test_chees_adaptation(adaptation_filters):
logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
x, loc=0.0, scale=jnp.array([1.0, 10.0])
).sum()
Expand All @@ -47,7 +73,10 @@ def test_chees_adaptation():
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)

warmup = blackjax.chees_adaptation(
logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
logprob_fn,
num_chains=num_chains,
target_acceptance_rate=0.75,
adaptation_info_fn=adaptation_filters["filter_fn"],
)

initial_positions = jax.random.normal(init_key, (num_chains, 2))
Expand All @@ -66,6 +95,25 @@ def test_chees_adaptation():
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)

def check_attrs(attribute, keyset):
for name, param in getattr(warmup_info, attribute)._asdict().items():
print(name, param)
if name in keyset:
assert param is not None
else:
assert param is None

keysets = adaptation_filters["return_sets"]
if keysets is None:
keysets = (
warmup_info.state._fields,
warmup_info.info._fields,
warmup_info.adaptation_state._fields,
)
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
check_attrs(attribute, keysets[i])

np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1)
np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0)
52 changes: 49 additions & 3 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import blackjax
import blackjax.diagnostics as diagnostics
import blackjax.mcmc.random_walk
from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info
from blackjax.util import run_inference_algorithm


Expand Down Expand Up @@ -56,6 +57,27 @@ def rmh_proposal_distribution(rng_key, position):
},
]

window_adaptation_filters = [
{
"filter_fn": return_all_adapt_info,
"return_sets": None,
},
{
"filter_fn": get_filter_adapt_info_fn(),
"return_sets": (set(), set(), set()),
},
{
"filter_fn": get_filter_adapt_info_fn(
{"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"}
),
"return_sets": (
{"position"},
{"is_divergent"},
{"ss_state", "inverse_mass_matrix"},
),
},
]


class LinearRegressionTest(chex.TestCase):
"""Test sampling of a linear regression model."""
Expand Down Expand Up @@ -112,8 +134,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):

return samples

@parameterized.parameters(itertools.product(regression_test_cases, [True, False]))
def test_window_adaptation(self, case, is_mass_matrix_diagonal):
@parameterized.parameters(
itertools.product(
regression_test_cases, [True, False], window_adaptation_filters
)
)
def test_window_adaptation(
self, case, is_mass_matrix_diagonal, window_adapt_config
):
"""Test the HMC kernel and the Stan warmup."""
rng_key, init_key0, init_key1 = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
Expand All @@ -131,15 +159,33 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal):
logposterior_fn,
is_mass_matrix_diagonal,
progress_bar=True,
adaptation_info_fn=window_adapt_config["filter_fn"],
**case["parameters"],
)
(state, parameters), _ = warmup.run(
(state, parameters), info = warmup.run(
warmup_key,
case["initial_position"],
case["num_warmup_steps"],
)
inference_algorithm = case["algorithm"](logposterior_fn, **parameters)

def check_attrs(attribute, keyset):
for name, param in getattr(info, attribute)._asdict().items():
if name in keyset:
assert param is not None
else:
assert param is None

keysets = window_adapt_config["return_sets"]
if keysets is None:
keysets = (
info.state._fields,
info.info._fields,
info.adaptation_state._fields,
)
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
check_attrs(attribute, keysets[i])

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
)
Expand Down
Loading