Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update run_inference_algorithm to split initial_position and initial_state #672

Merged
merged 17 commits into from
May 20, 2024
Merged
107 changes: 83 additions & 24 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
from jax.random import normal, split
from jax.tree_util import tree_leaves

from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm
from blackjax.base import SamplingAlgorithm, VIAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -142,12 +143,15 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree:

def run_inference_algorithm(
rng_key: PRNGKey,
initial_state_or_position: ArrayLikeTree,
inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm],
num_steps: int,
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
) -> tuple[State, State, Info]:
return_state_history=True,
expectation: Callable = lambda x: x,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we combine the kwargs of transform and expectation for now? Otherwise we need a better name for expectation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expectation seems to me like an appropriate name, since indeed the value calculated is the expectation. For example, if expectation=lambda x: x**2, then you get back $E[x^2]$.

By contrast, transform operates on the full history of samples. In the future, I think it would make sense for it to also take Info, so that the user can choose to dispense with (part of) the diagnostic info. That is why I thought it was better to keep them separate.

I think it would in theory be possible to make them the same, if necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, but looking at the implementation, the return state is always transform(state), which means for a expectation transformation, we should probably do expectation(transform(x))?

) -> tuple:
"""Wrapper to run an inference algorithm.

Note that this utility function does not work for Stochastic Gradient MCMC samplers
Expand All @@ -158,9 +162,10 @@ def run_inference_algorithm(
----------
rng_key
The random state used by JAX's random numbers generator.
initial_state_or_position
The initial state OR the initial position of the inference algorithm. If an initial position
is passed in, the function will automatically convert it into an initial state.
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Expand All @@ -171,34 +176,88 @@ def run_inference_algorithm(
A transformation of the trace of states to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck.

Returns
-------
Tuple[State, State, Info]
1. The final state of the inference algorithm.
2. The trace of states of the inference algorithm (contains the MCMC samples).
If return_state_history is True:
1. The final state.
2. The trace of the state.
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
"""
init_key, sample_key = split(rng_key, 2)
try:
initial_state = inference_algorithm.init(initial_state_or_position, init_key)
except (TypeError, ValueError, AttributeError):
# We assume initial_state is already in the right format.
initial_state = initial_state_or_position

keys = split(sample_key, num_steps)
if initial_state is None and initial_position is None:
raise ValueError("Either initial_state or initial_position must be provided.")
if initial_state is not None and initial_position is not None:
raise ValueError(
"Only one of initial_state or initial_position must be provided."
)

@jit
def _one_step(state, xs):
rng_key, init_key = split(rng_key, 2)
if initial_position is not None:
initial_state = inference_algorithm.init(initial_position, init_key)

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
return state, (transform(state), info)
average = streaming_average(expectation, state, average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None

one_step = jax.jit(partial(one_step, return_state=return_state_history))

if progress_bar:
one_step = progress_bar_scan(num_steps)(_one_step)
else:
one_step = _one_step
one_step = progress_bar_scan(num_steps)(one_step)

xs = (jnp.arange(num_steps), keys)
final_state, (state_history, info_history) = lax.scan(one_step, initial_state, xs)
return final_state, state_history, info_history
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(initial_state)), initial_state), xs
)

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history


def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0):
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
O
function to be averaged
x
current state
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
streaming_avg
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new streaming average
"""

expectation = O(x)
flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
)
total += weight
streaming_avg = (total, unravel_fn(average))
return streaming_avg
7 changes: 6 additions & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def test_chees_adaptation(adaptation_filters):

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=algorithm,
num_steps=num_results,
)
)(chain_keys, last_states)

harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
Expand Down
42 changes: 32 additions & 10 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):

_, samples, _ = run_inference_algorithm(
rng_key=run_key,
initial_state_or_position=blackjax_state_after_tuning,
initial_state=blackjax_state_after_tuning,
inference_algorithm=sampling_alg,
num_steps=num_steps,
transform=lambda x: x.position,
Expand Down Expand Up @@ -187,7 +187,10 @@ def check_attrs(attribute, keyset):
check_attrs(attribute, keysets[i])

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, case["num_sampling_steps"]
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=case["num_sampling_steps"],
)

coefs_samples = states.position["coefs"]
Expand All @@ -209,7 +212,12 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
_, states, _ = run_inference_algorithm(inference_key, state, mala, 10_000)
_, states, _ = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=mala,
num_steps=10_000,
)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -275,7 +283,10 @@ def test_pathfinder_adaptation(
inference_algorithm = algorithm(logposterior_fn, **parameters)

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, num_sampling_steps
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)

coefs_samples = states.position["coefs"]
Expand Down Expand Up @@ -316,7 +327,10 @@ def test_meads(self):
chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=100,
)
)(chain_keys, last_states)

Expand Down Expand Up @@ -360,7 +374,10 @@ def test_chees(self, jitter_generator):
chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
lambda key, state: run_inference_algorithm(
key, state, inference_algorithm, 100
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=100,
)
)(chain_keys, last_states)

Expand All @@ -384,7 +401,12 @@ def test_barker(self):
barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})

_, states, _ = run_inference_algorithm(inference_key, state, barker, 10_000)
_, states, _ = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=barker,
num_steps=10_000,
)

coefs_samples = states.position["coefs"][3000:]
scale_samples = np.exp(states.position["log_scale"][3000:])
Expand Down Expand Up @@ -570,7 +592,7 @@ def test_latent_gaussian(self):
inference_algorithm=inference_algorithm,
num_steps=self.sampling_steps,
),
)(self.key, initial_state)
)(rng_key=self.key, initial_state=initial_state)

np.testing.assert_allclose(
np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2
Expand Down Expand Up @@ -614,7 +636,7 @@ def univariate_normal_test_case(
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
)
)(inference_key, initial_state)
)(rng_key=inference_key, initial_state=initial_state)

# else:
if postprocess_samples:
Expand Down Expand Up @@ -885,7 +907,7 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
)
)
_, states, _ = inference_loop_multiple_chains(
multi_chain_sample_key, initial_states
rng_key=multi_chain_sample_key, initial_state=initial_states
)

posterior_samples = states.position[:, -1000:]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def run_regression(algorithm, **parameters):
inference_algorithm = algorithm(logdensity_fn, **parameters)

_, states, _ = run_inference_algorithm(
inference_key, state, inference_algorithm, 10_000
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=10_000,
)

return states
Expand Down
63 changes: 55 additions & 8 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,70 @@ def setUp(self):
)
self.num_steps = 10

def check_compatible(self, initial_state_or_position, progress_bar):
def check_compatible(self, initial_state, progress_bar):
"""
Runs 10 steps with `run_inference_algorithm` starting with
`initial_state_or_position` and potentially a progress bar.
`initial_state` and potentially a progress bar.
"""
_ = run_inference_algorithm(
self.key,
initial_state_or_position,
self.algorithm,
self.num_steps,
progress_bar,
rng_key=self.key,
initial_state=initial_state,
inference_algorithm=self.algorithm,
num_steps=self.num_steps,
progress_bar=progress_bar,
transform=lambda x: x.position,
)

def test_streaming(self):
def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))

initial_position = jnp.ones(
10,
)

init_key, run_key = jax.random.split(self.key, 2)

initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1)

_, states, info = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x.position,
transform=lambda x: x.position,
return_state_history=True,
)

average, _ = run_inference_algorithm(
rng_key=run_key,
initial_state=initial_state,
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x.position,
transform=lambda x: x.position,
return_state_history=False,
)

assert jnp.allclose(states.mean(axis=0), average)

@parameterized.parameters([True, False])
def test_compatible_with_initial_pos(self, progress_bar):
self.check_compatible(jnp.array([1.0, 1.0]), progress_bar)
_ = run_inference_algorithm(
rng_key=self.key,
initial_position=jnp.array([1.0, 1.0]),
inference_algorithm=self.algorithm,
num_steps=self.num_steps,
progress_bar=progress_bar,
transform=lambda x: x.position,
)

@parameterized.parameters([True, False])
def test_compatible_with_initial_state(self, progress_bar):
Expand Down
Loading