diff --git a/figures/moded.png b/figures/moded.png deleted file mode 100644 index 1762044..0000000 Binary files a/figures/moded.png and /dev/null differ diff --git a/figures/moped.png b/figures/moped.png new file mode 100644 index 0000000..1710897 Binary files /dev/null and b/figures/moped.png differ diff --git a/paper_OJA.tex b/paper_OJA.tex index 0184a5b..51968c4 100644 --- a/paper_OJA.tex +++ b/paper_OJA.tex @@ -158,6 +158,7 @@ \newcommand{\bydef}{:=} \newcommand{\jaxcosmo}{\texttt{jax-cosmo}} \newcommand{\autodiff}{\texttt{autodiff}} +\newcommand{\jax}{\texttt{JAX}} @@ -177,40 +178,47 @@ J.~E. Campagne$^{1,\ast}$, F. Lanusse$^2$, J. Zuntz$^3$,\\ -A. Boucaud$^8$, -D.~Kirkby$^4$, -D. Lanzieri$^2$, -Y.~Li$^{5,6}$, -A. Peel$^7$ -\jaxcosmo\ contributors} +A. Boucaud$^4$, +S. Casas$^{5}$, +M.~Karamanis$^{6,7}$, +D.~Kirkby$^8$, +D. Lanzieri$^9$, +Y.~Li$^{10,11}$, +A. Peel$^{12}$ +} \thanks{$^\ast$jean-eric.campagne@ijclab.in2p3.fr} \affiliation{ -$^1$Université Paris-Saclay, CNRS/IN2P3, IJCLab, 91405 Orsay, France\\ +$^1$Université Paris-Saclay, CNRS/IN2P3, IJCLab, 91405 Orsay, France } - -\affiliation{$^5$Department of Mathematics and Theory, Peng Cheng Laboratory, Shenzhen, Guangdong 518066, China} -\affiliation{$^6$Center for Computational Astrophysics \& Center for Computational Mathematics, Flatiron Institute, New York, New York 10010, USA} -\affiliation{$^7$Ecole Polytechnique F\'ed\'erale de Lausanne (EPFL), Observatoire de Sauverny, 1290 Versoix, Switzerland} -\affiliation{$^8$Université de Paris, CNRS, Astroparticule et Cosmologie, F-75013 Paris, France} +\affiliation{$^2$Université Paris-Saclay, Université Paris Cité, CEA, CNRS, AIM, 91191, Gif-sur-Yvette, France} +\affiliation{$^3$Institute for Astronomy, University of Edinburgh, Edinburgh EH9 3HJ, United Kingdom} +\affiliation{$^4$Université de Paris, CNRS, Astroparticule et Cosmologie, F-75013 Paris, France} +\affiliation{$^{5}$Institute for Theoretical Particle Physics and Cosmology (TTK), RWTH Aachen University, 52056 Aachen, Germany.} +\affiliation{$^{6}$Berkeley Center for Cosmological Physics, University of California, Berkeley, CA 94720, USA} +\affiliation{$^{7}$Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA} +\affiliation{$^8$Department of Physics and Astronomy, University of California, Irvine, CA 92697, USA} +\affiliation{$^9$Université Paris Cité, Université Paris-Saclay, CEA, CNRS, AIM, F-91191, Gif-sur-Yvette, France} +\affiliation{$^{10}$Department of Mathematics and Theory, Peng Cheng Laboratory, Shenzhen, Guangdong 518066, China} +\affiliation{$^{11}$Center for Computational Astrophysics \& Center for Computational Mathematics, Flatiron Institute, New York, New York 10010, USA} +\affiliation{$^{12}$Ecole Polytechnique F\'ed\'erale de Lausanne (EPFL), Observatoire de Sauverny, 1290 Versoix, Switzerland} %\date{\today} \begin{abstract} - - -We present \jaxcosmo, a library for differentiable cosmological theory calculations. \jaxcosmo\ uses the recent \texttt{JAX} library, which has created a new coding ecosystem, especially in probabilistic programming. +We present \jaxcosmo, a library for automatically differentiable cosmological theory calculations. \jaxcosmo\ uses the \jax\ library, which has created a new coding ecosystem, especially in probabilistic programming. % -As well as batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU), \texttt{JAX} exposes an \textit{automatic differentiation} (AD) mechanism. AD can take a Python function that uses \texttt{JAX} and return its (vector) derivative, computed not with finite differences but by successively differentiating each instruction in turn. This gives us the opportunity to apply a range of powerful Bayesian inference algorithms, otherwise impractical in cosmology, such as Hamiltonian Monte Carlo and Variational Inference. +As well as batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU), \jax\ exposes an \textit{automatic differentiation} (AD) mechanism. Thanks to AD, \jaxcosmo\ gives access to the derivatives of cosmological likelihoods with respect to any of their parameters, and thus enables a range of powerful % AD can take a Python function that uses \jax\ and return its (vector) derivative, computed not with finite differences but by successively differentiating each instruction in turn. This gives us the opportunity to apply a range of powerful +Bayesian inference algorithms, otherwise impractical in cosmology, such as Hamiltonian Monte Carlo and Variational Inference. % -\jaxcosmo\ implements background evolution, linear and non-linear power spectra (using \texttt{halofit} or the Eisenstein and Hu transfer function), as well as angular power spectra ($C_\ell$) with the Limber approximation for galaxy and weak lensing probes, all differentiable with respect to the cosmological parameters and their other inputs. +In its initial release, \jaxcosmo\ implements background evolution, linear and non-linear power spectra (using \texttt{halofit} or the Eisenstein and Hu transfer function), as well as angular power spectra ($C_\ell$) with the Limber approximation for galaxy and weak lensing probes, all differentiable with respect to the cosmological parameters and their other inputs. % -We discuss algorithms made possible by this library, and present comparisons with the Core Cosmology Library as a benchmark, and run a series of tests using the Dark Energy Survey Year 1 3x2pt analysis with the \numpyro\ library to demonstrate practical inference. +We illustrate how automatic differentiation can be a game-changer for common tasks involving Fisher matrix computations, or full posterior inference with gradient-based techniques (e.g. Hamiltonian Monte Carlo). In particular, we show how Fisher matrices are now fast, exact, no longer require any fine tuning, and are themselves differentiable with respect to parameters of the likelihood enabling complex survey optimization by simple gradient descent. Finally, using a Dark Energy Survey Year 1 3x2pt analysis as a benchmark, we demonstrate how \jaxcosmo\ can be combined with Probabilistic Programming Languages such as \numpyro\ to perform posterior inference with state-of-the-art algorithms including a No U-Turn Sampler (NUTS), Automatic Differentiation Variational Inference (ADVI), and Neural Transport HMC (NeuTra).% We discuss algorithms made possible by this library, present comparisons with the Core Cosmology Library as a benchmark, and run a series of tests using the Dark Energy Survey Year 1 3x2pt analysis with the \numpyro\ library to demonstrate practical inference. % -We show that clear improvements are possible using HMC compared to Metropolis-Hasting, and that Normalizing Flows using the Neural Transport are a promising methodology. +We show that clear improvements are possible using HMC compared to Metropolis-Hasting, and that Normalizing Flows using the Neural Transport are a promising methodology.\FrL{we need a better statement than this, what are our conclusions? Are we faster/better than Cobaya?} -% The recent \texttt{JAX} library has created a new ecosystem from which probabilistic programming software can take enormous benefit: batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU) can provide huge speed-ups to a wide range of different problems. In particular, \texttt{JAX} exposes an \textit{automatic differentiation} mechanism, which can take a python function that uses \texttt{JAX} and return its (vector) derivative. This gives us the opportunity to apply a range of powerful but otherwise unfeasible algorithms used in Bayesian inference, such as Hamiltonian Monte Carlo (HMC) and Variational Inference. +% The recent \jax\ library has created a new ecosystem from which probabilistic programming software can take enormous benefit: batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU) can provide huge speed-ups to a wide range of different problems. In particular, \jax\ exposes an \textit{automatic differentiation} mechanism, which can take a python function that uses \jax\ and return its (vector) derivative. This gives us the opportunity to apply a range of powerful but otherwise unfeasible algorithms used in Bayesian inference, such as Hamiltonian Monte Carlo (HMC) and Variational Inference. % To take advantage of these possibilities within cosmology we have developed the \jaxcosmo\ library, which implements background evolution, linear and non linear power spectra (using \texttt{halofit} or the Eisenstein and Hu transfer function), as well as angular power spectra ($C_\ell$) with the Limber approximation for galaxy and weak lensing probes. We discuss algorithms made possible by this library, and present comparisons with the Core Cosmology Library as a benchmark, and run a series of tests using the Dark Energy Survey Year 1 3x2pt analysis with the \numpyro\ library to demonstrate practical usage. % We show that clear improvements are possible using HMC compared to Metropolis-Hasting, and that the Normalizing Flows using the Neural Transport is a promising methodology. @@ -229,7 +237,7 @@ %% main text \section{Introduction} -Bayesian inference has been widely used in cosmology in the form of Monte Carlo Markov Chains (MCMC) since the work of \citep{2001ApJ...563L..95K,2003MNRAS.341.1084R} and has been the keystone for past and current analysis thanks partly to packages as \texttt{CosmoMC} \citep{2002PhRvD..66j3511L}, \texttt{CosmoSIS} \citep{2015A&C....12...45Z}, \texttt{MontePython} \citep{2019PDU....24..260B} as well as \texttt{Cobaya} \citep{2019ascl.soft10019T,2021JCAP...05..057T}; see, for instance, the list of citations to these popular packages for an idea of the wide usage in the community. +Bayesian inference has been widely used in cosmology in the form of Monte Carlo Markov Chains (MCMC) since the work of \citet{2001ApJ...563L..95K} and \citet{2003MNRAS.341.1084R}, and has been the keystone for past and current analysis thanks partly to packages such as \texttt{CosmoMC} \citep{2002PhRvD..66j3511L}, \texttt{CosmoSIS} \citep{2015A&C....12...45Z}, \texttt{MontePython} \citep{2019PDU....24..260B}, and \texttt{Cobaya} \citep{2019ascl.soft10019T,2021JCAP...05..057T}; see, for instance, the list of citations to these popular packages for an idea of the wide usage in the community. % \footnote{\url{https://cosmologist.info/cosmomc/readme.html}} % \footnote{\url{http://bitbucket.org/joezuntz/cosmosis}} @@ -239,12 +247,12 @@ \section{Introduction} % Note that some direct optimization of likelihood function has also been successfully undertaken for instance for some Planck analysis \citep{2014A&A...566A..54P}, Since the development of these MCMC packages, major advances have been made in automatic differentiation (\textit{autodiff}) \citep{10.5555/3122009.3242010, Margossian2019}, a set of technologies for transforming pieces of code into their derivatives. -While these tools have especially been applied to neural network optimization and machine learning (ML) in general, they can also enable classical statistical methods that require the derivatives of (e.g. likelihood) functions to operate: we consider such methods in this paper. \textit{Autodiff} has been implemented in widely used libraries like \texttt{Stan} \citep{JSSv076i01}, \texttt{TensorFlow} \citep{tensorflow2015-whitepaper}, and \texttt{PyTorch} \citep{NEURIPS2019_9015}. +While these tools have especially been applied to neural network optimization and machine learning (ML) in general, they can also enable classical statistical methods that require the derivatives of (e.g. likelihood) functions to operate: we consider such methods in this paper. \textit{Autodiff} has been implemented in widely used libraries like \texttt{Stan} \citep{JSSv076i01}, \texttt{TensorFlow} \citep{tensorflow2015-whitepaper}, \texttt{Julia} \citep{bezanson2017julia}, and \texttt{PyTorch} \citep{NEURIPS2019_9015}. -A recent entrant to this field is the \texttt{JAX} library\footnote{\url{https://jax.readthedocs.io}} \citep{jax2018github} which has undergone rapid development and can automatically differentiate native \texttt{Python} and \texttt{NumPy} functions, offering a speed up to the development process and indeed code runtimes. \texttt{JAX} offers an easy parallelization mechanism (\texttt{vmap}), and just-in-time compilation (\texttt{jit}) and optimization targetting CPU, GPU and TPU hardware thanks to the \texttt{XLA} library that powers TensorFlow. These attractive features have driven wide adoption of \texttt{JAX} in computational research, and motivate us to consider its usage in cosmology. +A recent entrant to this field is the \jax\ library\footnote{\url{https://jax.readthedocs.io}} \citep{jax2018github} which has undergone rapid development and can automatically differentiate native \texttt{Python} and \texttt{NumPy} functions, offering a speed up to the development process and indeed code runtimes. \jax\ offers an easy parallelization mechanism (\texttt{vmap}), just-in-time compilation (\texttt{jit}), and optimization targeting CPU, GPU, and TPU hardware thanks to the \texttt{XLA} library that powers TensorFlow. These attractive features have driven wide adoption of \jax\ in computational research, and motivate us to consider its usage in cosmology. -\texttt{JAX} contains bespoke reimplementation of packages such as \texttt{jax.numpy} and \texttt{jax.scipy}, as well as example libraries as \texttt{Stax} for simple but flexible neural network development. It has formed the seed for a wider ecosystem of packages, including, for example: -\texttt{Flax} \citep{flax2020github} a high-performance neural network library, \texttt{JAXopt} \citep{jaxopt_implicit_diff} an hardware accelerated, batchable and differentiable collection of optimizers, \texttt{Optax} \citep{optax2020github} a gradient processing and optimization library, and \numpyro\ \citep{phan2019composable,bingham2019pyro}, a probabilistic programming language (PPL) that is used in this paper. Other PPL packages such as \texttt{PyMC} \citep{Salvatier2016} have switched to \texttt{JAX} backend in recent versions\footnote{A more exhaustive list and rapidly growing list of packages can be found at \url{https://project-awesome.org/n2cholas/awesome-jax}}. +\jax\ contains bespoke reimplementations of packages such as \texttt{jax.numpy} and \texttt{jax.scipy}, as well as example libraries such as \texttt{Stax} for simple but flexible neural network development. It has formed the seed for a wider ecosystem of packages, including, for example: +\texttt{Flax} \citep{flax2020github} a high-performance neural network library, \texttt{JAXopt} \citep{jaxopt_implicit_diff} a hardware accelerated, batchable and differentiable collection of optimizers, \texttt{Optax} \citep{optax2020github} a gradient processing and optimization library, and \numpyro\ \citep{phan2019composable,bingham2019pyro}, a probabilistic programming language (PPL) that is used in this paper. Other PPL packages such as \texttt{PyMC} \citep{Salvatier2016} have switched to a \jax\ backend in recent versions\footnote{A more exhaustive list and rapidly growing list of packages can be found at \url{https://project-awesome.org/n2cholas/awesome-jax}}. % I don't think we need all these URLs, it's getting excessive, especially when they have papers: %\footnote{\url{https://flax.readthedocs.io/}} @@ -257,46 +265,46 @@ \section{Introduction} % \footnote{\url{https://pytorch.org/docs/}} % \footnote{\url{https://mc-stan.org/users/documentation}} -In this context we have developed the open source \jaxcosmo\ library\footnote{\url{https://github.com/DifferentiableUniverseInitiative/jax_cosmo}}, which we present in this paper. The package represents a first step in making the powerful features described above useful for cosmology; it implements a selection of theory predictions for key cosmology observables as differentiable \texttt{JAX} functions. +In this context we have developed the open source \jaxcosmo\ library\footnote{\url{https://github.com/DifferentiableUniverseInitiative/jax_cosmo}}, which we present in this paper. The package represents a first step in making the powerful features described above useful for cosmology; it implements a selection of theory predictions for key cosmology observables as differentiable \jax\ functions. We give an overview of the code's design and contents in Section~\ref{sec-jaxcosmo-design}. We show how to use it for rapid and numerically stable Fisher forecasts in Section~\ref{sec-fisher-forecast}, in more complete gradient-based cosmological inference with variants of Hamiltonian Monte Carlo including the No-U-Turn Sampler, and ML-accelerated Stochastic Variational Inference in Section~\ref{sec:chmc}. We discuss and compare these methods in Section~\ref{sec-discussion} and conclude in Section~\ref{sec-conclusion}. -\section{JAX: GPU Accelerated and Automatically Differentiable Python Programming} +\section{\jax: GPU Accelerated and Automatically Differentiable Python Programming} \label{sec-primer} -The aim of this section is to provide a brief technical primer on \texttt{JAX}, necessary to fully grasp the potential of a cosmology library implemented in this framework. +The aim of this section is to provide a brief technical primer on \jax, necessary to fully grasp the potential of a cosmology library implemented in this framework. % Maybe a short history explaining what autograd and XLA are. %\FrL{Add a short history of XLA and Autograd.} Or maybe not -\paragraph{\textbf{Automatic Differentiation}} Traditionally, two different approaches have been used in cosmology to obtain derivatives of given numerical expressions. The first approach is to derive analytically the formula for the derivatives of interest \citep[e.g.][]{2013MNRAS.432..894J}, with or without the help of tools such as \texttt{Mathematica}\footnote{\url{https://www.wolfram.com/mathematica}.}. This is however only practical typically for simple analytical models. -The second approach is to compute numerical derivatives by finite differences. This approach can be applied on black-box models of arbitrary complexity (from typical Boltzmann codes to cosmological simulations \citep{2020ApJS..250....2V}). However it is notoriously difficult to obtain stable derivatives by finite differences \citep[e.g.][]{2021arXiv210100298B, 2021A&A...649A..52Y}. In addition, their computational cost does not scale well with the number of parameters (a minimum of $(2N+1)$ model evaluations is typically required for $N$ parameters), making them impractical whenever derivatives are needed as part of outer iterative algorithm. +\paragraph{\textbf{Automatic Differentiation}} Traditionally, two different approaches have been used in cosmology to obtain derivatives of given numerical expressions. The first approach is to derive analytically the formula for the derivatives of interest \citep[e.g.][]{2013MNRAS.432..894J}, with or without the help of tools such as \texttt{Mathematica}\footnote{\url{https://www.wolfram.com/mathematica}.}. This is typically only practical, however, for simple analytical models. +The second approach is to compute numerical derivatives by finite differences. This approach can be applied on black-box models of arbitrary complexity (from typical Boltzmann codes to cosmological simulations; \citealp{2020ApJS..250....2V}). However it is notoriously difficult to obtain stable derivatives by finite differences \citep[e.g.][]{2021arXiv210100298B, 2021A&A...649A..52Y}. In addition, their computational cost does not scale well with the number of parameters (a minimum of $(2N+1)$ model evaluations is typically required for $N$ parameters), making them impractical whenever derivatives are needed as part of an outer iterative algorithm. -Automatic differentiation frameworks like \texttt{JAX} take a different approach. They trace the execution of a given model and decompose this trace into fundamental operations with known derivatives (e.g. the derivative of a multiplication operation is known). Then, by applying the chain rule formula, the computational graph for the derivatives (of any order) of the model can be built from the known derivatives of every elementary operations. A new function corresponding to the derivative of the original function is therefore built automatically for the user. We direct the interested reader to \cite{baydin2018automatic, Margossian2019} for in-depth introductions to automatic differentiation. +Automatic differentiation frameworks like \jax\ take a different approach. They trace the execution of a given model and decompose this trace into primitive operations with known derivatives (e.g. multiplication). Then, by applying the chain rule formula, the computational graph for the derivatives (of any order) of the model can be built from the known derivatives of every elementary operation. A new function corresponding to the derivative of the original function is therefore built automatically for the user. We direct the interested reader to \citet{baydin2018automatic} and \citet{Margossian2019} for in-depth introductions to automatic differentiation. -JAX provides in particular a number of operators (\texttt{jax.grad}, \texttt{jax.jacobian}, \texttt{jax.hessian}) which can compute derivatives of any function written in \texttt{JAX}. +\jax\ provides in particular a number of operators (\texttt{jax.grad}, \texttt{jax.jacobian}, \texttt{jax.hessian}) which can compute derivatives of any function written in \jax: \begin{lstlisting}[language=iPython] -# Let us define a simple function +# Define a simple function def f(x): return y = 5 * x + 2 # Take the derivative df_dx = jax.grad(f) -# df_dx is a new function that always return 5 +# df_dx is a new function that always returns 5 \end{lstlisting} -\textbf{Why is that interesting?} \autodiff\ makes it possible to obtain \textit{exact gradients of cosmological likelihoods} with respect to all input parameters at the cost of only 2 likelihood evaluations. +\textbf{Why is this interesting?} \autodiff\ makes it possible to obtain \textit{exact gradients of cosmological likelihoods} with respect to all input parameters at the cost of only two likelihood evaluations. -\paragraph{\textbf{Just In Time Compilation (JIT)}} Despite its convenience and wide adoption in astrophysics, Python still suffers from slow execution times compared to fully compiled languages such as C/C++. One approach to mitigate these issues and make Python code fast is Just In Time compilation, which traces the execution of a given Python function the first time it is called, and compiles it into a fast executable (by-passing the Python interpreter), which can be transparently used in subsequent calls to this function. Compared to other strategies for speeding up Python code such as Cython, JIT allows the user to simply write plain Python code, and reap the benefits of compiled code. +\paragraph{\textbf{Just In Time Compilation (JIT)}} Despite its convenience and wide adoption in astrophysics, Python still suffers from slow execution times compared to fully compiled languages such as C/C++. One approach to mitigate these issues and make Python code fast is Just In Time compilation, which traces the execution of a given Python function the first time it is called, and compiles it into a fast executable (by-passing the Python interpreter), which can be transparently used in subsequent calls to this function. Compared to other strategies for speeding up Python code such as Cython, JIT allows the user to write plain Python code, and reap the benefits of compiled code. -A number of libraries allowing for JIT have already been used in astrophysics, in particular Numba\footnote{\url{https://numba.pydata.org/}}, or the HOPE library \cite{2015A&C....10....1A} developed specifically for the need of astrophysics. \texttt{JAX} stands out compared to these other frameworks in that it relies on the highly optimized XLA library\footnote{\url{https://www.tensorflow.org/xla}} for executing the compiled expressions. XLA is continuously developed by Google as part of the TensorFlow project, for efficient training and inference of large scale deep learning applications, and as such supports computational backends such as GPU an Tensor Processing Units (TPU) clusters. The ability to perform computations directly on GPUs through XLA is one of the major benefits of \texttt{JAX}, as speed-ups of at least two orders of magnitudes can be expected for typical parallel linear algebra computations compared to CPU. +A number of libraries allowing for JIT have already been used in astrophysics, in particular Numba\footnote{\url{https://numba.pydata.org/}}, or the HOPE library \cite{2015A&C....10....1A} developed specifically for the needs of astrophysics. \jax\ stands out compared to these other frameworks in that it relies on the highly optimized XLA library\footnote{\url{https://www.tensorflow.org/xla}} for executing the compiled expressions. XLA is continuously developed by Google as part of the TensorFlow project, for efficient training and inference of large scale deep learning applications, and as such supports computational backends such as GPU and Tensor Processing Units (TPU) clusters. The ability to perform computations directly on GPUs through XLA is one of the major benefits of \jax, as speed-ups of at least two orders of magnitudes can be expected for typical parallel linear algebra computations compared to CPU. -In \texttt{JAX}, jitting is achieved simply by transforming a function with \texttt{jax.jit} operation: +In \jax, jitting is achieved by transforming a function with the \texttt{jax.jit} operation: \begin{lstlisting}[language=iPython] -# Let us redefine our function +# Redefine our function def f(x): return y = 5 * x + 2 # And JIT it @@ -306,11 +314,11 @@ \section{JAX: GPU Accelerated and Automatically Differentiable Python Programmin # and run as a compiled code directly on GPU \end{lstlisting} -\textbf{Why is that interesting?} JIT makes it possible to execute entire cosmological \textit{MCMC chains directly on GPUs as compiled code}, with orders of magnitude in speedup over Python code. +\textbf{Why is this interesting?} JIT makes it possible to execute entire cosmological \textit{MCMC chains directly on GPUs as compiled code}, with orders of magnitude gain in speedup over Python code. -\paragraph{\textbf{Automatic Vectorization}} Another extremely powerful feature of \texttt{JAX} is its ability to automatically vectorize or paralellize any function. Through the same tracing mechanism used for automatic differentiation, \texttt{JAX} can decompose a given computation into fundamental operations and add a new \textit{batch} dimension so that the computation can be applied on a batch of inputs as opposed to individual inputs. In doing so, note that the computation will not run sequentially over all entries of the batch, but trully in parallel making full use of the intrinsic parallel architecture of modern GPUs. +\paragraph{\textbf{Automatic Vectorization}} Another extremely powerful feature of \jax\ is its ability to automatically vectorize or paralellize any function. Through the same tracing mechanism used for automatic differentiation, \jax\ can decompose a given computation into primitive operations and add a new \textit{batch} dimension so that the computation can be applied to a batch of inputs as opposed to individual ones. In doing so, the computation will not run sequentially over all entries of the batch, but truly in parallel making full use of the intrinsic parallel architecture of modern GPUs. -In \texttt{JAX} automatic vectorization is achieved using the \textit{jax.vmap} operation: +In \jax\ automatic vectorization is achieved using the \textit{jax.vmap} operation: \begin{lstlisting}[language=iPython] # Our function f only applies to scalars def f(x): @@ -319,28 +327,29 @@ \section{JAX: GPU Accelerated and Automatically Differentiable Python Programmin batched_f = jax.vmap(f) # batched_f now applies to 1D arrays \end{lstlisting} -Again, we stress that in this example, \texttt{batched\_f} will not be implemented in terms of a for loop, but with operations over vectors. The function above is trivial, but the same approach can be used to parallelize any function, from Limber integrals, to an entire likelihood. When considering multi-devices use-case (eg. GPU or TPU), \texttt{JAX} provides \texttt{pmap} which compile and execute in parallel replicates of the same code on each device. Moreover, recent experimental developments deal with parallelization of function over supercomputer-scale hardware meshes. Notice that in the examples detailed in this article, we have only relied on \texttt{vmap} functionality. +Again, we stress that in this example, \texttt{batched\_f} will not be implemented in terms of a for loop, but with operations over vectors. The function above is trivial, but the same approach can be used to parallelize any function, from Limber integrals, to an entire likelihood. For multi-device use-cases (e.g., several GPUs or TPUs), \jax\ provides \texttt{pmap} which compiles and executes, in parallel, replicas of the same code on each device. Moreover, recent experimental developments deal with parallelization of functions over supercomputer-scale hardware meshes. In the examples detailed in this article, we have only relied on \texttt{vmap} functionality. -\textbf{Why is that interesting?} Automatic Vectorization makes it possible to trivially parallelize cosmological likelihood evaluations, to run many parallel MCMC chains on a single GPU. +\textbf{Why is this interesting?} Automatic Vectorization makes it possible to trivially parallelize cosmological likelihood evaluations, to run many parallel MCMC chains on a single GPU. -\paragraph{\textbf{NumPy API compliance}} Finally, the last point to note about \texttt{JAX}, is that it mirrors the NumPy API (with only a few exceptions). This means in practice that existing NumPy code can easily be converted to \texttt{JAX}. This is in contrast to other similar frameworks like TensorFlow, PyTorch, or Julia which all require the user to learn, and adapt their code to, a new API or even a new language. +\paragraph{\textbf{NumPy API compliance}} Finally, the last point to note about \jax, is that it mirrors the NumPy API (with only a few exceptions). This means in practice that existing NumPy code can easily be converted to \jax. This is in contrast to other similar frameworks like TensorFlow, PyTorch, or Julia which all require the user to learn, and adapt their code to, a new API or even a new language. -\textbf{Why is that interesting?} NumPy compliance implies improved maintainability and lower barrier to entry for new contributors. +\textbf{Why is this interesting?} NumPy compliance implies improved maintainability and lower barrier to entry for new contributors. -\section{Design of the \jaxcosmo\ library} +\section{Capabilities of the \jaxcosmo\ library} \label{sec-jaxcosmo-design} -In this section, we describe the cosmological modeling provided by \jaxcosmo, and its implementation in JAX. The general design follows that of the Core Cosmology Library (CCL) \citep{2019ApJS..242....2C}, though, in its initial release, +In this section, we describe the cosmological modeling provided by \jaxcosmo, and its implementation in \jax. The general design follows that of the Core Cosmology Library (CCL; \citealp{2019ApJS..242....2C}), though in its initial release \jaxcosmo\ only implements a subset of CCL features and options. -All \jaxcosmo\ data structures are organized as \texttt{JAX} \textit{pyTree} container objects, ensuring that two key \texttt{JAX} features are available to them: \textit{autodiff} and \textit{vmap}. The vmap feature enables any operation defined in \texttt{JAX} (including complicated composite operations) to be applied efficiently as a vector operation. The autodiff feature further makes it possible to take the derivative of any operation, automatically transforming a function that takes $n$ inputs and produces $m$ outputs into a new function that generates an $m \times n$ matrix of partial derivatives. +All \jaxcosmo\ data structures are organized as \jax\ container objects, which means that two key \jax\ features are automatically available to them: \textit{autodiff} and \textit{vmap}. The vmap feature enables any operation defined in \jax\ (including complicated composite operations) to be applied efficiently as a vector operation. The autodiff feature further makes it possible to take the derivative of any operation, automatically transforming a function that takes $n$ inputs and produces $m$ outputs into a new function that generates an $m \times n$ matrix of partial derivatives. -\jaxcosmo\ implements, for example, Runge-Kutta solvers for ODEs, and Simpson \& Romberg integration routines through this framework, so that we can automatically compute the derivatives of their solutions with respect to their inputs. This includes not only the cosmological parameters (the standard set of $w_0 w_a CDM$ cosmological parameters is exposed, using $\sigma_8$ as an amplitude parameter), but also other input quantities such as redshift distributions as described below. +\jaxcosmo\ implements, for example, Runge-Kutta solvers for ODEs, as well as Simpson and Romberg integration routines through this framework, so that we can automatically compute the derivatives of their solutions with respect to their inputs. This includes not only the cosmological parameters (the standard set of $w_0 w_a CDM$ cosmological parameters is exposed, using $\sigma_8$ as an amplitude parameter), but also other input quantities such as redshift distributions as described below. In the rest of this section we describe the cosmological calculations that are implemented using these facilities. +\subsection{Formalism \& Implementation} -\subsection{Background cosmology} +\subsubsection{Background cosmology} The computation of the evolution of the cosmological background follows a typical implementation of a Friedmann equation (\citealp[see e.g.][]{2005A&A...443..819P}): \begin{equation} @@ -352,17 +361,17 @@ \subsection{Background cosmology} \end{equation} Notably, the relativistic contributions of massless neutrinos and photon radiation, as well as the massive neutrino contribution are currently neglected. From these expressions, in \texttt{jc.background}, are computed the different cosmological distance functions, such as the radial comoving distance \begin{equation} - \chi(a) = R_H \int_a^1 \frac{da^\prime}{{a^\prime}^2 E(a^\prime)} + \chi(a) = R_H \int_a^1 \frac{\mathrm{d}a^\prime}{{a^\prime}^2 E(a^\prime)} \label{eq:radial_comoving} \end{equation} with $R_H$ the Hubble radius. %Using the scale factor ($a$) redshift ($z$) relationship, $\chi$ can be viewed as a function of $z$. % -\subsection{Growth of perturbations} +\subsubsection{Growth of perturbations} % Currently, \jaxcosmo\ implements the \citet{Eisenstein_1998} transfer function $T$ which transforms the primordial matter power spectrum to its late-time non-linear value: \begin{equation} - P(k, z) = P(k, z=\infty) \cdot T(k, z; \Omega_m, \Omega_b, ...), + P(k, z) = P(k, z=\infty) \cdot T^2(k, z; \Omega_m, \Omega_b, ...), \end{equation} through the \textit{halofit} model by \cite{2012ApJ...761..152T} or \cite{2003MNRAS.341.1311S} without the neutrino contribution introduced by \cite{10.1111/j.1365-2966.2011.20222.x}. No Baryon feedback modeling is considered yet. @@ -371,26 +380,26 @@ \subsection{Growth of perturbations} P(k) = A k^{n_s}. \end{equation} -The normalisation $A$ is parametrized via $\sigma_8$ at $z=0$ as +The normalisation $A$ is parameterised via $\sigma_8$ at $z=0$ as \begin{equation} - A = \sigma_8^2 \times \left(\frac{1}{2 \pi^2} \int_0^\infty \frac{dk}{k} k^3 P(k) W^2(kR_8) \right)^{-1} + A = \sigma_8^2 \times \left(\frac{1}{2 \pi^2} \int_0^\infty \frac{\mathrm{d}k}{k} k^3 P(k) W^2(kR_8) \right)^{-1} \end{equation} with $R_8 = 8 \mathrm{Mpc}/h$ and $W(x)$ related to the $j_1$ spherical Bessel function as \begin{equation} W(x) = \frac{3j_1(x)}{x} \end{equation} -%Future version of the library would offer the possibility to call a \texttt{JAX} emulator of the Cosmic Linear Anisotropy Solving System \citep{2011JCAP...07..034B}. \JZ{How close is this to being usable? If it's just an idea or an early prototype then skip this paragraph.} +%Future version of the library would offer the possibility to call a \jax\ emulator of the Cosmic Linear Anisotropy Solving System \citep{2011JCAP...07..034B}. \JZ{How close is this to being usable? If it's just an idea or an early prototype then skip this paragraph.} % -\subsection{Angular power spectra} +\subsubsection{Angular power spectra} % \jaxcosmo\ is currently focused on predicting projected 2D Fourier-space 2pt galaxy lensing, clustering, and cross correlations, the $C_\ell$ angular power spectra that are a primary target of upcoming photometric surveys. The details of the implementation is in \texttt{jc.angular\_cl} which deals with the mean and Gaussian covariance matrix computations. The angular power spectra $C_\ell^{ij}$ for the probes $(i,j)$ and for redshift bin window selections are computed in the first order Limber approximation \citep{PhysRevD.78.123506}: \begin{align} - C_\ell^{i,j} \approx \left(\ell+\frac{1}{2}\right)^{m_i+m_j}\int\frac{d\chi}{c^2\chi^2}K_i(\chi)K_j(\chi)\,P\left(k=\frac{\ell+1/2}{\chi},z\right),\label{eq:Cell_limber} + C_\ell^{i,j} \approx \left(\ell+\frac{1}{2}\right)^{m_i+m_j}\int\frac{\mathrm{d}\chi}{c^2\chi^2}K_i(\chi)K_j(\chi)\,P\left(k=\frac{\ell+1/2}{\chi},z\right),\label{eq:Cell_limber} \end{align} -The $m_i$ factors are $(0,-2)$ for the galaxy clustering and weak lensing, respectively, and each $K(z)$ function represent a single tomographic redshift bin's number density. These tracers are implemented as two kernel functions: +The $m_i$ factors are $(0,-2)$ for the galaxy clustering and weak lensing, respectively, and each $K(z)$ function represents a single tomographic redshift bin's number density. These tracers are implemented as two kernel functions: \begin{description} \item[\texttt{NumberCounts}] @@ -402,11 +411,11 @@ \subsection{Angular power spectra} \item [\texttt{WeakLensing}] \begin{multline} K_i(z) = \left( \frac{3 H_0^2\Omega_m}{2 c} \right) \left(\frac{(\ell+2)!}{(\ell-2)!} \right)^{1/2}\times - \\ (1+z)\ \chi(z) \int_z^\infty p_i(z^\prime)\ \frac{\chi(z^\prime)-\chi(z)}{\chi(z^\prime)}\ dz^\prime + K_{IA}(z) + \\ (1+z)\ \chi(z) \int_z^\infty p_i(z^\prime)\ \frac{\chi(z^\prime)-\chi(z)}{\chi(z^\prime)}\ \mathrm{d}z^\prime + K_{IA}(z) \end{multline} where $K_{IA}(z)$ is an optional kernel function to deal with the Intrinsic Alignment. The implementation of this term currently follows \citet{2011A&A...527A..26J}, and is given by: \begin{equation} - K_{IA}(z) = \left(\frac{(\ell+2)!}{(\ell-2)!}\right)^{1/2}\ p_i(z)\ b(z)\ H(z)\ \frac{C\ \Omega_m}{D(z)} + K_{IA}(z) = \left(\frac{(\ell+2)!}{(\ell-2)!}\right)^{1/2}\ p_i(z)\ b(z)\ H(z)\ \frac{C\ \Omega_m}{D(z)} \end{equation} with $C\approx 0.0134$ being a dimensionless constant and $D(z)$ the growth factor of linear perturbations. \end{description} @@ -415,15 +424,15 @@ \subsection{Angular power spectra} % \FrL{We probably actually want to start by the Limber formula, which explicits what a kernel is, otherwise people don't necessarily know where these kernels come from.} -Because, like the other ingredients, the kernels are implemented as pyTree objects, namely \texttt{NumberCounts} and \texttt{WeakLensing}, all the integrals involved in these computations can be differentiated with respect to the cosmological parameters and the number densities, using \textit{autodiff} and \jaxcosmo's implementation of integration quadrature. An example is given in the context of DES Y1 3x2pts analysis (Sec~\ref{sec-DESY1}). +Because, like the other ingredients, the \texttt{NumberCounts} and \texttt{WeakLensing} kernels are implemented as JAX objects, all the integrals involved in these computations can be differentiated with respect to the cosmological parameters and the number densities, using \textit{autodiff} and \jaxcosmo's implementation of integration quadrature. An example is given in the context of DES Y1 3x2pts analysis (Sec~\ref{sec-DESY1}). % \subsection{Validation against the Core Cosmology Library (CCL)} % -To illustrate the different features available with the present version of the library (\jaxcosmo\ \texttt{0.1rc9}, which is available in the Python Package Index PyPI\footnote{\url{https://pypi.org/}}), we have written a companion notebook \nblink{CCL_comparison} to compare it to the well-validated Core Cosmology Library \citep{2019ApJS..242....2C}\footnote{\url{https://ccl.readthedocs.io}, version \texttt{2.5.1}.}. As examples, Figures \ref{fig:chi_comparison}, +To illustrate the different features available with the present version of the library (\jaxcosmo\ \texttt{0.1}, which is available in the Python Package Index PyPI\footnote{\url{https://pypi.org/}}), we have written a companion notebook \nblink{CCL_comparison} to compare it to the well-validated Core Cosmology Library \citep{2019ApJS..242....2C}\footnote{\url{https://ccl.readthedocs.io}, version \texttt{2.5.1}.}. As examples, Figures \ref{fig:chi_comparison}, \ref{fig:halofit_comparison} and \ref{fig:Cell_comparison} -show the radial comoving distance (Eq.~\ref{eq:radial_comoving}), the non-linear matter power spectrum computation, and the angular power spectrum for galaxy-galaxy lensing (Eq.~\ref{eq:Cell_limber}) using the \texttt{NumberCounts} and \texttt{WeakLensing} kernel functions. \jaxcosmo\ features a suite of validation tests against CCL, automatically validating the precision of all computations to within the desired numberical accuracy; the relative differences between the two libraries are at the level of few $10^{-3}$ or better. +show the radial comoving distance (Eq.~\ref{eq:radial_comoving}), the non-linear matter power spectrum computation, and the angular power spectrum for galaxy-galaxy lensing (Eq.~\ref{eq:Cell_limber}) using the \texttt{NumberCounts} and \texttt{WeakLensing} kernel functions. \jaxcosmo\ features a suite of validation tests against CCL, automatically validating the precision of all computations to within the desired numerical accuracy; the relative differences between the two libraries are at the level of few $10^{-3}$ or better. -These numerical differences are mostly due to different choices of integration methods and accuracy parameters (eg. quadrature number of points). Increasing these parameters lead to performance degradation for \jaxcosmo\ as it increases the XLA compilation memory requirements significantly, especially for the angular power spectra computation. Since these differences are likely to be within the tolerance of the current generation of cosmological surveys, this trade-off is an acceptable one. +These numerical differences are mostly due to different choices of integration methods and accuracy parameters (eg. number of quadrature points). Increasing these parameters leads to performance degradation for \jaxcosmo\, but increases the XLA compilation memory requirements significantly, especially for the angular power spectra computation. Since these differences are likely to be within the tolerance of the current generation of cosmological surveys, this trade-off is an acceptable one. \begin{figure} \centering @@ -448,7 +457,7 @@ \subsection{Validation against the Core Cosmology Library (CCL)} \section{Fisher Forecasts \& Data Compression} \label{sec-fisher-forecast} -As a first illustration of the value of \textit{autodiff} in cosmological computations, we present in this section a few direct applications involving the computation of the Fisher information matrix, widely used in cosmology. \\ +As a first illustration of the value of \textit{autodiff} in cosmological computations, we present in this section a few direct applications involving the computation of the Fisher information matrix, widely used in cosmology \citep{1997ApJ...480...22T,Stuart1991}. \\ Not only does the computation of the Fisher matrix become trivial, but the Fisher matrix itself becomes differentiable, allowing in particular for powerful survey optimization applications. %JEC 5 August 22: not sure that should be given here as we have an Appendix for that and the main application described in the following sections concern DES Y1. @@ -460,7 +469,7 @@ \section{Fisher Forecasts \& Data Compression} \subsection{Instantaneous Fisher Forecasts} Fisher matrices are a key tool in cosmology forecasting and experimental planning. By computing the Hessian matrix of a likelihood with respect to its parameters, we can find a Gaussian approximation to a posterior, which is usually sufficient for comparing the constraining power of different experimental configurations. As noted above, computing Fisher matrices is notoriously error-prone, since finite difference approximations to likelihoods must be carefully tuned for convergence, and observable calculations can be numerically unstable. Autodiff can help avoid this challenge. -We first illustrate, with an artificial case study, the computation of a Fisher matrix \citep{1997ApJ...480...22T,Stuart1991} using two methods with the \textit{autodiff} ability of JAX. For the detailed implementation, the reader is invited to look at the following companion notebook \nblink{Simple-Fisher}. In this example we use four tracer redshift distributions: two to define \texttt{WeakLensing} kernels and two for \texttt{NumberCounts} kernels. Then, the $10$ angular power spectra $C_\ell^{p,q}$ ($p,q:1,\dots,4$) are computed for $50$ logarithmically-spaced angular moments between $\ell=10$ and $\ell=1000$ using Equation \ref{eq:Cell_limber}. The Gaussian covariance matrix is computed simultaneously. A dataset is obtained from the computation of the $C_\ell^{p,q}$ with a fiducial cosmology. Then, the following snippet shows the log likelihood function $\mathcal{L}(\theta)$ implementation considering a constant covariance matrix ($\theta$ stands for the set of cosmological parameters). +We first illustrate, with an artificial case study, the computation of a Fisher matrix using two methods with the \textit{autodiff} ability of \jax. For the detailed implementation, the reader is invited to look at the following companion notebook \nblink{Simple-Fisher}. In this example we use four tracer redshift distributions: two to define \texttt{WeakLensing} kernels and two for \texttt{NumberCounts} kernels. Then, the $10$ angular power spectra $C_\ell^{p,q}$ ($p,q:1,\dots,4$) are computed for $50$ logarithmically-spaced angular moments between $\ell=10$ and $\ell=1000$ using Equation \ref{eq:Cell_limber}. The Gaussian covariance matrix is computed simultaneously. A dataset is obtained from the computation of the $C_\ell^{p,q}$ with a fiducial cosmology. Then, the following snippet shows the log likelihood function $\mathcal{L}(\theta)$ implementation considering a constant covariance matrix ($\theta$ stands for the set of cosmological parameters). %\begin{minted}[fontsize=\footnotesize]{python} \begin{lstlisting}[language=iPython] @jax.jit @@ -477,7 +486,7 @@ \subsection{Instantaneous Fisher Forecasts} return -0.5 * r.T @ jc.sparse.sparse_dot_vec(P, r) \end{lstlisting} %\end{minted} -The \texttt{jc.sparse} functions are implementation of block matrix computations: a sparse matrix is represented as a 3D array of shape $(n_y, n_x, n_{diag})$ composed of $n_y \times n_x$ square blocks of size $n_{diag} \times n_{diag}$. The \texttt{jax.jit} decorator builds a compiled version of the function at the first use. +The \texttt{jc.sparse} functions are implementations of block matrix computations: a sparse matrix is represented as a 3D array of shape $(n_y, n_x, n_{diag})$ composed of $n_y \times n_x$ square blocks of size $n_{diag} \times n_{diag}$. The \texttt{jax.jit} decorator builds a compiled version of the function on first use. The first approach to obtaining approximate 1-sigma contours of the two parameters ($\Omega_c, \sigma_8$) with a Fisher matrix uses the Hessian of the log-likelihood as follows: \begin{equation} @@ -485,28 +494,28 @@ \subsection{Instantaneous Fisher Forecasts} \qquad (\theta_1=\Omega_c, \theta_2=\sigma_8) \label{eq:fisher_way1} \end{equation} -which is accomplished in two lines of \texttt{JAX} code: +which is accomplished in two lines of \jax\ code: \begin{lstlisting}[language=iPython] hessian_loglike = jax.jit(jax.hessian(likelihood)) F = - hessian_loglike(params) \end{lstlisting} -The second approach to computing the Fisher matrix, which is restricted to Gaussian likelihoods but is more commonly used in the field because of its numerical stability, is to define a function that computes the observable mean $\mu(\ell; \theta)$; it follows that the Fisher matrix elements are +The second approach to computing the Fisher matrix, which is restricted to Gaussian likelihoods but is more commonly used in the field because of its numerical stability, is to define a function that computes the summary statistic mean $\mu(\ell; \theta)$; the Fisher matrix elements are then: \begin{equation} F_{i,j} = \sum_\ell \frac{\partial \mu(\ell)}{\partial \theta_i}^T C^{-1}(\ell)\frac{\partial \mu(\ell)}{\partial \theta_j} \label{eq:fisher_way2} \end{equation} -where $C^{-1}(\ell)$ is the covariance matrix computed with the fiducial cosmology. This can also be simply computed in JAX: +where $C^{-1}(\ell)$ is the covariance matrix computed with the fiducial cosmology. This can be computed in \jax\ as: \begin{lstlisting}[language=iPython] -# We define a parameter dependent function that computes the mean +# We define a parameter dependent mean function @jax.jit def jc_mean(p): cosmo = jc.Planck15(Omega_c=p[0], sigma8=p[1]) # Compute signal vector mu = jc.angular_cl.angular_cl(cosmo, ell, tracers) - # We want mu in 1d to operate against the covariance matrix + # We want mu in 1d to match the covariance matrix return mu.flatten() -# We compute it's jacobian with JAX, and we JIT it for efficiency +# We compute its jacobian with JAX, and we JIT it for efficiency jac_mean = jax.jit(jax.jacfwd(jc_mean)) # We can now evaluate the jacobian at the fiducial cosmology dmu = jac_mean(params) @@ -514,15 +523,15 @@ \subsection{Instantaneous Fisher Forecasts} F = jc.sparse.dot(dmu.T, jc.sparse.inv(cov), dmu) \end{lstlisting} -JAX implementations of the two methods agree to near perfect accuracy, as shown in Figure \ref{fig:simple_fisher_1}. +\jax\ implementations of the two methods agree to near perfect accuracy, as shown in Figure \ref{fig:simple_fisher_1}. \begin{figure} \centering \includegraphics[width=0.7\columnwidth]{figures/simple_fisher_1.png} - \caption{Comparison of the two methods to compute the Fisher matrix (Eqs.\ref{eq:fisher_way1},\ref{eq:fisher_way2}).} + \caption{Comparison of the two methods to compute the Fisher matrix: method 1 (method 2) is using Eq.~\ref{eq:fisher_way1} (Eq.~\ref{eq:fisher_way2}).} \label{fig:simple_fisher_1} \end{figure} % -It is worth noting that in the two methods for Fisher matrix computations described above, the user does not need to vary individual parameter values to compute the 1st or 2nd order derivatives; this is in contrast to the usual finite difference methods. As an illustration, we have used the CCL library to compute the Fisher matrix via Equation \ref{eq:fisher_way2}. To do so, the Jacobian ($\partial \mu/\partial\theta_\alpha$) is computed with centered order 4 finite differences available in the \texttt{Numdifftools} Python package\footnote{\url{https://numdifftools.readthedocs.io}}. Using the parameter values spaced by ($10^{-6}$, $10^{-2}$, $10^{-1}$) one obtains three different approximation of the 1-sigma contours as shown on Figure \ref{fig:simple_fisher_2}. The contour that agrees best with the \jaxcosmo\ method is obtained with the intermediate spacing parameter $10^{-2}$, implying that the user must tune this parameter carefully. Although very simple, this case study exhibits a significant challenge of finite difference methods for computing the Fisher matrix as has been shown for instance in a more advanced case study in \citet{2021arXiv210100298B}. +It is worth noting that in the two methods for Fisher matrix computations described above, the user does not need to vary individual parameter values to compute the 1st or 2nd order derivatives; this is in contrast to the usual finite difference methods. As an illustration, we have used CCL to compute the Fisher matrix via Equation \ref{eq:fisher_way2}. To do so, the Jacobian ($\partial \mu/\partial\theta_\alpha$) is computed with centered order 4 finite differences available in the \texttt{Numdifftools} Python package\footnote{\url{https://numdifftools.readthedocs.io}}. Using the parameter values spaced by ($10^{-6}$, $10^{-2}$, $10^{-1}$) one obtains three different approximation of the 1-sigma contours as shown on Figure \ref{fig:simple_fisher_2}. The contour that agrees best with the \jaxcosmo\ method is obtained with the intermediate spacing parameter $10^{-2}$, implying that the user must tune this parameter carefully. Although very simple, this case study demonstrates the significant challenge of using finite difference methods for computing the Fisher matrix, as has been shown for instance in a more advanced case study in \citet{2021arXiv210100298B}. \begin{figure} \centering @@ -537,11 +546,11 @@ \subsection{Survey Optimization by FoM Maximization} % Needs to mention the tomo challenge, mention that we can backpropagate through a NN and the cosmology model % Maybe this can be only mentioned in the discussion. % \FrL{Insist here on the fact that the FoM itself is differentiable.} -Fisher forecasts are also commonly used in survey and analysis strategy strategy, where running a full posterior analysis for each possible choice would be unfeasible. The inverse area of a projection of a Fisher matrix in parameters of interest can be used as a metric for survey constraining power, such as in the Dark Energy Task Force report \citep{2006astro.ph..9591A}. +Fisher forecasts are also commonly used in survey and analysis strategy, where running a full posterior analysis for each possible choice would be unfeasible. The inverse area of a projection of a Fisher matrix in parameters of interest can be used as a metric for survey constraining power, such as in the Dark Energy Task Force report \citep{2006astro.ph..9591A}. \jaxcosmo\ was used in a recent example of such a process, for the the LSST-DESC 3x2pt tomography optimization challenge \citep{2021OJAp....4E..13Z}, where the best methodology for assignment of galaxies to tomographic bins was assessed using several such figures of merit and related calculations. The \jaxcosmo\ metric proved to be stable and fast. -Because \texttt{JAX} functions are differentiable with respect to all their inputs, including survey configuration parameters (e.g. depth, area, etc), we can even compute the derivative of an FoM with respect to these inputs, allowing for rapid and complete survey optimization. +Because \jax\ functions are differentiable with respect to all their inputs, including survey configuration parameters (e.g. depth, area, etc), we can even compute the derivative of an FoM with respect to these inputs, allowing for rapid and complete survey optimization. \subsection{Massive Optimal Compression in 3 Lines} @@ -550,7 +559,7 @@ \subsection{Massive Optimal Compression in 3 Lines} Once the Fisher matrix has been accurately estimated, the MOPED\footnote{Massively Optimised Parameter Estimation and Data compression} algorithm can be used to compress data sets with minimal information loss \citep{2000MNRAS.317..965H,2016PhRvD..93h3525Z, 2017MNRAS.472.4244H}. In the case of the constant covariance matrix the algorithm compresses data in a way that is lossless at the Fisher matrix level (i.e. Fisher matrices estimated using the compressed and full data are identical, by construction) which reduces a possibly large data vector $\mu$ of size $N$ to $M$ numbers, where $M$ is the number of parameters $\theta_i$ considered. For instance, in the previous section, $N=500$ as $\mu=(C_\ell^{p,q})$ and $M=2$ for $(\Omega_c,\sigma_8)$. The algorithm computes by iteration $M$ vectors of size $N$ such that (taking the notation of the previous section) -%JEC 13/1/23 Typo fixed + \begin{equation} b_i = \frac{C^{-1}\mu_{,i}-\sum_{j=1}^{i-1}(\mu_{,i}^T b_j)b_j}{\sqrt{F_{i,i}-\sum_{j=1}^{i-1}(\mu_{,i}^T b_j)^2}} \label{eq:moped} @@ -572,7 +581,7 @@ \subsection{Massive Optimal Compression in 3 Lines} % \FrL{hummm but we still do need a full covariance in the compression algorithm, right?} \FrL{And we still do need to invert it... is the idea rather that if we dont have a perfect matrix at that stage the worse that can happen is sub optimal compression?} %JEC{Well there are several ways that $C$ can be questioned: parameter dependence (ie. fiducial model/true model), non-gaussianities, errors... So not sure that all these effects induce only sub-optimal compression. But it may be out of the scope of this paragraph/paper.} -Another key advantage of the MOPED algorithm is to eliminate the need of large covariance matrix inversion of size $N\times N$ requiring $O(N^3)$ operations. This inversion takes place not only for the Fisher matrix computation (Eq.~\ref{eq:fisher_way2}, but more importantly in the log-likelihood computation (see the snippet of the previous section). The MOPED algorithm reduces the complexity to $O(M)$ operations. +Another key advantage of the MOPED algorithm is to eliminate the need for large covariance matrix inversion of size $N\times N$ requiring $O(N^3)$ operations. This inversion takes place not only for the Fisher matrix computation (Eq.~\ref{eq:fisher_way2}), but more importantly in the log-likelihood computation (see the snippet in the previous section). The MOPED algorithm reduces the complexity to $O(M^3)$ operations. To give an illustration, the following snippet uses the mock data set and results on the Fisher matrix computation from the previous section, to obtain the MOPED compressed data set composed of two parameters $(y_0,y_1)$ with maximal information on $(\Omega_c, \sigma_8)$: \begin{lstlisting}[language=iPython] @@ -584,7 +593,7 @@ \subsection{Massive Optimal Compression in 3 Lines} y0 = b0.T @ data y1 = b1.T @ data \end{lstlisting} -Then, the log-likelihood can easily be implemented as: +Then, the log-likelihood can be implemented as: \begin{lstlisting}[language=iPython] @jax.jit def compressed_likelihood(p): @@ -595,10 +604,10 @@ \subsection{Massive Optimal Compression in 3 Lines} # likelihood using the MOPED vector return -0.5 * ((y0 - b0.T @ mu)**2 + (y1 - b1.T @ mu)**2) \end{lstlisting} -The comparison between contour lines obtained with the original likelihood (uncompressed data set) and the MOPED version are shown in Figure \ref{fig:moped} for the case study of the previous section. Clearly, close to the negative likelihood minimum the lines agree very well. +The comparison between contour lines obtained with the original likelihood (uncompressed data set) and the MOPED version are shown in Figure \ref{fig:moped} for the case study of the previous section. Close to the negative likelihood minimum, the lines agree very well. \begin{figure} \centering - \includegraphics[width=\columnwidth]{figures/moded.png} + \includegraphics[width=\columnwidth]{figures/moped.png} \caption{Illustration of the log-likelihood contours obtained with an uncompressed data set (plain lines) and a MOPED version (dashed lines).} \label{fig:moped} \end{figure} @@ -609,7 +618,7 @@ \subsection{Massive Optimal Compression in 3 Lines} \section{Posterior Inference made fast by Gradient-Based Inference Methods} \label{sec:chmc} % -In the following sections we review some statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn Sampler} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report on their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). +In the following sections we review more statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn Sampler} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report on their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). % % and NUTS after a \textit{Neural Transport} perform using a \textit{Stochastic Variational Inference} (aka SVI). All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). % @@ -619,16 +628,16 @@ \section{Posterior Inference made fast by Gradient-Based Inference Methods} \subsection{Description of the DES-Y1 exercise} \label{sec-DESY1} %JEC{Transfer of the model context from appendix here to ease the reading.} -From the DES Year 1 lensing and clustering data release \footnote{\url{http://desdr-server.ncsa.illinois.edu/despublic/y1a1_files/chains/2pt_NG_mcal_1110.fit}} we have extracted the $N(z)$ distributions of four source and five lens samples. We normalize the sources to $[1.47, 1.46, 1.50, 0.73]$ effective number of sources per $\mathrm{arcmin}^2$. These distributions are modelled in \jaxcosmo\ using a kernel density estimation in the \texttt{jc.redshift.kde\_nz} function and are presented in Figure \ref{fig-DESY1-src-lens-redshift}. +From the DES Year 1 lensing and clustering data release\footnote{\url{http://desdr-server.ncsa.illinois.edu/despublic/y1a1_files/chains/2pt_NG_mcal_1110.fit}} we have extracted the $N(z)$ distributions of the four source and five lens samples. We normalize the sources to $[1.47, 1.46, 1.50, 0.73]$ effective number of sources per $\mathrm{arcmin}^2$. These distributions are modelled in \jaxcosmo\ using a kernel density estimation in the \texttt{jc.redshift.kde\_nz} function and are presented in Figure \ref{fig-DESY1-src-lens-redshift}. \begin{figure} \centering -\includegraphics[height=3cm]{figures/DESY1-source-redshift.png} -\includegraphics[height=3cm]{figures/DESY1-lens-redshift.png} +\includegraphics[height=5cm]{figures/DESY1-source-redshift.png}\\ +\includegraphics[height=5cm]{figures/DESY1-lens-redshift.png} \caption{Distributions of the sources and lenses for the different redshift bins considered.} \label{fig-DESY1-src-lens-redshift} \end{figure} \begin{table}[htb] -\caption{Priors of the 21 variables of the DES-Y1 of the 3x2pt likelihood (number counts and shear).} +\caption{Priors on the 21 variables of the DES-Y1 of the 3x2pt likelihood (number counts and shear).} \label{tab-DESY1} \centering \begin{tabular}{ccccccccccc} @@ -677,7 +686,7 @@ \subsection{Description of the DES-Y1 exercise} f_sky=0.25, sparse=True) \end{lstlisting} -with \texttt{cosmo} an instance of the \texttt{jc.Cosmology} setting the cosmological parameters generated with the priors, and \texttt{ell} (ie. $\ell$) a series of 50 angular modes. After encapsulating the code above with input sampled parameters (using \numpyro\ distribution classes) in a function \texttt{model}, we can generate our mock data (\texttt{cl\_obs}). This comes comes from this model function evaluated at a fiducial cosmology, with \numpyro\ dealing with the random number generation that generates added noise: +with \texttt{cosmo} an instance of the \texttt{jc.Cosmology} setting the cosmological parameters generated with the priors, and \texttt{ell} (ie. $\ell$) a series of 50 angular modes. After encapsulating the code above with input sampled parameters (using \numpyro\ distribution classes) in a function \texttt{model}, we can generate our mock data (\texttt{cl\_obs}). This comes comes from this model function evaluated at a fiducial cosmology, with \numpyro\ dealing with the random number generation that adds noise: \begin{lstlisting}[language=iPython] fiducial_model = numpyro.condition(model, @@ -693,19 +702,19 @@ \subsection{Description of the DES-Y1 exercise} cl_obs, P, C = fiducial_model() \end{lstlisting} -The \textit{inference model} looks very similar to the \textit{forward model} excepts that first we fix the angular power spectra covariance matrix $C$ (and its inverse $P$) and secondly we let \numpyro\ generates the model $C_{\ell}$ from the priors: -\begin{lstlisting}[language=iPython] -cl = jc.angular_cl.angular_cl(cosmo, ell, - probes).flatten() -return numpyro.sample("cl", MultivariateNormal(cl, - precision_matrix=P, - covariance_matrix=C)) -\end{lstlisting} +% The \textit{inference model} looks very similar to the \textit{forward model} excepts that first we fix the angular power spectra covariance matrix $C$ (and its inverse $P$) and secondly we let \numpyro\ generate the model $C_{\ell}$ from the priors: +% \begin{lstlisting}[language=iPython] +% cl = jc.angular_cl.angular_cl(cosmo, ell, +% probes).flatten() +% return numpyro.sample("cl", MultivariateNormal(cl, +% precision_matrix=P, +% covariance_matrix=C)) +% \end{lstlisting} These theoretical $C_{\ell}$ are in turn conditioned on the mock $C_\ell$: \begin{lstlisting}[language=iPython] -observed_model = numpyro.condition(model, {"cl": data}) +observed_model = numpyro.condition(model, {"cl": cl_obs}) \end{lstlisting} @@ -735,11 +744,11 @@ \subsection{Vanilla Hamiltonian Monte Carlo} % % Show Joe's vanilla HMC results against Cobaya Hamiltonian Monte Carlo (HMC) is an MCMC-type method particularly suited to drawing -samples from high dimensional parameter spaces. It was introduced in \citep{1987PhLB..195..216D} +samples from high dimensional parameter spaces. It was introduced in \citet{1987PhLB..195..216D} and developed extensively since. See \citet{betancourt} for a full review; we describe very basic features here. -HMC samples spaces by generating particle trajectories through the space, using the log-posterior as the negative potential energy of a particle at each point $q$ in the space. Associated with $q$, we introduce an auxiliary $p$ variable as Hamiltonian momentum such that +HMC samples a space by generating particle trajectories through it, using the log-posterior as the negative potential energy of a particle at each point $q$ in the space. Associated with $q$, we introduce an auxiliary $p$ variable as Hamiltonian momentum such that \begin{equation} - \log{\cal P}(q) = V(q) \quad H(q,p) = V(q) + U(p) \end{equation} @@ -747,12 +756,12 @@ \subsection{Vanilla Hamiltonian Monte Carlo} \begin{equation} U(p) = p^T M^{-1} p \end{equation} -At each sample, a trajectory is initialized with a random momentum $p$, and then Hamilton's equations are integrated: +where $M$ is a mass matrix which should be set to approximate the covariance of the posterior. At each sample, a trajectory is initialized with a random momentum $p$, and then Hamilton's equations are integrated: \begin{align} \frac{\mathrm{d}p}{\mathrm{d}t} &= - \frac{\partial V}{\mathrm{d} q} = \frac{\partial \log{\cal P}}{\mathrm{d} q} \\ \frac{\mathrm{d}q}{\mathrm{d}t} &= + \frac{\partial U}{\mathrm{d} p} = M^{-1} p \end{align} -where $M$ is a mass matrix which should be set to approximate the covariance of the posterior. This is also used to set the scale of the random initial velocities. These differential equations may be integrated numerically, taking $L$ small steps of the \textit{leapfrog} algorithm: +This is also used to set the scale of the random initial velocities. These differential equations may be integrated numerically, taking $L$ small steps of the \textit{leapfrog} algorithm: \begin{align} p_{n+\frac{1}{2}} &= p_n -\frac{\varepsilon}{2} \frac{\partial V}{\mathrm{d} q}(q_n) \\ q_{n+1} & = q_n +\varepsilon M^{-1} p_{n+\frac{1}{2}} \\ @@ -765,15 +774,15 @@ \subsection{Vanilla Hamiltonian Monte Carlo} and a Metropolis-Hastings acceptance criterion on the total energy $H(q,p)$ is applied. If the trajectory is perfectly simulated then this acceptance is unity, since energy is conserved; applying it allows a relaxation of the integration accuracy. -The gradients $\partial \log{\cal P} / \mathrm{d} q$ can be estimated using finite differencing, -but this typically requires at least $4 n_{\mathrm{dim}} + 1$ posterior evaluations per point, greatly slowing it +The gradients $\partial \log{\cal P} / \mathrm{d} q$ can be estimated using finite differences, +but this requires at least $2 n_{\mathrm{dim}} + 1$ posterior evaluations per point, greatly slowing it in high dimension, and as with the Fisher forecasting is highly prone to numerical error. Automatically calculating the derivative, as in \jaxcosmo, makes it feasible and efficient. Metropolis-Hastings, and related methods like \texttt{emcee} \citep{goodman-weare,emcee}, suffer as dimensionality increases, as the region of high probability mass (the \textit{typical set}) becomes a very small fraction of the total parameter space volume. At high enough dimension they become a slow random walk around the space and cannot remain in typical set regions. The dynamics of HMC allows it to make large jumps that nonetheless stay within the region of high posterior. -The tricky part of HMC is that the \textit{leapfrog} algorithm needs tuning to set the number of steps as well as the step size of integration. The next section addresses this problem thanks to the No-U-Turn HMC version. +The tricky part of HMC is that the \textit{leapfrog} algorithm needs tuning to set the number of steps as well as the step size of integration. The next section examines a solution to this problem: the No-U-Turn HMC version. @@ -812,10 +821,10 @@ \subsection{NUTS} The No-U-Turn Sampler (\textit{NUTS}) variant of the traditional HMC sampler was introduced in \citet{10.5555/2627435.2638586}. It aims to help finding a new point $x_{i+1}$ from the current $x_i$ by finding good and dynamic choices for the leapfrog integration parameters in the root HMC algorithms, the step size $\varepsilon$ and the number of steps $L$. -NUTS sets iterates the leapfrog algorithm not for a fixed $L$, but until the trajectory starts to ``double back'' and return to previously visited region, at the cost of increasing the number of model evaluations. The user has to set a new parameter (\texttt{max\_tree\_depth}) which gives as a power of $2$ the maximum number of model calls at each generation. +NUTS iterates the leapfrog algorithm not for a fixed $L$, but until the trajectory starts to ``double back'' and return to previously visited region, at the cost of increasing the number of model evaluations. The user has to set a new parameter (\texttt{max\_tree\_depth}) which gives as a power of $2$ the maximum number of model calls at each generation. Both sampler HMC and NUTS are available in the \numpyro\ library. After the forward model creation for the DES-Y1 3x2pt exercise described in section \ref{sec-DESY1}: \begin{itemize} - \item we apply an affine transformation on the cosmological, intrinsic alignment and bias parameters to use a consistent uniform prior $\mathcal{U}[-5,5]$ (Table~\ref{tab-DESY1}); + \item we apply a transformation to the cosmological, intrinsic alignment and bias parameters to use a consistent uniform prior $\mathcal{U}[-5,5]$ (Table~\ref{tab-DESY1}); \item we use a structured mass matrix $M$ in a block diagonal form with the blocks as the following sets of parameters $(\Omega_b,\Omega_c,\sigma_8,w_0,h)$ and $(b_i)_{i:1\dots5}$. The remaining parameters have uncorrelated masses. This matrix structure is motivated by the expected degree of parameter correlation as shown for instance in Figure \ref{fig_cobaya_NUTS_SVI_bis}. \end{itemize} @@ -838,12 +847,12 @@ \subsection{NUTS} % mcmc.run(jax.random.PRNGKey(42)) % \end{lstlisting} -We have run the NUTS sampler using \texttt{numpyro.infer.NUTS} on the DES Y1 likelihood, with 16 chains of 1,000 samples each after a warm-up phase consisting of 200 samples, with the \texttt{max\_tree\_depth} set to seven (ie. 128 steps for each iteration). %\FrL{it should explain that vectorized means 16 chains in parallel on a single GPU.} -Using the ``vectorized'' \texttt{numpyro} option we ran all 16 chains simultaneously on a single GPU, made possible by the \texttt{JAX} \textit{vmap} mechanism. If one has several GPU devices available, then the using the \texttt{JAX} paralellization mechanism (\textit{pmap}), it is further possible to launch the vectorized sampling across the devices, and get back all the MCMC chains. However, the experiments have all been undertaken on a single GPU, either a NVidia Titan Xp (12GB RAM) on a desktop or an NVidia V100 (32GB RAM) at the IN2P3 Computing Centre\footnote{\url{https://cc.in2p3.fr/en/}}. The elapsed time for these experiments was 20 hours. +We ran the NUTS sampler using \texttt{numpyro.infer.NUTS} on the DES Y1 likelihood, with 16 chains of 1,000 samples each after a warm-up phase consisting of 200 samples, with the \texttt{max\_tree\_depth} set to seven (ie. 128 steps for each iteration). %\FrL{it should explain that vectorized means 16 chains in parallel on a single GPU.} +Using the ``vectorized'' \texttt{numpyro} option we ran all 16 chains simultaneously on a single GPU, made possible by the \jax\ \textit{vmap} mechanism. If one has several GPU devices available, then the using the \jax\ paralellization mechanism (\textit{pmap}), it is further possible to launch the vectorized sampling across the devices, and get back all the MCMC chains. However, these experiments have all been undertaken on single GPUs, either an NVidia Titan Xp (12GB RAM) on a desktop or an NVidia V100 (32GB RAM) at the IN2P3 Computing Centre\footnote{\url{https://cc.in2p3.fr/en/}}. The elapsed time for these experiments was 20 hours. -The results in terms of relative effective sample sizes (ESS) are detailed in Table~\ref{tab-ESS-NUTS_SVI-1} while the confidential level (CL) contours are presented in Figure \ref{fig_cobaya_NUTS_SVI}. We compare to a reference sample from the highly-optimized \texttt{Cobaya} Metropolis-Hastings implementation \citep{2019ascl.soft10019T,2021JCAP...05..057T}, which is widely used in cosmology and which we ran for around 40 hours on CPU to obtain the set of contours shown. +The results in terms of relative effective sample sizes (ESS) are detailed in Table~\ref{tab-ESS-NUTS_SVI-1} while the confidence level (CL) contours are presented in Figure \ref{fig_cobaya_NUTS_SVI}. We compare to a reference sample from the highly-optimized \texttt{Cobaya} Metropolis-Hastings implementation \citep{2019ascl.soft10019T,2021JCAP...05..057T}, which is widely used in cosmology and which we ran for around 40 hours on CPU to obtain the set of contours shown. -There is a dramatic improvement ofinthe ESS by about a factor of 10 using the NUTS sampler compared to Cobaya, with very good agreement between the CL contours. It is worth noting that the mass matrix structure described above increases the sampling efficiency by about a factor of two. +There is a dramatic improvement of the ESS by about a factor of 10 using the NUTS sampler compared to Cobaya, with very good agreement between the CL contours. It is worth noting that the mass matrix structure described above increases the sampling efficiency by about a factor of two. The speed of the sampling could be improved: we have tested using the parameter \texttt{max\_tree\_depth=5} and found convergence in five hours, showing a linear scaling in this parameter while keeping the sampling efficiencies at a high level; the user is highly encouraged to tune this critical parameter. @@ -870,7 +879,7 @@ \subsection{Stochastic Variational Inference} \log p(\mathcal{D}) &= \mathtt{ELBO} + KL(q(z;\lambda)||p(z|\mathcal{D})) \label{eq-ELBO} \\ \mathrm{with} \ \mathtt{ELBO} &\equiv -\mathbb{E}_{q(z;\lambda)}\left[ \log q(z;\lambda)\right] + \mathbb{E}_{q(z;\lambda)}\left[ \log p(z,\mathcal{D}) \right] \end{align} -which defines the \textit{evidence lower bound} (aka ELBO) that one aims to maximize to get the $\lambda$ values. So, the optimal variational distribution satisfies +which defines the \textit{evidence lower bound} (ELBO) that one aims to maximize to get the $\lambda$ values. So, the optimal variational distribution satisfies \begin{equation} q(z;\lambda^\ast) = \underset{q(z;\lambda)}{\mathrm{argmax}}\ \mathtt{ELBO} = \underset{\lambda}{\mathrm{argmin}}\ \mathcal{L}(\lambda) @@ -880,10 +889,10 @@ \subsection{Stochastic Variational Inference} \mathcal{L}(\lambda) = \underbrace{\mathbb{E}_{q(z;\lambda)}\left[ \log q(z;\lambda)\right]}_{guide} - \underbrace{\mathbb{E}_{q(z;\lambda)}\left[ \log p(z,\mathcal{D}) \right]}_{model} \label{eq-loss-svi-1} \end{equation} -where the \textit{guide} in the \numpyro\ library (i.e. the parametrized function $q$) may be a multi-variate Gaussian distribution (MVN) for instance. +where the \textit{guide} in the \numpyro\ library (i.e. the parameterised function $q$) may be a multi-variate Gaussian distribution (MVN) for instance. -Using the auto-differentation tool, one can use ``black-box'' guides (aka \textit{automatic differentiation variational inference}). As stated by the authors of \citep{10.5555/3122009.3122023} ADVI specifies a variational family appropriate to the model, computes the corresponding objective -function, takes derivatives, and runs a gradient-based or coordinate-ascent optimization. First we define a invertible differentiable transformation $T$ of the original latent variables $z$ into new variables $\xi$, such $\xi=T(z)$ and $z=T^{-1}(\xi)$, where the new $\xi$ parameters are unbounded, $xi_i in (\-infty, \infty)$ and so the subsequent minimization step can be performed with no bound constraints. +Using the auto-differentiation tool, one can use ``black-box'' guides (aka \textit{automatic differentiation variational inference}). As stated by the authors of \citep{10.5555/3122009.3122023} ADVI specifies a variational family appropriate to the model, computes the corresponding objective +function, takes derivatives, and runs a gradient-based or coordinate-ascent optimization. First we define a invertible differentiable transformation $T$ of the original latent variables $z$ into new variables $\xi$, such $\xi=T(z)$ and $z=T^{-1}(\xi)$, where the new $\xi$ parameters are unbounded, $\xi_i \in (-\infty, \infty)$ and so the subsequent minimization step can be performed with no bound constraints. % One can develop ``AutoGuides'' (\numpyro\ terminology) that can be adapted to the user models. The cost function then reads \begin{equation} @@ -934,7 +943,7 @@ \subsection{Stochastic Variational Inference} % samples = guide.sample_posterior(jax.random.PRNGKey(1), svi_result.params, sample_shape=(100_000,)) % \end{lstlisting} % -Once the optimisation is done one can obtain i.i.d. $z$ samples from the $q(z,\lambda^\ast)$ distribution applying the inverse of the inverse of the $S_{\lambda^\ast}$ and $T$ transformations. Using the same DES-Y1 simulation as in previous section, we use both a MultiVariate Normal distribution (MVN) and a Block Nueral Autoregressive Flow (B-NAF) \citep{pmlr-v115-de-cao20a} as \textit{guides} to approximate the true posterior (\textit{model}). The B-NAF architecture is composed of a single flow using a block autoregressive structure with 2 hidden layers of 8 units each. The SVI optimization has been performed with the Adam optimizer \citep{KingmaB14} and a learning rate set to $10^{-3}$. We have stopped the optimization after 20,000 (30,000) steps to ensure a stable ELBO loss convergence when using the B-NAF (MVN) guides, and no tuning of the optimizer has been undertaken as for instance a learning rate scheduling. It takes about 2 or 3 hours on the NVidia V100 GPU scaling according to the number of steps. +Once the optimisation is done one can obtain i.i.d. $z$ samples from the $q(z,\lambda^\ast)$ distribution applying the inverse of the inverse of the $S_{\lambda^\ast}$ and $T$ transformations. Using the same DES-Y1 simulation as in previous section, we use both a Multivariate Normal distribution (MVN) and a Block Neural Autoregressive Flow (B-NAF) \citep{pmlr-v115-de-cao20a} as \textit{guides} to approximate the true posterior (\textit{model}). The B-NAF architecture is composed of a single flow using a block autoregressive structure with 2 hidden layers of 8 units each. The SVI optimization has been performed with the Adam optimizer \citep{KingmaB14} and a learning rate set to $10^{-3}$. We have stopped the optimization after 20,000 (30,000) steps to ensure a stable ELBO loss convergence when using the B-NAF (MVN) guides, and no tuning of the optimizer learning rate scheduling or other parameters was performed. This takes about 2 or 3 hours on the NVidia V100 GPU scaling, depending on the number of steps. In figure \ref{fig_cobaya_SVIs}, we compare the contours obtained with \texttt{Cobaya} (as in figure \ref{fig_cobaya_NUTS_SVI}) and those obtained with the MVN and B-NAF guided SVI. As noted in \cite{NEURIPS2020_7cac11e2} one challenge with variational inference is assessing how close the variational approximation $q(z,\lambda^\ast)$ is to the true posterior distribution. It is not in the scope of this article to elaborate a statistical diagnosis; rather we show that both guided SVI exhibit rather similar contours and both estimate are close to the Cobaya posterior sampling. The difference is that these SVI approximate posteriors have been obtained in a much shorter time, and can serve as starting point for a NUTS sampler as described in the next section. % @@ -948,7 +957,7 @@ \subsection{Stochastic Variational Inference} \subsubsection{Neural Transport} \label{sec-Neural-Reparametrisation} % -If the SVI method can be used as is to get $z$ i.i.d. samples from the $q(z,\lambda^\ast)$ distribution as shown on the previous section, the \textit{Neural Transport MCMC} method \citep{Parno2018,2019arXiv190303704H} is an efficient way to boost HMC efficiency, especially in target distribution with unfavourable geometry where for instance the leapfrog integration algorithm has to face squeezed join distributions for a subset of variables. From SVI, one obtains a first approximation of the target distribution, and this approximation is used to choose a better transform $T$ to map the parameter space to a more convenient one $z=F_\lambda(\zeta)$ (eg. $F_\lambda=T^{-1}\circ S^{-1}_\lambda$) is such that +If the SVI method can be used as is to get $z$ i.i.d. samples from the $q(z,\lambda^\ast)$ distribution as shown on the previous section, the \textit{Neural Transport MCMC} method \citep{Parno2018,2019arXiv190303704H} is an efficient way to boost HMC efficiency, especially in target distribution with unfavourable geometry where for instance the leapfrog integration algorithm has to face squeezed joint distributions for a subset of variables. From SVI, one obtains a first approximation of the target distribution, and this approximation is used to choose a better transform $T$ to map the parameter space to a more convenient one $z=F_\lambda(\zeta)$ (eg. $F_\lambda=T^{-1}\circ S^{-1}_\lambda$) is such that \begin{equation} q(z;\lambda) \rightarrow q(\zeta;\lambda) \bydef q(F_\lambda(\zeta)) |J_{F_\lambda}(\zeta)| \end{equation} @@ -969,7 +978,7 @@ \subsubsection{Neural Transport} % neutra_model = neutra.reparam(model_spl) % \end{lstlisting} % -We have used the MVN guided SVI described in the previous section on the same DES-Y1 analysis described above. The NUTS sampler was run using 1 chain of 200 samples (a fast configuration), one chain with 1000 samples and a set of ten chains with 1000 samples each combined into a single run of 10,000 samples (each setup began all chains with 200 samples for initialisation). All NUTS sampling was performed with dense mass matrix optimisation without the special block structuring and with \texttt{max\_tree\_depth=5} which is different than the default NUTS setting described in section \ref{sec-NUTS}. The elapsed time for each of the 3 setups was 50 minutes, 150 minutes and 5 hours, respectively. Naturally, more samples lead to better contour precision. But what is illustrative is the fast configuration results as shown in the Figures \ref{fig_cobaya_NUTS_SVI} and \ref{fig_cobaya_NUTS_SVI_bis}, compared to \texttt{Cobaya} and NUTS results presented in section \ref{sec-NUTS}. The results are good even for the highly correlated lens bias parameters. It is noticeable that running NUTS with the same settings but without the SVI Neural Transport phase has demonstrated a rather poor behaviour with only 200 samples. +We have used the MVN guided SVI described in the previous section on the same DES-Y1 analysis described above. NUTS was run using 1 chain of 200 samples (a fast configuration), one chain with 1000 samples and a set of ten chains with 1000 samples each combined into a single run of 10,000 samples (each setup began all chains with 200 samples for initialisation). All NUTS sampling was performed with dense mass matrix optimisation without the special block structuring and with \texttt{max\_tree\_depth=5} which is different than the default NUTS setting described in section \ref{sec-NUTS}. The elapsed time for each of the 3 setups was 50 minutes, 150 minutes and 5 hours, respectively. Naturally, more samples lead to better contour precision. But what is illustrative is the fast configuration results as shown in the Figures \ref{fig_cobaya_NUTS_SVI} and \ref{fig_cobaya_NUTS_SVI_bis}, compared to \texttt{Cobaya} and NUTS results presented in section \ref{sec-NUTS}. The results are good even for the highly correlated lens bias parameters. It is noticeable that running NUTS with the same settings but without the SVI Neural Transport phase has demonstrated a rather poor behaviour with only 200 samples. Results in terms of sampling efficiency are shown in Table~\ref{tab-ESS-NUTS_SVI-1}. SVI followed by Neutral Transport gives high efficiency at low number of samples which may be particularly useful during early phases of model development. % @@ -1001,24 +1010,24 @@ \subsection{Sampling efficiency} \begin{equation} \eta = \frac{n_\mathrm{eff}}{n_\mathrm{eval}} = \frac{N_s \times \varepsilon}{N_s \times n_\mathrm{step}} = \frac{\varepsilon}{n_\mathrm{step}} \end{equation} -with $N_s$ the total number of samples, $\varepsilon$ the effective sampler efficiency and $n_\mathrm{step}$ the number of steps (calls) per generated sample. For \texttt{Cobaya} we find $\varepsilon\approx 3\%$ with $n_\mathrm{step}=1$ while for the NUTS sampler $\varepsilon\approx 50\%$ but at the expense of $n_\mathrm{step}=2^5$ or more. With better tuning of the sampling parameters we would expect the $\eta$ values for both methods would to become more comparable, but at this intermediate dimensionality the gain from HMC/NUTS compared to standard MCMC sampler. The power of these approaches will become most evident at higher dimensionality still, such has when marginalizing over increasingly complex systematic models. NUTS also makes post-processing simpler, since samples are nearly uncorrelated, removing the need for a \textit{thinning} step, which is a rather delicate procedure, needing how-know to be conducted correctly \citep{doi:10.1146/annurev-statistics-040220-091727, Owen2017}. +with $N_s$ the total number of samples, $\varepsilon$ the effective sampler efficiency and $n_\mathrm{step}$ the number of steps (calls) per generated sample. For \texttt{Cobaya} we find $\varepsilon\approx 3\%$ with $n_\mathrm{step}=1$ while for the NUTS sampler $\varepsilon\approx 50\%$ but at the expense of $n_\mathrm{step}=2^5$ or more. With better tuning of the sampling parameters we would expect the $\eta$ values for both methods to become more comparable, but at this intermediate dimensionality the gain from HMC/NUTS compared to a standard MCMC sampler is small. The power of these approaches will become most evident at higher dimensionality still, such has when marginalizing over increasingly complex systematic models. NUTS also makes post-processing simpler, since samples are nearly uncorrelated, removing the need for a \textit{thinning} step, which is a rather delicate procedure, needing know-how to be conducted correctly \citep{doi:10.1146/annurev-statistics-040220-091727, Owen2017}. % The $\eta$ metric is may be too crude, though, to cover all aspects of the sample generation. \JZ{I don't understand the next point. Doesn't the HMC/NUTS take a long time too on dedicated resources?} One should probably have in mind that the low sampling efficiency of a standard MCMC sampler requires mobilizing a large amount of resources to produce a sufficient large sample batch in a reasonable time scale, ie. several days on dedicated infrastructure. %\JZ{I think you need to remove the burn-in when using NUTS too?}. -As the dimensionality of cosmological models increases, methods like HMC/NUTS that by construction are more efficient will become increasingly important. Moreover, using SVI with neural reparametrisation offers an effective way to undertake a progressive validation of a model with rather modest sample set (eg. starting with 200 samples) producing good enough marginal contours in few hours. In practice, this validation phase can save time before producing sizeable batch for final analysis. The authors have not investigated higher dimensional ($O(10^2)$ parameters) or multi-modal problems, but the key argument in favour of HMC/NUTS sampling is that it exploits the geometry of the typical set of the posterior distribution automatically, unlike the standard random walk of Metropolis-Hasting sampling. Furthermore, using reparametrisation one can adapt to poor geometry cases (eg. \cite{2019arXiv190303704H}). +As the dimensionality of cosmological models increases, methods like HMC/NUTS that by construction are more efficient will become increasingly important. Moreover, using SVI with neural reparametrisation offers an effective way to undertake a progressive validation of a model with rather modest sample set (eg. starting with 200 samples) producing good enough marginal contours in few hours. In practice, this validation phase can save time before producing sizeable batch for final analysis. The authors have not investigated higher dimensional ($O(10^2)$ parameters) or multi-modal problems, but the key argument in favour of HMC/NUTS sampling is that it exploits the geometry of the typical set of the posterior distribution automatically, unlike the standard random walk of Metropolis-Hasting sampling. Furthermore, using reparametrisation one can adapt to poor geometry cases (eg. \citealp{2019arXiv190303704H}). % % \section{General discussion} \label{sec-discussion} % -Having demonstrated the utility of \jaxcosmo\ as a differentiable cosmology library library, we now discuss several limitations, and open questions. +Having demonstrated the utility of \jaxcosmo\ as a differentiable cosmology library, we now discuss several limitations, and raise a few questions. %Lack of autodiff Boltzmann code (emulator) -The first essential barrier to a fully-fledged automatically differentiable cosmology library is the need for a differentiable Boltzmann solver to compute the CMB or matter power spectra. At this stage, \jaxcosmo\ relies on the analytic Einsenstein \& Hu fitting formula for the latter, which is not accurate enough for Stage IV requirements, and it does not include models beyond $\Lambda$CDM. Existing solvers such as CLASS or CAMB \citep{2011JCAP...07..034B,camb} are large and complex codes which are not easily reimplemented in an \autodiff\ framework and therefore cannot be directly integrated in \jaxcosmo\ . +The first essential barrier to a fully-fledged automatically differentiable cosmology library is the need for a differentiable Boltzmann solver to compute the CMB or matter power spectra. At this stage, \jaxcosmo\ relies on the analytic Eisenstein \& Hu fitting formula for the latter, which is not accurate enough for Stage IV \citep{detf} requirements, and it does not include models beyond $\Lambda$CDM. Existing solvers such as CLASS \citep{2011JCAP...07..034B} or CAMB \citep{camb} are large and complex codes which are not easily reimplemented in an \autodiff\ framework and therefore cannot be directly integrated in \jaxcosmo\ . -A first option to resolve this issue would be to implement from scratch a new Boltzmann code in a framework that supports automatic differentiation. This is the approach behind works such as \texttt{Bolt.jl}\footnote{\url{https://github.com/xzackli/Bolt.jl}} which is provides a simplified Boltzmann solver in Julia, or PyCosmo \citep{pycosmo} which is based on a the SymPy symbolic mathematics library and could be relatively easily compatible with JAX. However, even if very promising, both of these options thus far remain limited. While we do believe a automatically differentiable Boltzmann code is the best option, it seems that the cost of developing such a code remains very high at this time. +A first option to resolve this issue would be to implement from scratch a new Boltzmann code in a framework that supports automatic differentiation. This is the approach behind works such as \texttt{Bolt.jl}\footnote{\url{https://github.com/xzackli/Bolt.jl}} which provides a simplified Boltzmann solver in Julia, or PyCosmo \citep{pycosmo} which is based on the SymPy symbolic mathematics library and could be relatively compatible with \jax. However, even if very promising, both of these options thus far remain limited. While we do believe an automatically differentiable Boltzmann code is the best option, it seems that the cost of developing such a code remains very high at this time. A second approach would be to develop emulators of a fully-fledged Boltzmann code. Emulators based on neural networks or Gaussian processes are themselves automatically differentiable with respect to cosmological parameters. In fact, the literature is now rich in examples of such emulators \citep[e.g.][and references therein]{Gunther_2022, nygaard,cosmopower,cosmicnet, emucmb}. After validating their accuracy against a reference CAMB or CLASS implementation, they could be directly integrated as a plug-and-play replacement for the computation of the matter power spectrum. At this time, it seems that using emulators will be the most straightforward approach to bring more accurate models to \jaxcosmo. We believe, though, that one of the reason for the wide diversity in this is a lack of standardization - a unified interface and validation suite for such methods would provide a much simpler comparison between them and enable wider usage. @@ -1026,21 +1035,21 @@ \section{General discussion} % Reasoning behind not making a full emulator of the correlation functions -Connected to this discussion about emulators, one point that could be discussed is whether one even needs a library like \jaxcosmo\ if anyway one can build emulators of a CCL likelihood for use inside gradient-based inference algorithms. While this could indeed be feasible, and of similar cost as just making an emulator of the matter power spectrum, the drawback of this approach is that many analysis parameters and choices become hard coded in the emulator. While the model for the linear matter power spectrum is typically kept fixed in practical analyses, all the of the choices related in particular to systematics modeling (e.g. photometric redshift errors, galaxy bias, or intrinsic galaxy alignments) will vary significantly in the process of developing the analysis. Building an emulator for the likelihood would mean the emulator needs to be retrained every time the likelihood is changed. +Connected to this discussion about emulators, a point that could be discussed is whether one even needs a library like \jaxcosmo\ if one can build emulators of a CCL likelihood for use inside gradient-based inference algorithms. While this could indeed be feasible, and of similar cost as just making an emulator of the matter power spectrum, the drawback of this approach is that many analysis parameters and choices become hard coded in the emulator. Since the model for the linear matter power spectrum is typically kept fixed in practical analyses, all the of the choices related in particular to systematics modeling (e.g. photometric redshift errors, galaxy bias, or intrinsic galaxy alignments) will vary significantly in the process of developing the analysis. Building an emulator for the likelihood would require the emulator to be retrained every time the likelihood is changed. \bigskip % - Massively Parallel MCMC on GPU (linked to vmap and/or pmap) % - Efficiency of gradient-based inference methods -Another aspect worth discussing are the prospects for scaling up and speeding up cosmological inference in practice given tools such as \jaxcosmo. As we illustrated in the previous section, gradient-based inference techniques yield significantly less correlated MCMC chains, scale better than any other known sampling method as dimension increases, and can provide very fast approximate posteriors if needed. \jaxcosmo\ is also well suited to aid in the parallelization of likelihoods and algorithms, especially on multiple GPUs, which will become increasingly important as the high-performance computing landscape evolves. +Another aspect worth discussing are the prospects for scaling and speeding up cosmological inference in practice, given tools such as \jaxcosmo. As we illustrated in the previous section, gradient-based inference techniques yield significantly less correlated MCMC chains, scale better than any other known sampling method as dimension increases, and can provide very fast approximate posteriors if needed. \jaxcosmo\ is also well suited to aid in the parallelisation of likelihoods and algorithms, especially on multiple GPUs, which will become increasingly important as the high-performance computing landscape evolves. \bigskip % - Respective role of CCL and jax-cosmo -Finally, one question worth asking is how does \jaxcosmo\ position itself against classical codes such as CCL? While we are convinced of the benefits of a \texttt{JAX} implementation, we expect CCL and other key codes to remain critical as standard cosmology implementations. Ultimately, a natural transition may occur towards differentiable frameworks like \jaxcosmo\ when they reach the ability to run full-fledged Stage IV likelihoods. +Finally, how does \jaxcosmo\ position itself against classical codes such as CCL? While we are convinced of the benefits of a \jax\ implementation, we expect CCL and other key codes to remain critical as standard cosmology implementations. Ultimately, a natural transition may occur towards differentiable frameworks like \jaxcosmo\ when they reach the ability to run fully-fledged Stage IV likelihoods. \section{Conclusions \& Prospects} @@ -1051,15 +1060,15 @@ \section{Conclusions \& Prospects} %- We have demonstrated the efficiency in JC of the derivative-aware MCMC methods HMC and NUTS, which are regarded as the only way for samplers to evade the curse of dimensionality up to the hundreds of dimensions we are likely to need for the next generation of surveys. %- We have shown a proof-of-concept for the use of JC with machine learning methods, opening a whole new space of methods in cosmology. -In this paper, we have presented \jaxcosmo, a cosmology library implemented in the \texttt{JAX} framework that enables the automatic computation of the derivatives of cosmological likelihoods with respect to their parameters, and greatly speeds up likelihood evaluations thanks to automatic parallelisation and just-in-time compilation on GPUs. Currently, \jaxcosmo\ only contains a small set of features corresponding to a DES-Y1 3x2pt analysis. Being an open source project, contributions of additional features is additional areas such as CMB or spectroscopic galaxy clustering are warmly welcome. +In this paper, we have presented \jaxcosmo, a cosmology library implemented in the \jax\ framework that enables the automatic computation of the derivatives of cosmological likelihoods with respect to their parameters, and greatly speeds up likelihood evaluations thanks to automatic parallelisation and just-in-time compilation on GPUs. Currently, \jaxcosmo\ contains a small set of features corresponding to a DES-Y1 3x2pt analysis. Being an open source project, contributions of new features for additional scientific areas such as CMB or spectroscopic galaxy clustering are warmly welcome. -To demonstrate the value of an automatically differentiable library, we have illustrated with concrete examples how Fisher matrices, which are notoriously unstable and require extensive and careful fine tuning, can now be computed robustly and at much lower cost. In addition, Fisher matrices becomes themselves differentiable, which allows for Figure of Merit optimization by gradient descent, making survey optimization extremely simple. +To demonstrate the value of an automatically differentiable library, we have illustrated with concrete examples how Fisher matrices, which are notoriously unstable and require extensive and careful fine tuning, can now be computed robustly and at much lower cost. In addition, Fisher matrices becomes themselves differentiable, which allows for Figure of Merit optimization by gradient descent, making survey optimization extremely straightforward. -Going beyond Fisher forecasts, we have also compared simple Metropolis-Hastings to several gradient-based inference techniques (Hamiltonian Monte-Carlo, No-U-Turn-Sampler, and Stochastic Variational Inference). We show that the posterior samples with gradient-based methods can reproduce classic methods very efficiently, and can provide approximate posteriors very rapidly. These inference techniques can scale to hundreds of dimensions, and may become necessary in Stage IV analysis as the number of nuisance parameters is likely to become very large. +Going beyond Fisher forecasts, we have also compared simple Metropolis-Hastings to several gradient-based inference techniques (Hamiltonian Monte-Carlo, No-U-Turn-Sampler, and Stochastic Variational Inference). We have shown that the posterior samples with gradient-based methods can reproduce classical methods very efficiently, and can provide approximate posteriors very rapidly. These inference techniques can scale to hundreds of dimensions, and may become necessary in Stage IV analysis, as the number of nuisance parameters is likely to become large. -The next extensions of this framework will be the inclusion of additional cosmological probes, as well as the integration of emulators for the matter power spectrum trained on CAMB or CLASS, as a means to go beyond the current analytic Eisenstein \& Hu model. +The next extensions to this framework will be the inclusion of additional cosmological probes, as well as the integration of emulators for the matter power spectrum trained on CAMB or CLASS, as a means to go beyond the current analytic Eisenstein \& Hu model. -In the spirit of reproducible research, all results presented in this paper can be reproduced from the following GitHub repository: +In the spirit of reproducible research, all results presented in this paper can be reproduced with code contained in the following GitHub repository: \url{https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/} @@ -1069,11 +1078,14 @@ \section{Conclusions \& Prospects} \section*{Credit authorship contribution statement} \textbf{A. Boucaud:} Software and comments. \textbf{J.E Campagne:} Conceptualization, Methodology, Software, Validation, Writing, Visualization. +\textbf{D.~Kirkby:} Software and validation for sparse linear algebra. \textbf{F. Lanusse:} Conceptualization, Methodology, Software, Validation, Writing, Project administration. \textbf{D. Lanzieri:} Software and validation for redshift distribution. +\textbf{S. Casas:} Software contribution for growth rate and power spectra. +\textbf{Y. Li:} Software contributions. +\textbf{A. Peel:} Software and validation for spline interpolations in \jax. \textbf{J. Zuntz:} Investigation, Writing. -\textbf{D.~Kirkby:} Software and validation for sparse linear algebra. -\textbf{A. Peel:} Software and validation for spline interpolations in JAX. +\textbf{M.~Karamanis:} Software for Gaussian likelihood computation. %\section*{Declaration of Competing Interest} %The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper. @@ -1095,7 +1107,7 @@ \section*{Acknowledgements} %\label{eq-loss-svi-3b} %\end{multline} -%In case of the MVN $\nabla_\mu \mathbb{H}(q)=0$, $\nabla_L \mathbb{H}(q)=(L^{-1})^T$ while the other gradients are computed by the \texttt{JAX} \texttt{autodiff} mechanism. +%In case of the MVN $\nabla_\mu \mathbb{H}(q)=0$, $\nabla_L \mathbb{H}(q)=(L^{-1})^T$ while the other gradients are computed by the \jax\ \texttt{autodiff} mechanism. %\FrL{Yeah. this is way TMI, and also goes against the idea that with autodiff we don't need to manually work out all these details.} %The expectations used in the above expressions are computed with $\zeta$ i.i.d samples from $\mathcal{N}(0,1)$ distribution, hence the ``S'' of SVI. As one can imagine all the jacobian computations take benefit of the automatic differentiation offers by JAX. diff --git a/refs.bib b/refs.bib index c10c1f2..622c57b 100644 --- a/refs.bib +++ b/refs.bib @@ -1511,3 +1511,32 @@ @ARTICLE{cosmicnet adsnote = {Provided by the SAO/NASA Astrophysics Data System} } + +@ARTICLE{detf, + author = {{Albrecht}, Andreas and {Bernstein}, Gary and {Cahn}, Robert and {Freedman}, Wendy L. and {Hewitt}, Jacqueline and {Hu}, Wayne and {Huth}, John and {Kamionkowski}, Marc and {Kolb}, Edward W. and {Knox}, Lloyd and {Mather}, John C. and {Staggs}, Suzanne and {Suntzeff}, Nicholas B.}, + title = "{Report of the Dark Energy Task Force}", + journal = {arXiv e-prints}, + keywords = {Astrophysics}, + year = 2006, + month = sep, + eid = {astro-ph/0609591}, + pages = {astro-ph/0609591}, + doi = {10.48550/arXiv.astro-ph/0609591}, +archivePrefix = {arXiv}, + eprint = {astro-ph/0609591}, + primaryClass = {astro-ph}, + adsurl = {https://ui.adsabs.harvard.edu/abs/2006astro.ph..9591A}, + adsnote = {Provided by the SAO/NASA Astrophysics Data System} +} + +@article{bezanson2017julia, + title={Julia: A fresh approach to numerical computing}, + author={Bezanson, Jeff and Edelman, Alan and Karpinski, Stefan and Shah, Viral B}, + journal={SIAM review}, + volume={59}, + number={1}, + pages={65--98}, + year={2017}, + publisher={SIAM}, + url={https://doi.org/10.1137/141000671} +} \ No newline at end of file