Skip to content

Commit

Permalink
docstring, paper
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Oct 22, 2024
1 parent 1ffc81d commit c94c633
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 22 deletions.
37 changes: 22 additions & 15 deletions paper/paper.bib
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,6 @@ @misc{blackjax
primaryClass={cs.MS}
}

@article{equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

@misc{mafs,
title={Masked Autoregressive Flow for Density Estimation},
author={George Papamakarios and Theo Pavlakou and Iain Murray},
Expand Down Expand Up @@ -132,4 +117,26 @@ @misc{npe
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/1905.07488},
}

@software{jax,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

@article{equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

@software{optax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/google-deepmind},
year = {2020},
}
74 changes: 67 additions & 7 deletions sbiax/inference/nle/nuts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,65 @@
from typing import Tuple, Callable
import blackjax.progress_bar
import jax
import jax.random as jr
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray, Array
import blackjax
from tensorflow_probability.substrates.jax.distributions import Distribution

def nuts_sample(
key: PRNGKeyArray,
log_prob_fn: Callable,
prior: Distribution,
n_samples: int = 100_000
) -> Tuple[Array, Array]:
"""
Runs NUTS (No-U-Turn Sampler) to sample from a posterior distribution using JAX.
def nuts_sample(key, log_prob_fn, prior=None, n_samples=100_000):
This function performs sampling using the NUTS algorithm, implemented via BlackJAX,
with an initial warm-up phase for tuning parameters. It uses the window adaptation
process to adjust the parameters during warm-up and runs the sampler in parallel
for multiple chains.
Args:
key: A JAX `PRNGKeyArray`.
log_prob_fn: A callable representing the log probability function of the
posterior distribution. This function should take a set of parameters
as input and return their log probability (`Callable`).
prior: A `tensorflow_probability` `Distribution` object representing the
prior distribution from which the initial parameter values are sampled (`Distribution`).
n_samples: The number of posterior samples to generate (`int`). Default is 100,000.
Returns:
Tuple[`Array`, `Array`]:
- The first array contains the sampled parameter positions from the NUTS algorithm
with shape `(n_samples,)` for one chain.
- The second array contains the log densities (log posterior probabilities) corresponding
to each sampled position with shape `(n_samples,)`.
Process:
1. The prior distribution is used to sample initial parameter values.
2. The NUTS sampler is tuned and adapted during the warm-up phase using `blackjax.window_adaptation`.
3. After warm-up, the function performs sampling over `n_samples` using the NUTS kernel
provided by `blackjax.nuts.build_kernel`, running the sampler for one or more chains.
4. The function returns the positions (sampled parameter values) and their associated
log densities from the posterior distribution.
Example:
```python
import jax
import blackjax
from tensorflow_probability.substrates.jax.distributions import Normal
def log_prob_fn(params):
# Typically this takes in a datavector of some kind!
return -0.5 * jnp.sum(params ** 2) # Example: Standard normal
key = jax.random.PRNGKey(0)
prior = Normal(0, 1)
samples, log_densities = nuts_sample(key, log_prob_fn, prior)
```
"""

def init_param_fn(seed):
return prior.sample(seed=seed)
Expand All @@ -28,12 +82,13 @@ def call_warmup(seed, param):
initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_params)

def inference_loop_multiple_chains(
rng_key, initial_states, tuned_params, log_prob_fn, n_samples, num_chains
rng_key,
initial_states,
tuned_params,
log_prob_fn,
n_samples,
num_chains
):
"""
Does this just step EACH sample once?
Need to run this for multiple steps?!
"""
kernel = blackjax.nuts.build_kernel()

def step_fn(key, state, **params):
Expand All @@ -52,7 +107,12 @@ def one_step(states, i):

key, sample_key = jr.split(key)
states, infos = inference_loop_multiple_chains(
sample_key, initial_states, tuned_params, log_prob_fn, n_samples, n_chains
sample_key,
initial_states,
tuned_params,
log_prob_fn,
n_samples,
n_chains
)

return states.position[:, 0], states.logdensity[:, 0]

0 comments on commit c94c633

Please sign in to comment.