From 2946c53a8c80cc5651603e2236b959a6657a5f98 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Sun, 12 Feb 2023 11:57:06 +0000 Subject: [PATCH 1/4] Include first draft of guidelines of design for developer docs w/ skeletons for sampling and approximate inference algorithms --- .pre-commit-config.yaml | 10 ++ docs/developer/approximate_inf_algorithm.py | 137 +++++++++++++++++ docs/developer/guidelines.md | 41 +++++ docs/developer/sampling_algorithm.py | 159 ++++++++++++++++++++ docs/index.md | 9 ++ 5 files changed, 356 insertions(+) create mode 100644 docs/developer/approximate_inf_algorithm.py create mode 100644 docs/developer/guidelines.md create mode 100644 docs/developer/sampling_algorithm.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b12b82640..562f00c90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,11 @@ repos: rev: 6.0.0 hooks: - id: flake8 + exclude: | + (?x)^( + docs/developer/approximate_inf_algorithm.py| + docs/developer/sampling_algorithm.py + )$ - repo: https://github.com/psf/black rev: 23.1.0 hooks: @@ -33,6 +38,11 @@ repos: rev: v1.0.1 hooks: - id: mypy + exclude: | + (?x)^( + docs/developer/approximate_inf_algorithm.py| + docs/developer/sampling_algorithm.py + )$ - repo: https://github.com/nbQA-dev/nbQA rev: 1.3.1 hooks: diff --git a/docs/developer/approximate_inf_algorithm.py b/docs/developer/approximate_inf_algorithm.py new file mode 100644 index 000000000..67b9bb2e0 --- /dev/null +++ b/docs/developer/approximate_inf_algorithm.py @@ -0,0 +1,137 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, NamedTuple, Tuple + +import jax +from optax import GradientTransformation + +# import basic compoments that are already implemented +# or that you have implemented with a general structure +from blackjax.base import VIAlgorithm +from blackjax.types import PRNGKey, PyTree + +__all__ = [ + "ApproxInfState", + "ApproxInfInfo", + "init", + "sample", + "step", + "approx_inf_algorithm", +] + + +class ApproxInfState(NamedTuple): + """State of the approximate inference algorithm. + + Give an overview of the variables needed at each step and for sampling. + """ + + ... + + +class ApproxInfInfo(NamedTuple): + """Additional information on the algorithm transition. + + Given an overview of the collected values at each step of the approximation. + """ + + ... + + +def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs): + # build an inital state + state = ApproxInfState(...) + return state + + +def step( + rng_key: PRNGKey, + state: ApproxInfInfo, + logdensity_fn: Callable, + optimizer: GradientTransformation, + *args, + **kwargs, +) -> Tuple[ApproxInfState, ApproxInfInfo]: + """Approximate the target density using the some approximation. + + Parameters + ---------- + List and describe its parameters. + """ + # extract the previous parameters from the state + params = ... + # generate pseudorandom keys + key_other, key_update = jax.random.split(rng_key, 2) + # update the parameters and build a new state + new_state = ApproxInfState(...) + info = ApproxInfInfo(...) + + return new_state, info + + +def sample(rng_key: PRNGKey, state: ApproxInfState, num_samples: int = 1): + """Sample from the approximation.""" + # the sample should be a PyTree of the same structure as the `position` in the init function + samples = ... + return samples + + +class approx_inf_algorithm: + """Implements the (basic) user interface for the approximate inference method. + + Describe in detail the inner mechanism of the method and its use. + + Example + ------- + Illustrate the use of the algorithm. + + Parameters + ---------- + List and describe its parameters. + + Returns + ------- + A ``VIAlgorithm``. + """ + + init = staticmethod(init) + step = staticmethod(step) + sample = staticmethod(sample) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + optimizer: GradientTransformation, + *args, + **kwargs, + ) -> VIAlgorithm: + def init_fn(position: PyTree): + return cls.init(position, optimizer, ...) + + def step_fn(rng_key: PRNGKey, state): + return cls.step( + rng_key, + state, + logdensity_fn, + optimizer, + ..., + ) + + def sample_fn(rng_key: PRNGKey, state, num_samples): + return cls.sample(rng_key, state, num_samples) + + return VIAlgorithm(init_fn, step_fn, sample_fn) + + +# other functions that help make `init`,` `step` and/or `sample` easier to read and understand diff --git a/docs/developer/guidelines.md b/docs/developer/guidelines.md new file mode 100644 index 000000000..16b4d9a4e --- /dev/null +++ b/docs/developer/guidelines.md @@ -0,0 +1,41 @@ +# Developer Guidelines + +## Style +In its broadest sense, an algorithm that belongs in the blackjax library should approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with probabilities defined by a density (continuous variable) or measure (discrete variable) function. The idea is to sample more from areas with higher probability but also from areas with low probability, just at a lower rate. You can also approximate the target density directly, using an approximation that is easier to handle, then do inference, i.e. solve integrals, with the approximation directly and use importance sampling to correct its bias. + +In the following section, we’ll explain blackjax’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles: + +Leverage JAX's unique strengths: functional programming and composable function-transformation approach. +Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms. +Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax). +Write code that is easy to read and understand. +Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use. + +## Core implementation +There are three types of sampling algorithms blackjax currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, blackjax supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples. + +Basic components are functions, which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all that are already implemented *before* writing your own. A recurrent example is the Metropolis-Hastings step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In blackjax, this common accept/reject step done with two functions: first the Hastings ratio is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. + +Because JAX operates on pure functions, inference algorithms always return a NamedTuple containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of blackjax, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a NamedTuple with important information of each iteration. + +The user-facing interface of a **sampling algorithm** should work like this: +```python +import blackjax +sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs) +state = sampling_algorithm.init(initial_position) +new_state, info = sampling_algorithm.step(rng_key, state) +``` +Achieve this by building from the basic skeleton of a sampling algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py]. Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. + +The user-facing interface of an **approximate inference algorithm** should work like this: +```python +import blackjax +approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs) +state = approx_inf_algorithm.init(initial_position) +new_state, info = approx_inf_algorithm.step(rng_key, state) +#user is able to build the approximate distribution using the state, or generate samples: +position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples) +``` +Achieve this by building from the basic skeleton of an approximate inference algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py]. Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. + +Well documented code is essential for a useful library. Start by decomposing your algorithm into basic components, finding those that are already implemented, then implement your own and build the high-level API from basic components. diff --git a/docs/developer/sampling_algorithm.py b/docs/developer/sampling_algorithm.py new file mode 100644 index 000000000..df9a057dd --- /dev/null +++ b/docs/developer/sampling_algorithm.py @@ -0,0 +1,159 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, NamedTuple, Tuple + +import jax + +# import basic compoments that are already implemented +# or that you have implemented with a general structure +# for example, if you do a Metropolis-Hastings accept/reject step: +import blackjax.mcmc.proposal as proposal +from blackjax.base import MCMCSamplingAlgorithm +from blackjax.types import PRNGKey, PyTree + +__all__ = [ + "SamplingAlgoState", + "SamplingAlgoInfo", + "init", + "build_kernel", + "sampling_algorithm", +] + + +class SamplingAlgoState(NamedTuple): + """State of the sampling algorithm. + + Give an overview of the variables needed at each iteration of the model. + """ + + ... + + +class SamplingAlgoInfo(NamedTuple): + """Additional information on the algorithm transition. + + Given an overview of the collected values at each iteration of the model. + """ + + ... + + +def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs): + # build an inital state + state = SamplingAlgoState(...) + return state + + +def build_kernel(*args, **kwargs): + """Build a HMC kernel. + + Parameters + ---------- + List and describe its parameters. + + Returns + ------- + Describe the kernel that is returned. + """ + + def kernel( + rng_key: PRNGKey, + state: SamplingAlgoState, + logdensity_fn: Callable, + *args, + **kwargs, + ) -> Tuple[SamplingAlgoState, SamplingAlgoInfo]: + """Generate a new sample with the sampling kernel.""" + + # build everything you'll need + proposal_generator = sampling_algorithm_proposal(...) + + # generate pseudorandom keys + key_other, key_proposal = jax.random.split(rng_key, 2) + + # generate the proposal with all its parts + proposal, info = proposal_generator(key_proposal, ...) + proposal = SamplingAlgoState(...) + + return proposal, info + + return kernel + + +class sampling_algorithm: + """Implements the (basic) user interface for the sampling kernel. + + Describe in detail the inner mechanism of the algorithm and its use. + + Example + ------- + Illustrate the use of the algorithm. + + Parameters + ---------- + List and describe its parameters. + + Returns + ------- + A ``MCMCSamplingAlgorithm``. + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + *args, + **kwargs, + ) -> MCMCSamplingAlgorithm: + kernel = cls.build_kernel(...) + + def init_fn(position: PyTree): + return cls.init(position, logdensity_fn, ...) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + ..., + ) + + return MCMCSamplingAlgorithm(init_fn, step_fn) + + +# and other functions that help make `init` and/or `build_kernel` easier to read and understand +def sampling_algorithm_proposal(*args, **kwags) -> Callable: + """Title + + Description + + Parameters + ---------- + List and describe its parameters. + + Returns + ------- + Describe what is returned. + """ + ... + + def generate(*args, **kwargs): + """Generate a new chain state.""" + sampled_state, info = ... + + return sampled_state, info + + return generate diff --git a/docs/index.md b/docs/index.md index 59e4b3f0d..412309677 100644 --- a/docs/index.md +++ b/docs/index.md @@ -134,3 +134,12 @@ maxdepth: 2 API Reference Bibliography ``` + +```{toctree} +--- +maxdepth: 1 +caption: DEVELOPER DOCUMENTATION +hidden: +--- +Guidelines +``` From 6f94b9468dcf841a46055a5913fbda8c7d14dd36 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Fri, 12 May 2023 15:29:59 +0100 Subject: [PATCH 2/4] Guidelines with MH example and revised skeletons --- README.md | 4 +++ docs/developer/approximate_inf_algorithm.py | 12 +++---- docs/developer/guidelines.md | 36 +++++++++++---------- docs/developer/sampling_algorithm.py | 24 ++++++++++---- docs/index.md | 2 +- 5 files changed, 47 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 085e87a9d..6b5339ab4 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,10 @@ information related to the transition are returned separately. They can thus be easily composed and exchanged. We specialize these kernels by closure instead of passing parameters. +### New algorithms + +We hope to make implementing and testing new algorithms easy with BlackJAX. Many basic methods are already implemented in the library, and you can use them to test new algorithms. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your own method and test new ideas on existing methods without writing everything from scratch! + ## Contributions Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/main/CONTRIBUTING.md). diff --git a/docs/developer/approximate_inf_algorithm.py b/docs/developer/approximate_inf_algorithm.py index 67b9bb2e0..52f5a8c24 100644 --- a/docs/developer/approximate_inf_algorithm.py +++ b/docs/developer/approximate_inf_algorithm.py @@ -32,7 +32,7 @@ class ApproxInfState(NamedTuple): - """State of the approximate inference algorithm. + """State of your approximate inference algorithm. Give an overview of the variables needed at each step and for sampling. """ @@ -41,9 +41,9 @@ class ApproxInfState(NamedTuple): class ApproxInfInfo(NamedTuple): - """Additional information on the algorithm transition. + """Additional information on your algorithm transition. - Given an overview of the collected values at each step of the approximation. + Give an overview of the collected values at each step of the approximation. """ ... @@ -63,7 +63,7 @@ def step( *args, **kwargs, ) -> Tuple[ApproxInfState, ApproxInfInfo]: - """Approximate the target density using the some approximation. + """Approximate the target density using your approximation. Parameters ---------- @@ -81,14 +81,14 @@ def step( def sample(rng_key: PRNGKey, state: ApproxInfState, num_samples: int = 1): - """Sample from the approximation.""" + """Sample from your approximation.""" # the sample should be a PyTree of the same structure as the `position` in the init function samples = ... return samples class approx_inf_algorithm: - """Implements the (basic) user interface for the approximate inference method. + """Implements the (basic) user interface for your approximate inference method. Describe in detail the inner mechanism of the method and its use. diff --git a/docs/developer/guidelines.md b/docs/developer/guidelines.md index 16b4d9a4e..6d1d98695 100644 --- a/docs/developer/guidelines.md +++ b/docs/developer/guidelines.md @@ -1,41 +1,43 @@ # Developer Guidelines -## Style -In its broadest sense, an algorithm that belongs in the blackjax library should approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with probabilities defined by a density (continuous variable) or measure (discrete variable) function. The idea is to sample more from areas with higher probability but also from areas with low probability, just at a lower rate. You can also approximate the target density directly, using an approximation that is easier to handle, then do inference, i.e. solve integrals, with the approximation directly and use importance sampling to correct its bias. +In the broadest sense, an algorithm that belongs in the BlackJAX library should provide the tools to approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with [relative likelihood](https://en.wikipedia.org/wiki/Relative_likelihood) given by a target probability density function (known up to a normalization constant). The idea is to sample more from areas with higher likelihood but also from areas with low likelihood, just at a lower rate. You can also approximate the target density directly, using a density that is tractable and easy to sample from, then do inference with the approximation instead of the target, potentially using [importance sampling](https://en.wikipedia.org/wiki/Importance_sampling) to correct the approximation error. -In the following section, we’ll explain blackjax’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles: +In the following section, we’ll explain BlackJAX’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles: -Leverage JAX's unique strengths: functional programming and composable function-transformation approach. -Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms. -Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax). -Write code that is easy to read and understand. -Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use. +- Leverage JAX's unique strengths: functional programming and composable function-transformation approach. +- Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms. +- Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax). +- Write code that is easy to read and understand. +- Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use. ## Core implementation -There are three types of sampling algorithms blackjax currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, blackjax supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples. +There are three types of sampling algorithms BlackJAX currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples. -Basic components are functions, which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all that are already implemented *before* writing your own. A recurrent example is the Metropolis-Hastings step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In blackjax, this common accept/reject step done with two functions: first the Hastings ratio is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. +Basic components are functions which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm you should first break it down to its basic components then find and use all that are already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX there are two basic components that do a specific (but simpler) and a general version of this accept/reject step: -Because JAX operates on pure functions, inference algorithms always return a NamedTuple containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of blackjax, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a NamedTuple with important information of each iteration. +- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. +- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.asymmetric_proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. + +When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Just make sure to carry over to the next iteration an updated slice, instead of passing a pseudo-random number generating key, for the slice sampling step! + +The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also useful to personalize and test new algorithms which replace some steps of common efficient algorithms. Like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step. + +Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration. The user-facing interface of a **sampling algorithm** should work like this: ```python -import blackjax sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs) state = sampling_algorithm.init(initial_position) new_state, info = sampling_algorithm.step(rng_key, state) ``` -Achieve this by building from the basic skeleton of a sampling algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py]. Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. +Achieve this by building from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. The user-facing interface of an **approximate inference algorithm** should work like this: ```python -import blackjax approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs) state = approx_inf_algorithm.init(initial_position) new_state, info = approx_inf_algorithm.step(rng_key, state) #user is able to build the approximate distribution using the state, or generate samples: position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples) ``` -Achieve this by building from the basic skeleton of an approximate inference algorithm (here)[https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py]. Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. - -Well documented code is essential for a useful library. Start by decomposing your algorithm into basic components, finding those that are already implemented, then implement your own and build the high-level API from basic components. +Achieve this by building from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. diff --git a/docs/developer/sampling_algorithm.py b/docs/developer/sampling_algorithm.py index df9a057dd..747c9e23b 100644 --- a/docs/developer/sampling_algorithm.py +++ b/docs/developer/sampling_algorithm.py @@ -32,7 +32,7 @@ class SamplingAlgoState(NamedTuple): - """State of the sampling algorithm. + """State of your sampling algorithm. Give an overview of the variables needed at each iteration of the model. """ @@ -41,7 +41,7 @@ class SamplingAlgoState(NamedTuple): class SamplingAlgoInfo(NamedTuple): - """Additional information on the algorithm transition. + """Additional information on your algorithm transition. Given an overview of the collected values at each iteration of the model. """ @@ -56,7 +56,7 @@ def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs): def build_kernel(*args, **kwargs): - """Build a HMC kernel. + """Build a your kernel. Parameters ---------- @@ -92,7 +92,7 @@ def kernel( class sampling_algorithm: - """Implements the (basic) user interface for the sampling kernel. + """Implements the (basic) user interface for your sampling kernel. Describe in detail the inner mechanism of the algorithm and its use. @@ -148,10 +148,20 @@ def sampling_algorithm_proposal(*args, **kwags) -> Callable: ------- Describe what is returned. """ - ... + # as an example, a Metropolis-Hastings step would look like this: + init_proposal, generate_proposal = proposal.proposal_generator(...) + sample_proposal = proposal.static_binomial_sampling(...) + + def generate(rng_key, state): + # propose a new sample + proposal_state = ... + + # accept or reject the proposed sample + proposal = init_proposal(state) + new_proposal, is_diverging = generate_proposal(proposal.energy, proposal_state) + sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal) - def generate(*args, **kwargs): - """Generate a new chain state.""" + # build a new state and collect useful information sampled_state, info = ... return sampled_state, info diff --git a/docs/index.md b/docs/index.md index 412309677..f4f43a87f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -141,5 +141,5 @@ maxdepth: 1 caption: DEVELOPER DOCUMENTATION hidden: --- -Guidelines +Guidelines ``` From 8b2edea2429993370dff81b384619167b75d40d7 Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Thu, 22 Feb 2024 10:11:11 +0000 Subject: [PATCH 3/4] revise text and skeletons to new API --- README.md | 2 +- docs/developer/approximate_inf_algorithm.py | 6 +-- docs/developer/guidelines.md | 43 +++++++++++---------- docs/developer/sampling_algorithm.py | 27 ++++++------- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 6b5339ab4..f8937563b 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ passing parameters. ### New algorithms -We hope to make implementing and testing new algorithms easy with BlackJAX. Many basic methods are already implemented in the library, and you can use them to test new algorithms. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your own method and test new ideas on existing methods without writing everything from scratch! +We want to make implementing and testing new algorithms easy with BlackJAX. You can test new algorithms by reusing the basic components of the many known methods already implemented in the library. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your method and test new ideas on existing methods without writing everything from scratch. ## Contributions diff --git a/docs/developer/approximate_inf_algorithm.py b/docs/developer/approximate_inf_algorithm.py index 52f5a8c24..803de00c3 100644 --- a/docs/developer/approximate_inf_algorithm.py +++ b/docs/developer/approximate_inf_algorithm.py @@ -19,7 +19,7 @@ # import basic compoments that are already implemented # or that you have implemented with a general structure from blackjax.base import VIAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ "ApproxInfState", @@ -49,7 +49,7 @@ class ApproxInfInfo(NamedTuple): ... -def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs): +def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs): # build an inital state state = ApproxInfState(...) return state @@ -116,7 +116,7 @@ def __new__( # type: ignore[misc] *args, **kwargs, ) -> VIAlgorithm: - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, optimizer, ...) def step_fn(rng_key: PRNGKey, state): diff --git a/docs/developer/guidelines.md b/docs/developer/guidelines.md index 6d1d98695..a0222cc9d 100644 --- a/docs/developer/guidelines.md +++ b/docs/developer/guidelines.md @@ -1,43 +1,44 @@ # Developer Guidelines -In the broadest sense, an algorithm that belongs in the BlackJAX library should provide the tools to approximate integrals on a probability space. An introduction to probability theory is outside the scope of this document, but the Monte Carlo method is ever-present and important to understand. In simple terms, we want to approximate an integral with a sum. To do this, generate samples with [relative likelihood](https://en.wikipedia.org/wiki/Relative_likelihood) given by a target probability density function (known up to a normalization constant). The idea is to sample more from areas with higher likelihood but also from areas with low likelihood, just at a lower rate. You can also approximate the target density directly, using a density that is tractable and easy to sample from, then do inference with the approximation instead of the target, potentially using [importance sampling](https://en.wikipedia.org/wiki/Importance_sampling) to correct the approximation error. - In the following section, we’ll explain BlackJAX’s design of different algorithms for Monte Carlo integration. Keep in mind some basic principles: - Leverage JAX's unique strengths: functional programming and composable function-transformation approach. -- Write small and general functions, compose them to create complex methods, reuse the same building blocks for similar algorithms. +- Write small and general functions, compose them to create complex methods, and reuse the same building blocks for similar algorithms. - Consider compatibility with the broader JAX ecosystem (Flax, Optax, GPJax). - Write code that is easy to read and understand. -- Write code that is well documented, describe in detail the inner mechanism of the algorithm and its use. +- Write well-documented code describing in detail the inner mechanism of the algorithm and its use. ## Core implementation -There are three types of sampling algorithms BlackJAX currently supports: Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), and Stochastic Gradient MCMC (SGMCMC); and one type of approximate inference algorithm: Variational Inference (VI). Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples. - -Basic components are functions which do specific tasks but are generally applicable, used to build all inference algorithms. When implementing a new inference algorithm you should first break it down to its basic components then find and use all that are already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX there are two basic components that do a specific (but simpler) and a general version of this accept/reject step: - -- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. -- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.asymmetric_proposal_generator`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. - -When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Just make sure to carry over to the next iteration an updated slice, instead of passing a pseudo-random number generating key, for the slice sampling step! - -The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also useful to personalize and test new algorithms which replace some steps of common efficient algorithms. Like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step. - -Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so it must be done in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration. +BlackJAX supports sampling algorithms such as Markov Chain Monte Carlo (MCMC), Sequential Monte Carlo (SMC), Stochastic Gradient MCMC (SGMCMC), and approximate inference algorithms such as Variational Inference (VI). In all cases, BlackJAX takes a Markovian approach, whereby its current state contains all the information to obtain the next iteration of an algorithm. This naturally results in a functionally pure structure, where no side-effects are allowed, simplifying parallelisation. Additionally, BlackJAX supports adaptation algorithms that efficiently tune the hyperparameters of sampling algorithms, usually aimed at reducing autocorrelation between sequential samples. -The user-facing interface of a **sampling algorithm** should work like this: +The user-facing interface of a **sampling algorithm** is made up of an initializer and an iterator: ```python +# Generic sampling algorithm: sampling_algorithm = blackjax.sampling_algorithm(logdensity_fn, *args, **kwargs) state = sampling_algorithm.init(initial_position) new_state, info = sampling_algorithm.step(rng_key, state) ``` -Achieve this by building from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. +Build from the basic skeleton of a sampling algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/sampling_algorithm.py). Only the `sampling_algorithm` class and the `init` and `build_kernel` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary. -The user-facing interface of an **approximate inference algorithm** should work like this: +The user-facing interface of an **approximate inference algorithm** is made up of an initializer, iterator, and sampler: ```python +# Generic approximate inference algorithm: approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs) state = approx_inf_algorithm.init(initial_position) new_state, info = approx_inf_algorithm.step(rng_key, state) -#user is able to build the approximate distribution using the state, or generate samples: +# user is able to build the approximate distribution using the state, or generate samples: position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples) ``` -Achieve this by building from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm, the rest might become useful but are not necessary. +Build from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary. + +## Basic components +All inference algorithms are composed of basic components which provide the lowest level of algorithm abstraction and are available to the user. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX, two basic components do a specific (but simpler) and a general version of this accept/reject step: + +- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated using `mcmc.proposal.safe_energy_diff`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.hmc.hmc_proposal`. +- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.compute_asymmetric_acceptance_ratio`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.mala.build_kernel`. + +When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Make sure to carry over to the next iteration an updated slice for the slice sampling step, instead of passing a pseudo-random number generating key! + +The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also to personalize and test new algorithms that replace some steps of standard efficient algorithms, like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` only with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step. + +Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so you must do it in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration. diff --git a/docs/developer/sampling_algorithm.py b/docs/developer/sampling_algorithm.py index 747c9e23b..5ee1e6679 100644 --- a/docs/developer/sampling_algorithm.py +++ b/docs/developer/sampling_algorithm.py @@ -19,8 +19,8 @@ # or that you have implemented with a general structure # for example, if you do a Metropolis-Hastings accept/reject step: import blackjax.mcmc.proposal as proposal -from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.base import SamplingAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ "SamplingAlgoState", @@ -49,7 +49,7 @@ class SamplingAlgoInfo(NamedTuple): ... -def init(position: PyTree, logdensity_fn: Callable, *args, **kwargs): +def init(position: ArrayLikeTree, logdensity_fn: Callable, *args, **kwargs): # build an inital state state = SamplingAlgoState(...) return state @@ -117,10 +117,10 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, *args, **kwargs, - ) -> MCMCSamplingAlgorithm: + ) -> SamplingAlgorithm: kernel = cls.build_kernel(...) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn, ...) def step_fn(rng_key: PRNGKey, state): @@ -131,7 +131,7 @@ def step_fn(rng_key: PRNGKey, state): ..., ) - return MCMCSamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn) # and other functions that help make `init` and/or `build_kernel` easier to read and understand @@ -148,20 +148,21 @@ def sampling_algorithm_proposal(*args, **kwags) -> Callable: ------- Describe what is returned. """ - # as an example, a Metropolis-Hastings step would look like this: - init_proposal, generate_proposal = proposal.proposal_generator(...) - sample_proposal = proposal.static_binomial_sampling(...) + # as an example, a Metropolis-Hastings step with symmetric a symmetric transition would look like this: + acceptance_ratio = proposal.safe_energy_diff + sample_proposal = proposal.static_binomial_sampling def generate(rng_key, state): # propose a new sample proposal_state = ... # accept or reject the proposed sample - proposal = init_proposal(state) - new_proposal, is_diverging = generate_proposal(proposal.energy, proposal_state) - sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal) + initial_energy = ... + proposal_energy = ... + new_proposal, is_diverging = acceptance_ratio(initial_energy, proposal_energy) + sampled_state, info = sample_proposal(rng_key, proposal, new_proposal) - # build a new state and collect useful information + # maybe add to the returned state and collect more useful information sampled_state, info = ... return sampled_state, info From ff87bac085c2cb97d8ba2ca56b0058285aa04e0c Mon Sep 17 00:00:00 2001 From: Alberto Cabezas Gonzalez Date: Fri, 23 Feb 2024 11:02:10 +0000 Subject: [PATCH 4/4] Improve text --- README.md | 4 ++-- docs/developer/guidelines.md | 13 ++++++------- docs/index.md | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f8937563b..b925c5753 100644 --- a/README.md +++ b/README.md @@ -129,9 +129,9 @@ information related to the transition are returned separately. They can thus be easily composed and exchanged. We specialize these kernels by closure instead of passing parameters. -### New algorithms +## New algorithms -We want to make implementing and testing new algorithms easy with BlackJAX. You can test new algorithms by reusing the basic components of the many known methods already implemented in the library. Follow the [guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your method and test new ideas on existing methods without writing everything from scratch. +We want to make implementing and testing new algorithms easy with BlackJAX. You can test new algorithms by reusing the basic components of the many known methods already implemented in the library. Follow the [developer guidelines](https://blackjax-devs.github.io/blackjax/developer/guidelines.html) to implement your method and test new ideas on existing methods without writing everything from scratch. ## Contributions diff --git a/docs/developer/guidelines.md b/docs/developer/guidelines.md index a0222cc9d..09ed7ac9f 100644 --- a/docs/developer/guidelines.md +++ b/docs/developer/guidelines.md @@ -26,19 +26,18 @@ The user-facing interface of an **approximate inference algorithm** is made up o approx_inf_algorithm = blackjax.approx_inf_algorithm(logdensity_fn, optimizer, *args, **kwargs) state = approx_inf_algorithm.init(initial_position) new_state, info = approx_inf_algorithm.step(rng_key, state) -# user is able to build the approximate distribution using the state, or generate samples: position_samples = approx_inf_algorithm.sample(rng_key, state, num_samples) ``` Build from the basic skeleton of an approximate inference algorithm [here](https://github.com/blackjax-devs/blackjax/tree/main/docs/developer/approximate_inf_algorithm.py). Only the `approx_inf_algorithm` class and the `init`, `step` and `sample` functions need to be in the final version of your algorithm; the rest might be useful but are not necessary. ## Basic components -All inference algorithms are composed of basic components which provide the lowest level of algorithm abstraction and are available to the user. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all already implemented *before* writing your own. A recurrent example is the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX, two basic components do a specific (but simpler) and a general version of this accept/reject step: +All inference algorithms are composed of basic components which provide the lowest level of algorithm abstraction and are available to the user. When implementing a new inference algorithm, you should first break it down to its basic components, then find and use all already implemented *before* writing your own. For example, the [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) step, a basic component used by many MCMC algorithms to keep the target distribution invariant. In BlackJAX, two basic components do a specific (but simpler) and a general version of this accept/reject step: -- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated using `mcmc.proposal.safe_energy_diff`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.hmc.hmc_proposal`. -- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.compute_asymmetric_acceptance_ratio`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For example, see `mcmc.mala.build_kernel`. +- Metropolis step: if the proposal transition kernel is symmetric, i.e. if the probability of going from the initial to the proposed position is always equal to the probability of going from the proposed to the initial position, the acceptance probability is calculated using `mcmc.proposal.safe_energy_diff`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For instance, see `mcmc.hmc.hmc_proposal`. +- Metropolis-Hastings step: for the more general case of an asymmetric proposal transition kernel, the acceptance probability is calculated by creating a proposal using `mcmc.proposal.compute_asymmetric_acceptance_ratio`, then the proposal is accepted or rejected using `mcmc.proposal.static_binomial_sampling`. For instance, see `mcmc.mala.build_kernel`. -When implementing an algorithm you could choose to replace the classic, reversible Metropolis-Hastings step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Make sure to carry over to the next iteration an updated slice for the slice sampling step, instead of passing a pseudo-random number generating key! +When implementing an algorithm you could choose to replace the reversible binomial sampling step with Neal's [non-reversible slice sampling](https://arxiv.org/abs/2001.11950) step by simply replacing `mcmc.proposal.static_binomial_sampling` with `mcmc.proposal.nonreversible_slice_sampling` on either of the previous implementations. Make sure to carry over to the next iteration an updated slice for the slice sampling step, instead of passing a pseudo-random number generating key! -The previous example illustrates the power of basic components, useful not only to avoid rewriting the same methods for each new algorithm but also to personalize and test new algorithms that replace some steps of standard efficient algorithms, like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` only with a persistent momentum and a non-reversible slice sampling step instead of the Metropolis-Hastings step. +The previous example illustrates the practicality of basic components: they avoid rewriting the same methods and allow to easily test new algorithms that customize established algorithms, like how `blackjax.mcmc.ghmc` is `blackjax.mcmc.hmc` only with a persistent momentum and a non-reversible slice sampling step instead of the static binomial sampling step. -Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so you must do it in a way that abstracts the uninteresting bookkeeping from the end user but allows her to access important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration. +Because JAX operates on pure functions, inference algorithms always return a `typing.NamedTuple` containing the necessary variables to generate the next sample. Arguably, abstracting the handling of these variables is the whole point of BlackJAX, so you must do it in a way that abstracts the uninteresting bookkeeping and allows access to important variables at each step. The algorithms should also return a `typing.NamedTuple` with important information about each iteration. diff --git a/docs/index.md b/docs/index.md index f4f43a87f..e4ede97f1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -141,5 +141,5 @@ maxdepth: 1 caption: DEVELOPER DOCUMENTATION hidden: --- -Guidelines +Developer Guidelines ```