From 9bdbc699e6fcda494c337078183f45ed362940ac Mon Sep 17 00:00:00 2001 From: homerjed Date: Mon, 7 Oct 2024 14:26:24 +0200 Subject: [PATCH] paper, rename loaders, readme --- .github/workflows/joss.yml | 35 +++++ README.md | 13 +- data/__init__.py | 2 +- data/cifar10.py | 6 +- data/flowers.py | 6 +- data/grfs.py | 6 +- data/mnist.py | 6 +- data/moons.py | 6 +- data/quijote.py | 18 +-- data/utils.py | 10 +- paper/paper.bib | 270 +++++++++++++++++++++++++++++-------- paper/paper.md | 86 ++++++------ 12 files changed, 327 insertions(+), 137 deletions(-) create mode 100644 .github/workflows/joss.yml diff --git a/.github/workflows/joss.yml b/.github/workflows/joss.yml new file mode 100644 index 0000000..37130c0 --- /dev/null +++ b/.github/workflows/joss.yml @@ -0,0 +1,35 @@ +name: Draft JOSS paper PDF +on: + push: + paths: + - paper/** + - .github/workflows/draft-pdf.yml + +jobs: + paper: + runs-on: ubuntu-latest + name: Paper Draft + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Build draft PDF + uses: openjournals/openjournals-draft-action@master + with: + journal: joss + # This should be the path to the paper within your repo. + paper-path: paper.md + - name: Upload + uses: actions/upload-artifact@v4 + with: + name: paper + # This is the output path where Pandoc will write the compiled + # PDF. Note, this should be the same directory as the input + # paper.md + path: paper.pdf + + - name: Commit PDF to repository + uses: EndBug/add-and-commit@v9 + with: + message: '(auto) Paper PDF Draft' + # This should be the path to the paper within your repo. + add: 'paper.pdf' # 'paper/*.pdf' to commit all PDFs in the paper directory \ No newline at end of file diff --git a/README.md b/README.md index 02aa1ac..dd66a89 100644 --- a/README.md +++ b/README.md @@ -52,16 +52,22 @@ The Stein score of the marginal probability distributions over $t$ is approximat For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$ $$ -\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = F(\boldsymbol{x}(t), t). +\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = f'(\boldsymbol{x}(t), t)\text{d}t. $$ The continuous normalizing flow formalism allows the ODE to be expressed as $$ -\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = -\text{Tr}\bigg [ \frac{\partial}{\partial \boldsymbol{x}(t)} F(\boldsymbol{x}(t), t) \bigg ], +\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = \nabla_{\boldsymbol{x}} \cdot f'(\boldsymbol{x}(t), t), $$ -but note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models. +which gives the log-likelihood of a datapoint $\boldsymbol{x}$ as + +$$ +\log p(\boldsymbol{x}(0)) = \log p(\boldsymbol{x}(T)) + \int_{t=0}^{t=T}\text{d}t \; \nabla_{\boldsymbol{x}}\cdot f'(\boldsymbol{x}, t). +$$ + +Note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models. ### Usage @@ -112,7 +118,6 @@ model = sbgm.train.train( sde, dataset, config, - reload_opt_state=False, sharding=sharding, save_dir=root_dir ) diff --git a/data/__init__.py b/data/__init__.py index 4d181dd..6250755 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -7,7 +7,7 @@ from .moons import moons from .grfs import grfs from .quijote import quijote -from .utils import Scaler, ScalerDataset, _InMemoryDataLoader, _TorchDataLoader +from .utils import Scaler, ScalerDataset, InMemoryDataLoader, TorchDataLoader def get_dataset( diff --git a/data/cifar10.py b/data/cifar10.py index 054577f..716d702 100644 --- a/data/cifar10.py +++ b/data/cifar10.py @@ -4,7 +4,7 @@ from jaxtyping import Key from torchvision import transforms, datasets -from .utils import Scaler, ScalerDataset, _TorchDataLoader +from .utils import Scaler, ScalerDataset, TorchDataLoader def cifar10(path: str, key: Key) -> ScalerDataset: @@ -44,10 +44,10 @@ def cifar10(path: str, key: Key) -> ScalerDataset: transform=valid_transform ) - train_dataloader = _TorchDataLoader( + train_dataloader = TorchDataLoader( train_dataset, data_shape, context_shape=None, parameter_dim=parameter_dim, key=key_train ) - valid_dataloader = _TorchDataLoader( + valid_dataloader = TorchDataLoader( valid_dataset, data_shape, context_shape=None, parameter_dim=parameter_dim, key=key_valid ) diff --git a/data/flowers.py b/data/flowers.py index aaa8d10..f87b961 100644 --- a/data/flowers.py +++ b/data/flowers.py @@ -3,7 +3,7 @@ from jaxtyping import Key from torchvision import transforms, datasets -from .utils import Scaler, ScalerDataset, _TorchDataLoader +from .utils import Scaler, ScalerDataset, TorchDataLoader def flowers(key: Key, n_pix: int) -> ScalerDataset: @@ -46,8 +46,8 @@ def flowers(key: Key, n_pix: int) -> ScalerDataset: transform=valid_transform ) - train_dataloader = _TorchDataLoader(train_dataset, key=key_train) - valid_dataloader = _TorchDataLoader(valid_dataset, key=key_valid) + train_dataloader = TorchDataLoader(train_dataset, key=key_train) + valid_dataloader = TorchDataLoader(valid_dataset, key=key_valid) def label_fn(key, n): Q = None diff --git a/data/grfs.py b/data/grfs.py index 84a5f01..e095b6e 100644 --- a/data/grfs.py +++ b/data/grfs.py @@ -9,7 +9,7 @@ from torchvision import transforms import powerbox -from .utils import Scaler, ScalerDataset, _TorchDataLoader +from .utils import Scaler, ScalerDataset, TorchDataLoader data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" @@ -135,8 +135,8 @@ def grfs(key, n_pix, split=0.5): valid_dataset = MapDataset( (X[n_train:], Q[n_train:], A[n_train:]), transform=valid_transform ) - train_dataloader = _TorchDataLoader(train_dataset, key=key_train) - valid_dataloader = _TorchDataLoader(valid_dataset, key=key_valid) + train_dataloader = TorchDataLoader(train_dataset, key=key_train) + valid_dataloader = TorchDataLoader(valid_dataset, key=key_valid) # Don't have many maps # train_dataloader = _InMemoryDataLoader( diff --git a/data/mnist.py b/data/mnist.py index 4d75c69..5aa0501 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -5,7 +5,7 @@ from torch import Tensor from torchvision import datasets -from .utils import Scaler, ScalerDataset, _InMemoryDataLoader +from .utils import Scaler, ScalerDataset, InMemoryDataLoader def tensor_to_array(tensor: Tensor) -> Array: @@ -46,10 +46,10 @@ def mnist(path:str, key: Key) -> ScalerDataset: # scaler = Normer(train_data.mean(), train_data.std()) - train_dataloader = _InMemoryDataLoader( + train_dataloader = InMemoryDataLoader( train_data, Q=None, A=train_targets, key=key_train ) - valid_dataloader = _InMemoryDataLoader( + valid_dataloader = InMemoryDataLoader( valid_data, Q=None, A=valid_targets, key=key_valid ) diff --git a/data/moons.py b/data/moons.py index dbd52bb..9a1664a 100644 --- a/data/moons.py +++ b/data/moons.py @@ -2,7 +2,7 @@ import jax.random as jr from sklearn.datasets import make_moons -from .utils import ScalerDataset, _InMemoryDataLoader +from .utils import ScalerDataset, InMemoryDataLoader def key_to_seed(key): @@ -32,12 +32,12 @@ def moons(key): train_data = (Xt - mean) / std valid_data = (Xv - mean) / std - train_dataloader = _InMemoryDataLoader( + train_dataloader = InMemoryDataLoader( X=jnp.asarray(train_data), A=jnp.asarray(Yt)[:, jnp.newaxis], key=key_train ) - valid_dataloader = _InMemoryDataLoader( + valid_dataloader = InMemoryDataLoader( X=jnp.asarray(valid_data), A=jnp.asarray(Yv)[:, jnp.newaxis], key=key_valid diff --git a/data/quijote.py b/data/quijote.py index 36688a3..73debe1 100644 --- a/data/quijote.py +++ b/data/quijote.py @@ -8,9 +8,9 @@ import torch from torchvision import transforms -from .utils import Scaler, ScalerDataset, _TorchDataLoader, _InMemoryDataLoader +from .utils import Scaler, ScalerDataset, TorchDataLoader, InMemoryDataLoader -data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" +DATA_DIR = "/project/ls-gruen/users/jed.homer/data/fields/" class MapDataset(torch.utils.data.Dataset): @@ -39,8 +39,8 @@ def __len__(self): def get_quijote_data(n_pix: int) -> Tuple[Array, Array]: - X = np.load(os.path.join(data_dir, "quijote_fields.npy"))[:, np.newaxis, ...] - A = np.load(os.path.join(data_dir, "quijote_parameters.npy")) + X = np.load(os.path.join(DATA_DIR, "quijote_fields.npy"))[:, np.newaxis, ...] + A = np.load(os.path.join(DATA_DIR, "quijote_parameters.npy")) dx = int(256 / n_pix) X = X.reshape((-1, 1, n_pix, dx, n_pix, dx)).mean(axis=(3, 5)) @@ -48,7 +48,7 @@ def get_quijote_data(n_pix: int) -> Tuple[Array, Array]: def get_quijote_labels() -> Array: - Q = np.load(os.path.join(data_dir, "quijote_parameters.npy")) + Q = np.load(os.path.join(DATA_DIR, "quijote_parameters.npy")) return Q @@ -91,14 +91,14 @@ def quijote(key, n_pix, split=0.5): valid_dataset = MapDataset( (X[n_train:], A[n_train:]), transform=valid_transform ) - # train_dataloader = _TorchDataLoader( + # train_dataloader = TorchDataLoader( # train_dataset, # data_shape=data_shape, # context_shape=None, # parameter_dim=parameter_dim, # key=key_train # ) - # valid_dataloader = _TorchDataLoader( + # valid_dataloader = TorchDataLoader( # valid_dataset, # data_shape=data_shape, # context_shape=None, @@ -107,10 +107,10 @@ def quijote(key, n_pix, split=0.5): # ) # Don't have many maps - train_dataloader = _InMemoryDataLoader( + train_dataloader = InMemoryDataLoader( X=X[:n_train], A=A[:n_train], key=key_train ) - valid_dataloader = _InMemoryDataLoader( + valid_dataloader = InMemoryDataLoader( X=X[n_train:], A=A[n_train:], key=key_valid ) diff --git a/data/utils.py b/data/utils.py index c091430..8821454 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,5 +1,5 @@ import abc -from typing import Tuple, Union, NamedTuple, Callable +from typing import Tuple, Union, Callable from dataclasses import dataclass import jax import jax.numpy as jnp @@ -27,7 +27,7 @@ def loop(self, batch_size): pass -class _InMemoryDataLoader(_AbstractDataLoader): +class InMemoryDataLoader(_AbstractDataLoader): def __init__(self, X, Q=None, A=None, *, key): self.X = X self.Q = Q @@ -58,7 +58,7 @@ def loop(self, batch_size): end = start + batch_size -class _TorchDataLoader(_AbstractDataLoader): +class TorchDataLoader(_AbstractDataLoader): def __init__( self, dataset, @@ -132,8 +132,8 @@ def __init__(self, x_mean=0., x_std=1.): @dataclass class ScalerDataset: name: str - train_dataloader: Union[_TorchDataLoader | _InMemoryDataLoader] - valid_dataloader: Union[_TorchDataLoader | _InMemoryDataLoader] + train_dataloader: Union[TorchDataLoader | InMemoryDataLoader] + valid_dataloader: Union[TorchDataLoader | InMemoryDataLoader] data_shape: Tuple[int] context_shape: Tuple[int] parameter_dim: int diff --git a/paper/paper.bib b/paper/paper.bib index 72e3977..44171c3 100644 --- a/paper/paper.bib +++ b/paper/paper.bib @@ -1,59 +1,215 @@ -@article{Pearson:2017, - url = {http://adsabs.harvard.edu/abs/2017arXiv170304627P}, - Archiveprefix = {arXiv}, - Author = {{Pearson}, S. and {Price-Whelan}, A.~M. and {Johnston}, K.~V.}, - Eprint = {1703.04627}, - Journal = {ArXiv e-prints}, - Keywords = {Astrophysics - Astrophysics of Galaxies}, - Month = mar, - Title = {{Gaps in Globular Cluster Streams: Pal 5 and the Galactic Bar}}, - Year = 2017 -} - -@book{Binney:2008, - url = {http://adsabs.harvard.edu/abs/2008gady.book.....B}, - Author = {{Binney}, J. and {Tremaine}, S.}, - Booktitle = {Galactic Dynamics: Second Edition, by James Binney and Scott Tremaine.~ISBN 978-0-691-13026-2 (HB).~Published by Princeton University Press, Princeton, NJ USA, 2008.}, - Publisher = {Princeton University Press}, - Title = {{Galactic Dynamics: Second Edition}}, - Year = 2008 -} - -@article{gaia, - author = {{Gaia Collaboration}}, - title = "{The Gaia mission}", - journal = {Astronomy and Astrophysics}, - archivePrefix = "arXiv", - eprint = {1609.04153}, - primaryClass = "astro-ph.IM", - keywords = {space vehicles: instruments, Galaxy: structure, astrometry, parallaxes, proper motions, telescopes}, - year = 2016, - month = nov, - volume = 595, - doi = {10.1051/0004-6361/201629272}, - url = {http://adsabs.harvard.edu/abs/2016A%26A...595A...1G}, -} - -@article{astropy, - author = {{Astropy Collaboration}}, - title = "{Astropy: A community Python package for astronomy}", - journal = {Astronomy and Astrophysics}, - archivePrefix = "arXiv", - eprint = {1307.6212}, - primaryClass = "astro-ph.IM", - keywords = {methods: data analysis, methods: miscellaneous, virtual observatory tools}, - year = 2013, - month = oct, - volume = 558, - doi = {10.1051/0004-6361/201322068}, - url = {http://adsabs.harvard.edu/abs/2013A%26A...558A..33A} -} - -@misc{fidgit, - author = {A. M. Smith and K. Thaney and M. Hahnel}, - title = {Fidgit: An ungodly union of GitHub and Figshare}, +# https://joss.readthedocs.io/en/latest/paper.html#citations BibTeX is ok + +@misc{diffusion, + title={Deep Unsupervised Learning using Nonequilibrium Thermodynamics}, + author={Jascha Sohl-Dickstein and Eric A. Weiss and Niru Maheswaranathan and Surya Ganguli}, + year={2015}, + eprint={1503.03585}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1503.03585}, +} + +@article{sbi, + author = {Kyle Cranmer and Johann Brehmer and Gilles Louppe }, + title = {The frontier of simulation-based inference}, + journal = {Proceedings of the National Academy of Sciences}, + volume = {117}, + number = {48}, + pages = {30055-30062}, + year = {2020}, + doi = {10.1073/pnas.1912789117}, + URL = {https://www.pnas.org/doi/abs/10.1073/pnas.1912789117}, + eprint = {https://www.pnas.org/doi/pdf/10.1073/pnas.1912789117}, + abstract = {Many domains of science have developed complex simulations to describe phenomena of interest. While these simulations provide high-fidelity models, they are poorly suited for inference and lead to challenging inverse problems. We review the rapidly developing field of simulation-based inference and identify the forces giving additional momentum to the field. Finally, we describe how the frontier is expanding so that a broad audience can appreciate the profound influence these developments may have on science.} +} + +@article{field_level_inference, + title={Bayesian field-level inference of primordial non-Gaussianity using next-generation galaxy surveys}, + volume={520}, + ISSN={1365-2966}, + url={http://dx.doi.org/10.1093/mnras/stad432}, + DOI={10.1093/mnras/stad432}, + number={4}, + journal={Monthly Notices of the Royal Astronomical Society}, + publisher={Oxford University Press (OUP)}, + author={Andrews, Adam and Jasche, Jens and Lavaux, Guilhem and Schmidt, Fabian}, + year={2023}, + month=feb, pages={5746–5763} } + + +@misc{Feng2023, + title={Score-Based Diffusion Models as Principled Priors for Inverse Imaging}, + author={Berthy T. Feng and Jamie Smith and Michael Rubinstein and Huiwen Chang and Katherine L. Bouman and William T. Freeman}, + year={2023}, + eprint={2304.11751}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2304.11751}, +} + +@misc{Feng2024, + title={Variational Bayesian Imaging with an Efficient Surrogate Score-based Prior}, + author={Berthy T. Feng and Katherine L. Bouman}, + year={2024}, + eprint={2309.01949}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2309.01949}, +} + +@misc{inverse_problem_medical, + title={Solving Inverse Problems in Medical Imaging with Score-Based Generative Models}, + author={Yang Song and Liyue Shen and Lei Xing and Stefano Ermon}, + year={2022}, + eprint={2111.08005}, + archivePrefix={arXiv}, + primaryClass={eess.IV}, + url={https://arxiv.org/abs/2111.08005}, +} + +@misc{conditional_diffusion, + title={Conditional Image Generation with Score-Based Diffusion Models}, + author={Georgios Batzolis and Jan Stanczuk and Carola-Bibiane Schönlieb and Christian Etmann}, + year={2021}, + eprint={2111.13606}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2111.13606}, +} + +@misc{kidger, + title={On Neural Differential Equations}, + author={Patrick Kidger}, + year={2022}, + eprint={2202.02435}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2202.02435}, +} + +@misc{sde, + title={Score-Based Generative Modeling through Stochastic Differential Equations}, + author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole}, + year={2021}, + eprint={2011.13456}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2011.13456}, +} + +@misc{sde_ml, + title={Maximum Likelihood Training of Score-Based Diffusion Models}, + author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon}, + year={2021}, + eprint={2101.09258}, + archivePrefix={arXiv}, + primaryClass={stat.ML}, + url={https://arxiv.org/abs/2101.09258}, +} + +@misc{ddpm, + title={Denoising Diffusion Probabilistic Models}, + author={Jonathan Ho and Ajay Jain and Pieter Abbeel}, + year={2020}, + eprint={2006.11239}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2006.11239}, +} + +@misc{ffjord, + title={FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models}, + author={Will Grathwohl and Ricky T. Q. Chen and Jesse Bettencourt and Ilya Sutskever and David Duvenaud}, + year={2018}, + eprint={1810.01367}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1810.01367}, +} + +@misc{neuralodes, + title={Neural Ordinary Differential Equations}, + author={Ricky T. Q. Chen and Yulia Rubanova and Jesse Bettencourt and David Duvenaud}, + year={2019}, + eprint={1806.07366}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1806.07366}, +} + +@misc{blinddiffusion, + title={Parallel Diffusion Models of Operator and Image for Blind Inverse Problems}, + author={Hyungjin Chung and Jeongsol Kim and Sehui Kim and Jong Chul Ye}, + year={2022}, + eprint={2211.10656}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2211.10656}, +} + +@misc{ambientdiffusion, + title={Consistent Diffusion Meets Tweedie: Training Exact Ambient Diffusion Models with Noisy Data}, + author={Giannis Daras and Alexandros G. Dimakis and Constantinos Daskalakis}, + year={2024}, + eprint={2404.10177}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2404.10177}, +} + +@article{emulating, + title={CosmoPower: emulating cosmological power spectra for accelerated Bayesian inference from next-generation surveys}, + volume={511}, + ISSN={1365-2966}, + url={http://dx.doi.org/10.1093/mnras/stac064}, + DOI={10.1093/mnras/stac064}, + number={2}, + journal={Monthly Notices of the Royal Astronomical Society}, + publisher={Oxford University Press (OUP)}, + author={Spurio Mancini, Alessio and Piras, Davide and Alsing, Justin and Joachimi, Benjamin and Hobson, Michael P}, + year={2022}, + month=jan, pages={1771–1788} } + + +@software{jax, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/jax-ml/jax}, + version = {0.3.13}, + year = {2018}, +} + +@article{equinox, + author={Patrick Kidger and Cristian Garcia}, + title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations}, + year={2021}, + journal={Differentiable Programming workshop at Neural Information Processing Systems 2021} +} + +@software{optax, + title = {The {D}eep{M}ind {JAX} {E}cosystem}, + author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio}, + url = {http://github.com/google-deepmind}, year = {2020}, - publisher = {GitHub}, - journal = {GitHub repository}, - url = {https://github.com/arfon/fidgit} +} + +@misc{resnet, + title={Deep Residual Learning for Image Recognition}, + author={Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, + year={2015}, + eprint={1512.03385}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/1512.03385}, +} + +@misc{unet, + title={U-Net: Convolutional Networks for Biomedical Image Segmentation}, + author={Olaf Ronneberger and Philipp Fischer and Thomas Brox}, + year={2015}, + eprint={1505.04597}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/1505.04597}, } \ No newline at end of file diff --git a/paper/paper.md b/paper/paper.md index c88d074..6302c02 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -9,7 +9,7 @@ tags: - Emulators authors: - name: Jed Homer - orcid: 0000-0000-0000-0000 + orcid: 0009-0002-0985-1437 equal-contrib: true affiliation: "1" # (Multiple affiliations must be quoted) affiliations: @@ -72,43 +72,23 @@ aas-journal: Astrophysical Journal <- The name of the AAS journal. - likelihood weighting (maximum likelihood training of SBGMs) --> -Diffusion-based generative models are a method for density estimation and sampling from high-dimensional distributions. A sub-class of these models, score-based diffusion generatives models (SBGMs), permit exact-likelihood estimation via a change-of-variables associated with the forward diffusion process. Diffusion models allow fitting generative models to high-dimensional data in a more efficient way than normalising flows since only one neural network model parameterises the diffusion process as opposed to a stack of networks in typical normalising flow architectures. +Diffusion-based generative models [@diffusion, @ddpm] are a method for density estimation and sampling from high-dimensional distributions. A sub-class of these models, score-based diffusion generatives models (SBGMs, [@sde]), permit exact-likelihood estimation via a change-of-variables associated with the forward diffusion process [@sde_ml]. Diffusion models allow fitting generative models to high-dimensional data in a more efficient way than normalising flows since only one neural network model parameterises the diffusion process as opposed to a stack of networks in typical normalising flow architectures. -The software we present, `sbgm`, is designed to be used by machine learning and physics researchers for fitting diffusion models with a suite of custom architectures for their tasks. These models can be fit easily with multi-accelerator training and inference within the code. Typical use cases for these kinds of generative models are emulator approaches, simulation-based inference (likelihood-free inference), field-level infrence and general inverse problems. This code allows for seemless integration of diffusion models to these applications by allowing for easy conditioning of data on parameters, classes or other data such as images. - - - - - +The software we present, `sbgm`, is designed to be used by researchers in machine learning and the natural sciences for fitting diffusion models with a suite of custom architectures for their tasks. These models can be fit easily with multi-accelerator training and inference within the code. Typical use cases for these kinds of generative models are emulator approaches [@emulating], simulation-based inference (likelihood-free inference, [@sbi]), field-level inference [@field_level_inference] and general inverse problems [@inverse_problem_medical; @Feng2023; @Feng2024] (e.g. image inpainting [@sde] and denoising [@ambientdiffusion; @blinddiffusion]). This code allows for seemless integration of diffusion models to these applications by allowing for easy conditioning of data on parameters, classifying variables or other data such as images. + ![A diagram showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.\label{fig:sde_ode}](sde_ode.png) # Mathematics -Diffusion models model the reverse of a forward diffusion process on samples of data $\boldsymbol{x}$ by adding a sequence of noisy perturbations. +Diffusion models model the reverse of a forward diffusion process on samples of data $\boldsymbol{x}$ by adding a sequence of noisy perturbations [@diffusion]. -Score-based diffusion models model the forward diffusion process with Stochastic Differential Equations (SDEs) of the form +Score-based diffusion models model the forward diffusion process with Stochastic Differential Equations (SDEs, [@sde]) of the form $$ \text{d}\boldsymbol{x} = f(\boldsymbol{x}, t)\text{d}t + g(t)\text{d}\boldsymbol{w}, @@ -132,44 +112,60 @@ $$ where the score function $\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$ is substituted with a neural network $\boldsymbol{s}_{\theta}(\boldsymbol{x}(t), t)$ for the sampling process. This network predicts the noise added to the image at time $t$ with the forward diffusion process, in accordance with the SDE, and removes it. This defines the sampling chain for a diffusion model. -The parameters of the network $\theta$ are fit via stochastic gradient descent of the score-matching loss +The score-based diffusion model for the data is fit by optimising the parameters of the network $\theta$ via stochastic gradient descent of the score-matching loss $$ - \mathbb{E}_{t\sim\mathcal{U}(0, T)}\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}[\lambda(t)||\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) - \boldsymbol{s}_{\theta}(\boldsymbol{x},t)||_2^2] + % \mathbb{E}_{t\sim\mathcal{U}(0, T)}\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}[\lambda(t)||\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) - \boldsymbol{s}_{\theta}(\boldsymbol{x},t)||_2^2] + + \mathcal{L}(\theta) = \mathbb{E}_{t\sim\mathcal{U}(0, T)}\mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}\mathbb{E}_{\boldsymbol{x}(t)\sim p(\boldsymbol{x}(t)|\boldsymbol{x})}[\lambda(t)||\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}(t)|\boldsymbol{x}(0)) - \boldsymbol{s}_{\theta}(\boldsymbol{x}(t),t)||_2^2] + $$ -where $\lambda(t)$ is an arbitrary scalar weighting function, chosen to weight certain times - usually near $t=0$ where the data has only a small amount of noise added. +where $\lambda(t)$ is an arbitrary scalar weighting function, chosen to preferentially weight certain times - usually near $t=0$ where the data has only a small amount of noise added. Here, $p_t(\boldsymbol{x}(t)|\boldsymbol{x}(0))$ is the transition kernel for Gaussian diffusion paths. This is defined depending on the form of the SDE \cite{} and for the common variance-preserving (VP) SDE the kernel is written as -In Figure \autoref{fig:sde_ode} the forward and reverse diffusion processes are shown for a toy problem with their corresponding SDE and ODE paths. +$$ + p(\boldsymbol{x}(t)|\boldsymbol{x}(0)) = \mathcal{G}[\boldsymbol{x}(t)|\mu_t \cdot \boldsymbol{x}(0), \sigma^2_t \cdot \mathbb{I}] +$$ +where $\mathcal{G}[\cdot]$ is a Gaussian distribution, $\mu_t=\exp(-\int_0^t\text{d}s \; \beta(s))$ and $\sigma^2_t = 1 - \mu_t$. $\beta(t)$ is typically chosen to be a simple linear function of $t$. + +In Figure \ref{fig:sde_ode} the forward and reverse diffusion processes are shown for a samples from a Gaussian mixture with their corresponding SDE and ODE paths. The reverse SDE may be solved with Euler-Murayama sampling (or other annealed Langevin sampling methods) which is featured in the code. However, many of the applications of generative models depend on being able to calculate the likelihood of data. In [1] it is shown that any SDE may be converted into an ordinary differential equation (ODE) without changing the distributions, defined by the SDE, from which the noise is sampled from in the diffusion process. This ODE is known as the probability flow ODE and is written $$ - \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t) - g^2(t)\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t. + \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t) - g^2(t)\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = f'(\boldsymbol{x}, t)\text{d}t. $$ -This ODE can be solved with an initial-value problem that maps a prior -sample from a multivariate Gaussian to the data distribution. This inherits the formalism of continuous normalising flows without the expensive ODE simulations used to train these flows. The likelihood estimate under a score-based diffusion model is estimated by solving the change-of-variables equation for continuous normalising flows. The code implements these calculations also for the Hutchinson trace estimation method that reduces the computational expense of the estimate. +This ODE can be solved with an initial-value problem that maps a prior sample from a multivariate Gaussian to the data distribution. This inherits the formalism of continuous normalising flows [@neuralodes; @ffjord] without the expensive ODE simulations used to train these flows. + +The likelihood estimate under a score-based diffusion model is estimated by solving the change-of-variables equation for continuous normalising flows. + +$$ +\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = \nabla_{\boldsymbol{x}} \cdot f(\boldsymbol{x}(t), t), +$$ + +which gives the log-likelihood of a single datapoint $\boldsymbol{x}(0)$ as + +$$ +\log p(\boldsymbol{x}(0)) = \log p(\boldsymbol{x}(T)) + \int_{t=0}^{t=T}\text{d}t \; \nabla_{\boldsymbol{x}}\cdot f(\boldsymbol{x}, t). +$$ + + +The code implements these calculations also for the Hutchinson trace estimation method [@ffjord] that reduces the computational expense of the estimate. - - +and refer to \autoref{eq:fourier} from text. --> # Citations @@ -185,17 +181,15 @@ For a quick reference, the following citation commands can be used: - `[@author:2001]` -> "(Author et al., 2001)" - `[@author1:2001; @author2:2001]` -> "(Author1 et al., 2001; Author2 et al., 2002)" -# Figures - -Figures can be included like this: + # Acknowledgements -We thank the developers of these packages for their work and for making their code available to the community. +We thank the developers of the packages `jax` [@jax], `optax` [@optax], `equinox` [@equinox] and `diffrax` [@kidger] for their work and for making their code available to the community. # References \ No newline at end of file