You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The goal of this issue is to kickoff the discussion on the testing strategy. Hopefully we can unify names/concepts.
Our current testing landscape
E2E tests: they run a complete sampling over some model and assert over the shape of the result distribution. This may requiere several kernel steps.
Advantages:
checks that the sampling works for a simple model. All pieces of the model ‘fit’, in the sense that inputs of say the proposal_generator can be used by some other component like trajectory builders, etc.
checks that the output of one step can be input for the following one
checks properties that may only hold after many steps are taken (like stationarity?).
Disadvantages:
slow to run
if some regression is included, is difficult to find where it happen (slow developer feedback).
as ‘ifs’ are added, it becomes imposible to test the combination.
Integration tests: they verify that some components (for example, 2) can be used together. Without reaching user-level full grouping of objects. This means that they don’t group enough objects so as to be something the user would use directly like a full kernel.
For example, running one step of a kernel.
Some math properties may be checked (such as invariance of the distribution).
Some code properties may be checked: pieces work together correctly.
Unit tests: they verify the behavior of a single component. Small tests that run fast.
Some of the advantages/disadvantages listed above, are specific to our domain, (like the fact that certain properties of the output are only valid after several steps of execution), but some others are typical characteristics of E2E/big tests.
Our desired testing landscape
It is generally considered that a healthy codebase test suite has pyramid shape [1,2]: has a bigger set of unit tests, quick developer feedback, easy to understand, easy to catch regressions, complemented with a smaller set of integration tests that verify the connections between objects and finally a smaller set of big, slow, but user related E2E tests.
In systems design, it is always important to identify and differentiate between what Brooks [4] coined accidental complexity and essential complexity. Accidental complexity has to do with details specific to solving a problem, when executing in a computer. For example, the fact that blackjax runs on top of jax, and the code needs to be jittable for performance. This is complementary to essential complexity, which is proper of the problem being solved: for example, the fact that in order to run HMC we need to differentiate, or the fact that we need to sample to implement RMH acceptance rule. This is closely related to the concept of `domain logic’: the set of ideas that we want our code to represent, and the operations related to them. In our case, we are talking about the components to build samplers, and how they interoperate. We would like to be able clearly represent in code samplers, and thus be able to test them in a simple and fast way, trying to isolate the essential complexity from the accidental complexity.
Theory to practice
All these looks great in theory, but in practice Blackjax is a library with a heavy dependence on JAX. We want all our code to be Jittable and runnable in devices such as GPU, so we need to balance testing the domain logic while making sure certain accidental properties hold.
Quoting from [5]:
JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.
In order to account for this situation, the Jax ecosystem provides Chex, which among other things, exposes chex.TestCase that forces code compilation and emulates the presence/absence of devices such as GPU. Since the compilation is JIT, JAX requires us to run the code on JAX compilable inputs, in order to extract the shapes, and then compile the code given those shapes. This has as consequence that Test Doubles [3] libraries like MagicMock can't be used at all in chex test, since they are not JAX-compilable.
A real example
Let’s try to analize this code, which corresponds to the NUTS kernel.
Let’s start by noting that kernel is factory: it builds another object, in this case, the one_step function. But that’s not the only responsibility it has(I am talking about the SOLID notion of responsibility here), it also has one_step defined inside it, so all the behavior from one_step is also defined inside kernel. Now if we think about one_step, its main goal should be, given an input, execute one nuts step. But here, again, there are more responsibilities inside:
choosing gaussian_euclidean as metric
choosing nuts_proposal as proposal
applying the step
The main problem with this code from a testing standpoint is that all listed here ends up being tested together, there’s no way of decoupling the step test from an execution of the gaussian_euclidean code. Naturally, this leads to bigger tests, closer to E2E than to unitary. Searching for usages of this code, it seems is being used in test_compilation and test_benchmarks. If we remove those two, the file has 45% test coverage.
Design suggestion 1.
In this case, a suggestion is splitting between building/factory code vs properly sampling code.
# Copyright 2020- The Blackjax Authors.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Public API for the NUTS Kernel"""importfunctoolsfromtypingimportCallable, NamedTuple, Tupleimportjaximportjax.numpyasjnpimportnumpyasnpimportblackjax.mcmc.hmcashmcimportblackjax.mcmc.integratorsasintegratorsimportblackjax.mcmc.metricsasmetricsimportblackjax.mcmc.proposalasproposalimportblackjax.mcmc.terminationasterminationimportblackjax.mcmc.trajectoryastrajectoryfromblackjax.typesimportArray, PRNGKey, PyTree__all__= ["NUTSInfo", "init", "build_kernel"]
init=hmc.initclassNUTSInfo(NamedTuple):
"""Additional information on the NUTS transition. This additional information can be used for debugging or computing diagnostics. momentum: The momentum that was sampled and used to integrate the trajectory. is_divergent Whether the difference in energy between the original and the new state exceeded the divergence threshold. is_turning Whether the sampling returned because the trajectory started turning back on itself. energy: Energy of the transition. trajectory_leftmost_state The leftmost state of the full trajectory. trajectory_rightmost_state The rightmost state of the full trajectory. num_trajectory_expansions Number of subtrajectory samples that were taken. num_integration_steps Number of integration steps that were taken. This is also the number of states in the full trajectory. acceptance_rate average acceptance probabilty across entire trajectory """momentum: PyTreeis_divergent: boolis_turning: boolenergy: floattrajectory_leftmost_state: integrators.IntegratorStatetrajectory_rightmost_state: integrators.IntegratorStatenum_trajectory_expansions: intnum_integration_steps: intacceptance_rate: floatdefpropose_from_momentum(rng_key, momentum_generator, proposal_generator, state, step_size):
key_momentum, key_integrator=jax.random.split(rng_key, 2)
position, logdensity, logdensity_grad=statemomentum=momentum_generator(key_momentum, position)
integrator_state=integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)
proposal, info=proposal_generator(key_integrator, integrator_state, step_size)
proposal=hmc.HMCState(
proposal.position, proposal.logdensity, proposal.logdensity_grad
)
returnproposal, infodefbuild_kernel(
integrator: Callable=integrators.velocity_verlet,
divergence_threshold: int=1000,
max_num_doublings: int=10,
):
"""Build an iterative NUTS kernel. This algorithm is an iteration on the original NUTS algorithm :cite:p:`hoffman2014no` with two major differences: - We do not use slice samplig but multinomial sampling for the proposal :cite:p:`betancourt2017conceptual`; - The trajectory expansion is not recursive but iterative :cite:p:`phan2019composable`, :cite:p:`lao2020tfp`. The implementation can seem unusual for those familiar with similar algorithms. Indeed, we do not conceptualize the trajectory construction as building a tree. We feel that the tree lingo, inherited from the recursive version, is unnecessarily complicated and hides the more general concepts upon which the NUTS algorithm is built. NUTS, in essence, consists in sampling a trajectory by iteratively choosing a direction at random and integrating in this direction a number of times that doubles at every step. From this trajectory we continuously sample a proposal. When the trajectory turns on itself or when we have reached the maximum trajectory length we return the current proposal. Parameters ---------- integrator The simplectic integrator used to build trajectories. divergence_threshold The absolute difference in energy above which we consider a transition "divergent". max_num_doublings The maximum number of times we expand the trajectory by doubling the number of steps if the trajectory does not turn onto itself. """defkernel(
rng_key: PRNGKey,
state: hmc.HMCState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
) ->Tuple[hmc.HMCState, NUTSInfo]:
"""Generate a new sample with the NUTS kernel."""
(
momentum_generator,
kinetic_energy_fn,
uturn_check_fn,
) =metrics.gaussian_euclidean(inverse_mass_matrix)
symplectic_integrator=integrator(logdensity_fn, kinetic_energy_fn)
proposal_generator=build_iterative_nuts_proposal(
symplectic_integrator,
kinetic_energy_fn,
uturn_check_fn,
max_num_doublings,
divergence_threshold,
)
returnpropose_from_momentum(rng_key, momentum_generator, proposal_generator, state, step_size)
returnkerneldefnuts_proposal(rng_key,
initial_state: integrators.IntegratorState,
step_size,
new_termination_state,
expand,
proposal_generator):
initial_termination_state=new_termination_state(initial_state)
initial_proposal=proposal_generator(initial_state)
initial_trajectory=trajectory.zero_steps_trajectory(initial_state)
initial_expansion_state=trajectory.DynamicExpansionState(
0, initial_proposal, initial_trajectory, initial_termination_state
)
expansion_state, info=expand(
rng_key, initial_expansion_state, initial_proposal.energy, step_size
)
is_diverging, is_turning=infonum_doublings, sampled_proposal, new_trajectory, _=expansion_state# Compute average acceptance probability across entire trajectory,# even over subtrees that may have been rejectedacceptance_rate= (
jnp.exp(sampled_proposal.sum_log_p_accept) /new_trajectory.num_states
)
info=NUTSInfo(
initial_state.momentum,
is_diverging,
is_turning,
sampled_proposal.energy,
new_trajectory.leftmost_state,
new_trajectory.rightmost_state,
num_doublings,
new_trajectory.num_states,
acceptance_rate,
)
returnsampled_proposal.state, infodefbuild_iterative_nuts_proposal(
integrator: Callable,
kinetic_energy: Callable,
uturn_check_fn: Callable,
max_num_expansions: int=10,
divergence_threshold: float=1000,
) ->Callable:
"""Iterative NUTS proposal. Parameters ---------- integrator Symplectic integrator used to build the trajectory step by step. kinetic_energy Function that computes the kinetic energy. uturn_check_fn: Function that determines whether the trajectory is turning on itself (metric-dependant). step_size Size of the integration step. max_num_expansions The number of sub-trajectory samples we take to build the trajectory. divergence_threshold Threshold above which we say that there is a divergence. Returns ------- A kernel that generates a new chain state and information about the transition. """
(
new_termination_state,
update_termination_state,
is_criterion_met,
) =termination.iterative_uturn_numpyro(uturn_check_fn)
trajectory_integrator=trajectory.dynamic_progressive_integration(
integrator,
kinetic_energy,
update_termination_state,
is_criterion_met,
divergence_threshold,
)
expand=trajectory.dynamic_multiplicative_expansion(
trajectory_integrator,
uturn_check_fn,
max_num_expansions,
)
new, _=proposal.proposal_generator(trajectory.hmc_energy(kinetic_energy), np.inf)
returnfunctools.partial(nuts_proposal,
new_termination_state=lambdastate: new_termination_state(state, max_num_expansions),
expand=expand,
proposal_generator=new)
So this code looks far more testable than before, let’s try to do that. Lets first look at the components the file has
class NUTSInfo -> no behavior
def build_kernel -> integration test
def propose_from_momentum -> unit test
def nuts_proposal -> unit test
def build_iterative_nuts_proposal -> integration test
I’ll start with nuts_kernel. I could first try to build a test like this one:
classTestProposeFromMomentum(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)deftest_propose_from_momentum(self):
""" Given propose_from_momentum When calling it Then the proposal generator uses the momentum to generate a new proposal """state=HMCState(position=jnp.array([1., 2.]),
logdensity=0.5,
logdensity_grad=0.3)
key=jax.random.PRNGKey(42)
expected_info=jnp.array([1, 2, 3, 4])
defmomentum_generator(key, position):
returnjnp.array([50.0])
defproposal_generator(key_integrator, integrator_state, step_size):
returnIntegratorState(position=jnp.array([1., 2.]),
momentum=30,
logdensity=0.3,
logdensity_grad=0.8), expected_infodef_nuts_kernel(key, state, step_size):
returnpropose_from_momentum(rng_key=key,
momentum_generator=momentum_generator,
proposal_generator=proposal_generator,
state=state, step_size=step_size)
proposal, info=self.variant(_nuts_kernel)(key, state, 30)
np.testing.assert_allclose(proposal.position, jnp.array([1., 2.]))
np.testing.assert_allclose(proposal.logdensity, 0.3)
np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
np.testing.assert_allclose(expected_info, info)
Check that I have used a Stub for proposal_generator and a Dummy for expected info. ****
It works, and the assertions make sense. But there is room to break the code, let’s try that:
I have on purpose included a regression inside the code. If we rerun the above test, no failure is raised. There’s no assertion whatsoever that the integrator_state passed to the proposal_generator, nor any other parameter is correctly computed, which is part of this functions’ responsibility.
Two solutions to these types of problems: use Chex Runtime Assertions so that proposal_generator becomes a Mock, or use a Fake.
With Fake
classTestProposeFromMomentum(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)deftest_propose_from_momentum_with_fake(self):
""" Given propose_from_momentum When calling it Then the proposal generator uses the momentum to generate a new proposal """state=HMCState(position=jnp.array([1., 2.]),
logdensity=0.5,
logdensity_grad=0.3)
key=jax.random.PRNGKey(42)
expected_info=jnp.array([1, 2, 3, 4])
defmomentum_generator(key, position):
returnjnp.array([50.0])
defproposal_generator(key_integrator, integrator_state, step_size):
returnIntegratorState(position=(integrator_state.position+jnp.array([1., 2.])) *step_size,
momentum=30,
logdensity=0.3,
logdensity_grad=0.8), expected_infoproposal, info=self.variant(functools.partial(propose_from_momentum,
momentum_generator=momentum_generator,
proposal_generator=proposal_generator))(rng_key=key,
state=state,
step_size=30)
np.testing.assert_allclose(proposal.position, jnp.array([60., 120.]))
np.testing.assert_allclose(proposal.logdensity, 0.3)
np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
np.testing.assert_allclose(expected_info, info)
With Mock
classTestProposeFromMomentum(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)deftest_propose_from_momentum_with_mock(self):
""" Given propose_from_momentum When calling it Then the proposal generator uses the momentum to generate a new proposal """state=HMCState(position=jnp.array([1., 2.]),
logdensity=0.5,
logdensity_grad=0.3)
key=jax.random.PRNGKey(42)
expected_info=jnp.array([1, 2, 3, 4])
defmomentum_generator(key, position):
returnjnp.array([50.0])
@chex.chexifydefproposal_generator(key_integrator, integrator_state, step_size):
chex.assert_tree_all_close(state.position, integrator_state.position)
returnIntegratorState(position=jnp.array([1., 2.]),
momentum=30,
logdensity=0.3,
logdensity_grad=0.8), expected_infoproposal, info=self.variant(functools.partial(propose_from_momentum, momentum_generator=momentum_generator,
proposal_generator=proposal_generator)
)(rng_key=key, state=state, step_size=30)
np.testing.assert_allclose(proposal.position, jnp.array([1., 2.]))
np.testing.assert_allclose(proposal.logdensity, 0.3)
np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
np.testing.assert_allclose(expected_info, info)
Both approaches have advantages and disadvantages. The Mock is simpler to understand for
someone new to the algorithm. The Fake has the challenge that we need to think the simplest case that makes sense for the algorithm and someone else to understand, so it may require more explanation.
Summary of suggestions so far:
Having more smaller tests is going to help us improve developer feedback and make changes with confidence.
Splitting factories from functions is going to help us have more comprehensive tests.
We can use Chex runtime assertions to enhance our test doubles, considering we can't use testing double libraries within Jax.
Thank you for the detailed write-up - I will need some time to digest it, but just want to first expressed we appreciate how you put so much thought into it.
The goal of this issue is to kickoff the discussion on the testing strategy. Hopefully we can unify names/concepts.
Our current testing landscape
Some of the advantages/disadvantages listed above, are specific to our domain, (like the fact that certain properties of the output are only valid after several steps of execution), but some others are typical characteristics of E2E/big tests.
Our desired testing landscape
It is generally considered that a healthy codebase test suite has pyramid shape [1,2]: has a bigger set of unit tests, quick developer feedback, easy to understand, easy to catch regressions, complemented with a smaller set of integration tests that verify the connections between objects and finally a smaller set of big, slow, but user related E2E tests.
In systems design, it is always important to identify and differentiate between what Brooks [4] coined accidental complexity and essential complexity. Accidental complexity has to do with details specific to solving a problem, when executing in a computer. For example, the fact that blackjax runs on top of jax, and the code needs to be jittable for performance. This is complementary to essential complexity, which is proper of the problem being solved: for example, the fact that in order to run HMC we need to differentiate, or the fact that we need to sample to implement RMH acceptance rule. This is closely related to the concept of `domain logic’: the set of ideas that we want our code to represent, and the operations related to them. In our case, we are talking about the components to build samplers, and how they interoperate. We would like to be able clearly represent in code samplers, and thus be able to test them in a simple and fast way, trying to isolate the essential complexity from the accidental complexity.
Theory to practice
All these looks great in theory, but in practice Blackjax is a library with a heavy dependence on JAX. We want all our code to be Jittable and runnable in devices such as GPU, so we need to balance testing the domain logic while making sure certain accidental properties hold.
Quoting from [5]:
JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.
In order to account for this situation, the Jax ecosystem provides Chex, which among other things, exposes
chex.TestCase
that forces code compilation and emulates the presence/absence of devices such as GPU. Since the compilation is JIT, JAX requires us to run the code on JAX compilable inputs, in order to extract the shapes, and then compile the code given those shapes. This has as consequence that Test Doubles [3] libraries like MagicMock can't be used at all in chex test, since they are not JAX-compilable.A real example
Let’s try to analize this code, which corresponds to the NUTS kernel.
Let’s start by noting that kernel is factory: it builds another object, in this case, the one_step function. But that’s not the only responsibility it has(I am talking about the SOLID notion of responsibility here), it also has one_step defined inside it, so all the behavior from one_step is also defined inside kernel. Now if we think about one_step, its main goal should be, given an input, execute one nuts step. But here, again, there are more responsibilities inside:
The main problem with this code from a testing standpoint is that all listed here ends up being tested together, there’s no way of decoupling the step test from an execution of the gaussian_euclidean code. Naturally, this leads to bigger tests, closer to E2E than to unitary. Searching for usages of this code, it seems is being used in test_compilation and test_benchmarks. If we remove those two, the file has 45% test coverage.
Design suggestion 1.
In this case, a suggestion is splitting between building/factory code vs properly sampling code.
So this code looks far more testable than before, let’s try to do that. Lets first look at the components the file has
I’ll start with nuts_kernel. I could first try to build a test like this one:
Check that I have used a Stub for proposal_generator and a Dummy for expected info. ****
It works, and the assertions make sense. But there is room to break the code, let’s try that:
I have on purpose included a regression inside the code. If we rerun the above test, no failure is raised. There’s no assertion whatsoever that the integrator_state passed to the proposal_generator, nor any other parameter is correctly computed, which is part of this functions’ responsibility.
Two solutions to these types of problems: use Chex Runtime Assertions so that proposal_generator becomes a Mock, or use a Fake.
With Fake
With Mock
Both approaches have advantages and disadvantages. The Mock is simpler to understand for
someone new to the algorithm. The Fake has the challenge that we need to think the simplest case that makes sense for the algorithm and someone else to understand, so it may require more explanation.
Summary of suggestions so far:
[1] https://testing.googleblog.com/2015/04/just-say-no-to-more-end-to-end-tests.html
[2] https://martinfowler.com/articles/practical-test-pyramid.html
[3] https://martinfowler.com/bliki/TestDouble.html
[4] https://en.wikipedia.org/wiki/No_Silver_Bullet
[5] https://github.com/deepmind/chex
The text was updated successfully, but these errors were encountered: