From 0b926fe5ae845006d8855bb047b3ec0f355daaf8 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 10 Feb 2023 09:34:12 +0100 Subject: [PATCH] before submission --- paper_OJA.tex | 55 +++++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/paper_OJA.tex b/paper_OJA.tex index 51968c4..a390e2b 100644 --- a/paper_OJA.tex +++ b/paper_OJA.tex @@ -167,7 +167,7 @@ %\submitted{submitted xxx, revised yyy, accepted zzz} -\title{\jaxcosmo: An End-to-End Differentiable and GPU Accelerated Cosmology Library} +\title{\texttt{JAX-COSMO}: An End-to-End Differentiable and GPU Accelerated Cosmology Library} %% use optional labels to link authors explicitly to addresses: %% \author[label1,label2]{} @@ -208,15 +208,15 @@ \begin{abstract} We present \jaxcosmo, a library for automatically differentiable cosmological theory calculations. \jaxcosmo\ uses the \jax\ library, which has created a new coding ecosystem, especially in probabilistic programming. % -As well as batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU), \jax\ exposes an \textit{automatic differentiation} (AD) mechanism. Thanks to AD, \jaxcosmo\ gives access to the derivatives of cosmological likelihoods with respect to any of their parameters, and thus enables a range of powerful % AD can take a Python function that uses \jax\ and return its (vector) derivative, computed not with finite differences but by successively differentiating each instruction in turn. This gives us the opportunity to apply a range of powerful +As well as batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU), \jax\ exposes an \textit{automatic differentiation} (autodiff) mechanism. Thanks to autodiff, \jaxcosmo\ gives access to the derivatives of cosmological likelihoods with respect to any of their parameters, and thus enables a range of powerful % AD can take a Python function that uses \jax\ and return its (vector) derivative, computed not with finite differences but by successively differentiating each instruction in turn. This gives us the opportunity to apply a range of powerful Bayesian inference algorithms, otherwise impractical in cosmology, such as Hamiltonian Monte Carlo and Variational Inference. % In its initial release, \jaxcosmo\ implements background evolution, linear and non-linear power spectra (using \texttt{halofit} or the Eisenstein and Hu transfer function), as well as angular power spectra ($C_\ell$) with the Limber approximation for galaxy and weak lensing probes, all differentiable with respect to the cosmological parameters and their other inputs. % -We illustrate how automatic differentiation can be a game-changer for common tasks involving Fisher matrix computations, or full posterior inference with gradient-based techniques (e.g. Hamiltonian Monte Carlo). In particular, we show how Fisher matrices are now fast, exact, no longer require any fine tuning, and are themselves differentiable with respect to parameters of the likelihood enabling complex survey optimization by simple gradient descent. Finally, using a Dark Energy Survey Year 1 3x2pt analysis as a benchmark, we demonstrate how \jaxcosmo\ can be combined with Probabilistic Programming Languages such as \numpyro\ to perform posterior inference with state-of-the-art algorithms including a No U-Turn Sampler (NUTS), Automatic Differentiation Variational Inference (ADVI), and Neural Transport HMC (NeuTra).% We discuss algorithms made possible by this library, present comparisons with the Core Cosmology Library as a benchmark, and run a series of tests using the Dark Energy Survey Year 1 3x2pt analysis with the \numpyro\ library to demonstrate practical inference. -% -We show that clear improvements are possible using HMC compared to Metropolis-Hasting, and that Normalizing Flows using the Neural Transport are a promising methodology.\FrL{we need a better statement than this, what are our conclusions? Are we faster/better than Cobaya?} - +We illustrate how automatic differentiation can be a game-changer for common tasks involving Fisher matrix computations, or full posterior inference with gradient-based techniques (e.g. Hamiltonian Monte Carlo). In particular, we show how Fisher matrices are now fast, exact, no longer require any fine tuning, and are themselves differentiable with respect to parameters of the likelihood, enabling complex survey optimization by simple gradient descent. Finally, using a Dark Energy Survey Year 1 3x2pt analysis as a benchmark, we demonstrate how \jaxcosmo\ can be combined with Probabilistic Programming Languages such as \numpyro\ to perform posterior inference with state-of-the-art algorithms including a No U-Turn Sampler (NUTS), Automatic Differentiation Variational Inference (ADVI), and Neural Transport HMC (NeuTra). +% We discuss algorithms made possible by this library, present comparisons with the Core Cosmology Library as a benchmark, and run a series of tests using the Dark Energy Survey Year 1 3x2pt analysis with the \numpyro\ library to demonstrate practical inference. +We show that thee effective sample size per node (1 GPU or 32 CPUs) per hour of wall time is about 5 times better for a JAX NUTS sampler compared to the well optimized Cobaya Metropolis-Hasting sampler. We further demonstrate that Normalizing Flows using Neural Transport are a promising methodology for model validation in the early stages of analysis. % +\github % The recent \jax\ library has created a new ecosystem from which probabilistic programming software can take enormous benefit: batch acceleration, just-in-time compilation, and automatic optimization of code for different hardware modalities (CPU, GPU, TPU) can provide huge speed-ups to a wide range of different problems. In particular, \jax\ exposes an \textit{automatic differentiation} mechanism, which can take a python function that uses \jax\ and return its (vector) derivative. This gives us the opportunity to apply a range of powerful but otherwise unfeasible algorithms used in Bayesian inference, such as Hamiltonian Monte Carlo (HMC) and Variational Inference. @@ -249,7 +249,7 @@ \section{Introduction} While these tools have especially been applied to neural network optimization and machine learning (ML) in general, they can also enable classical statistical methods that require the derivatives of (e.g. likelihood) functions to operate: we consider such methods in this paper. \textit{Autodiff} has been implemented in widely used libraries like \texttt{Stan} \citep{JSSv076i01}, \texttt{TensorFlow} \citep{tensorflow2015-whitepaper}, \texttt{Julia} \citep{bezanson2017julia}, and \texttt{PyTorch} \citep{NEURIPS2019_9015}. -A recent entrant to this field is the \jax\ library\footnote{\url{https://jax.readthedocs.io}} \citep{jax2018github} which has undergone rapid development and can automatically differentiate native \texttt{Python} and \texttt{NumPy} functions, offering a speed up to the development process and indeed code runtimes. \jax\ offers an easy parallelization mechanism (\texttt{vmap}), just-in-time compilation (\texttt{jit}), and optimization targeting CPU, GPU, and TPU hardware thanks to the \texttt{XLA} library that powers TensorFlow. These attractive features have driven wide adoption of \jax\ in computational research, and motivate us to consider its usage in cosmology. +A recent entrant to this field is the \jax\ library\footnote{\url{https://jax.readthedocs.io}} \citep{jax2018github} which has undergone rapid development and can automatically differentiate native \texttt{Python} and \texttt{NumPy} functions, offering a speed up to the development process and indeed code runtimes. \jax\ offers an easy parallelization mechanism (\texttt{vmap}), just-in-time compilation (\texttt{jit}), and optimization targeting CPU, GPU, and TPU hardware thanks to the \texttt{XLA} library (which also powers TensorFlow). These attractive features have driven wide adoption of \jax\ in computational research, and motivate us to consider its usage in cosmology. \jax\ contains bespoke reimplementations of packages such as \texttt{jax.numpy} and \texttt{jax.scipy}, as well as example libraries such as \texttt{Stax} for simple but flexible neural network development. It has formed the seed for a wider ecosystem of packages, including, for example: \texttt{Flax} \citep{flax2020github} a high-performance neural network library, \texttt{JAXopt} \citep{jaxopt_implicit_diff} a hardware accelerated, batchable and differentiable collection of optimizers, \texttt{Optax} \citep{optax2020github} a gradient processing and optimization library, and \numpyro\ \citep{phan2019composable,bingham2019pyro}, a probabilistic programming language (PPL) that is used in this paper. Other PPL packages such as \texttt{PyMC} \citep{Salvatier2016} have switched to a \jax\ backend in recent versions\footnote{A more exhaustive list and rapidly growing list of packages can be found at \url{https://project-awesome.org/n2cholas/awesome-jax}}. @@ -265,7 +265,10 @@ \section{Introduction} % \footnote{\url{https://pytorch.org/docs/}} % \footnote{\url{https://mc-stan.org/users/documentation}} -In this context we have developed the open source \jaxcosmo\ library\footnote{\url{https://github.com/DifferentiableUniverseInitiative/jax_cosmo}}, which we present in this paper. The package represents a first step in making the powerful features described above useful for cosmology; it implements a selection of theory predictions for key cosmology observables as differentiable \jax\ functions. +To explore alternative inference methods to the well-known Metropolis-Hasting likelihood sampler, and in order to use GPU devices in the context of JAX framework, we have developed the open source \jaxcosmo\ library\footnote{\url{https://github.com/DifferentiableUniverseInitiative/jax_cosmo}}, which we present in this paper. The package represents a first step in making the powerful features described above useful for cosmology; it implements a selection of theory predictions for key cosmology observables as differentiable \jax\ functions. + +%\FrL{Missing a paragraph on what we propose to do with this framework in this paper, and why that's interesting in the context of the very long time that it takes to sample common likelihood.} + We give an overview of the code's design and contents in Section~\ref{sec-jaxcosmo-design}. We show how to use it for rapid and numerically stable Fisher forecasts in Section~\ref{sec-fisher-forecast}, in more complete gradient-based cosmological inference with variants of Hamiltonian Monte Carlo including the No-U-Turn Sampler, and ML-accelerated Stochastic Variational Inference in Section~\ref{sec:chmc}. We discuss and compare these methods in Section~\ref{sec-discussion} and conclude in Section~\ref{sec-conclusion}. @@ -432,7 +435,7 @@ \subsection{Validation against the Core Cosmology Library (CCL)} \ref{fig:halofit_comparison} and \ref{fig:Cell_comparison} show the radial comoving distance (Eq.~\ref{eq:radial_comoving}), the non-linear matter power spectrum computation, and the angular power spectrum for galaxy-galaxy lensing (Eq.~\ref{eq:Cell_limber}) using the \texttt{NumberCounts} and \texttt{WeakLensing} kernel functions. \jaxcosmo\ features a suite of validation tests against CCL, automatically validating the precision of all computations to within the desired numerical accuracy; the relative differences between the two libraries are at the level of few $10^{-3}$ or better. -These numerical differences are mostly due to different choices of integration methods and accuracy parameters (eg. number of quadrature points). Increasing these parameters leads to performance degradation for \jaxcosmo\, but increases the XLA compilation memory requirements significantly, especially for the angular power spectra computation. Since these differences are likely to be within the tolerance of the current generation of cosmological surveys, this trade-off is an acceptable one. +These numerical differences are mostly due to different choices of integration methods and accuracy parameters (e.g. number of quadrature points). Increasing these parameters leads to performance degradation for \jaxcosmo\, but increases the XLA compilation memory requirements significantly, especially for the angular power spectra computation. Since these differences are likely to be within the tolerance of the current generation of cosmological surveys, this trade-off is an acceptable one. \begin{figure} \centering @@ -443,7 +446,7 @@ \subsection{Validation against the Core Cosmology Library (CCL)} \begin{figure} \centering \includegraphics[width=\columnwidth]{figures/halofit_pk.png} - \caption{Comparison of the non-linear matter power spectrum (\textit{halofit} function) between CCL and \jaxcosmo. Is also shown the linear power spectrum.} \label{fig:halofit_comparison} + \caption{Comparison of the non-linear matter power spectrum (\textit{halofit} function) between CCL and \jaxcosmo. Also shown is the linear power spectrum.} \label{fig:halofit_comparison} \end{figure} \begin{figure} \centering @@ -454,7 +457,7 @@ \subsection{Validation against the Core Cosmology Library (CCL)} % \section{Unleashing the Power of Fisher Forecasts with Automatic Differentiation} -\section{Fisher Forecasts \& Data Compression} +\section{Fisher Information Matrices Made Easy} \label{sec-fisher-forecast} As a first illustration of the value of \textit{autodiff} in cosmological computations, we present in this section a few direct applications involving the computation of the Fisher information matrix, widely used in cosmology \citep{1997ApJ...480...22T,Stuart1991}. \\ @@ -527,7 +530,7 @@ \subsection{Instantaneous Fisher Forecasts} \begin{figure} \centering \includegraphics[width=0.7\columnwidth]{figures/simple_fisher_1.png} - \caption{Comparison of the two methods to compute the Fisher matrix: method 1 (method 2) is using Eq.~\ref{eq:fisher_way1} (Eq.~\ref{eq:fisher_way2}).} + \caption{Comparison of the two methods to compute the Fisher matrix: Eq.~\ref{eq:fisher_way1} (method 1) and Eq.~\ref{eq:fisher_way2} (method 2).} \label{fig:simple_fisher_1} \end{figure} % @@ -579,7 +582,7 @@ \subsection{Massive Optimal Compression in 3 Lines} % \FrL{hummm but we still do need a full covariance in the compression algorithm, right?} \FrL{And we still do need to invert it... is the idea rather that if we dont have a perfect matrix at that stage the worse that can happen is sub optimal compression?} -%JEC{Well there are several ways that $C$ can be questioned: parameter dependence (ie. fiducial model/true model), non-gaussianities, errors... So not sure that all these effects induce only sub-optimal compression. But it may be out of the scope of this paragraph/paper.} +%JEC{Well there are several ways that $C$ can be questioned: parameter dependence (i.e. fiducial model/true model), non-gaussianities, errors... So not sure that all these effects induce only sub-optimal compression. But it may be out of the scope of this paragraph/paper.} Another key advantage of the MOPED algorithm is to eliminate the need for large covariance matrix inversion of size $N\times N$ requiring $O(N^3)$ operations. This inversion takes place not only for the Fisher matrix computation (Eq.~\ref{eq:fisher_way2}), but more importantly in the log-likelihood computation (see the snippet in the previous section). The MOPED algorithm reduces the complexity to $O(M^3)$ operations. @@ -618,7 +621,7 @@ \subsection{Massive Optimal Compression in 3 Lines} \section{Posterior Inference made fast by Gradient-Based Inference Methods} \label{sec:chmc} % -In the following sections we review more statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn Sampler} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report on their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). +In the following sections we review more statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn Sampler} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). % % and NUTS after a \textit{Neural Transport} perform using a \textit{Stochastic Variational Inference} (aka SVI). All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL). % @@ -686,7 +689,7 @@ \subsection{Description of the DES-Y1 exercise} f_sky=0.25, sparse=True) \end{lstlisting} -with \texttt{cosmo} an instance of the \texttt{jc.Cosmology} setting the cosmological parameters generated with the priors, and \texttt{ell} (ie. $\ell$) a series of 50 angular modes. After encapsulating the code above with input sampled parameters (using \numpyro\ distribution classes) in a function \texttt{model}, we can generate our mock data (\texttt{cl\_obs}). This comes comes from this model function evaluated at a fiducial cosmology, with \numpyro\ dealing with the random number generation that adds noise: +with \texttt{cosmo} an instance of the \texttt{jc.Cosmology} setting the cosmological parameters generated with the priors, and \texttt{ell} (i.e. $\ell$) a series of 50 angular modes. After encapsulating the code above with input sampled parameters (using \numpyro\ distribution classes) in a function \texttt{model}, we can generate our mock data (\texttt{cl\_obs}). This comes from this model function evaluated at a fiducial cosmology, with random noise generated by \numpyro: \begin{lstlisting}[language=iPython] fiducial_model = numpyro.condition(model, @@ -847,14 +850,14 @@ \subsection{NUTS} % mcmc.run(jax.random.PRNGKey(42)) % \end{lstlisting} -We ran the NUTS sampler using \texttt{numpyro.infer.NUTS} on the DES Y1 likelihood, with 16 chains of 1,000 samples each after a warm-up phase consisting of 200 samples, with the \texttt{max\_tree\_depth} set to seven (ie. 128 steps for each iteration). %\FrL{it should explain that vectorized means 16 chains in parallel on a single GPU.} +We ran the NUTS sampler using \texttt{numpyro.infer.NUTS} on the DES Y1 likelihood, with 16 chains of 1,000 samples each after a warm-up phase consisting of 200 samples, with the \texttt{max\_tree\_depth} set to seven (i.e. 128 steps for each iteration). %\FrL{it should explain that vectorized means 16 chains in parallel on a single GPU.} Using the ``vectorized'' \texttt{numpyro} option we ran all 16 chains simultaneously on a single GPU, made possible by the \jax\ \textit{vmap} mechanism. If one has several GPU devices available, then the using the \jax\ paralellization mechanism (\textit{pmap}), it is further possible to launch the vectorized sampling across the devices, and get back all the MCMC chains. However, these experiments have all been undertaken on single GPUs, either an NVidia Titan Xp (12GB RAM) on a desktop or an NVidia V100 (32GB RAM) at the IN2P3 Computing Centre\footnote{\url{https://cc.in2p3.fr/en/}}. The elapsed time for these experiments was 20 hours. The results in terms of relative effective sample sizes (ESS) are detailed in Table~\ref{tab-ESS-NUTS_SVI-1} while the confidence level (CL) contours are presented in Figure \ref{fig_cobaya_NUTS_SVI}. We compare to a reference sample from the highly-optimized \texttt{Cobaya} Metropolis-Hastings implementation \citep{2019ascl.soft10019T,2021JCAP...05..057T}, which is widely used in cosmology and which we ran for around 40 hours on CPU to obtain the set of contours shown. -There is a dramatic improvement of the ESS by about a factor of 10 using the NUTS sampler compared to Cobaya, with very good agreement between the CL contours. It is worth noting that the mass matrix structure described above increases the sampling efficiency by about a factor of two. +There is a dramatic improvement of the ESS by about a factor of 10 using the NUTS sampler compared to Cobaya, with very good agreement between the CL contours. It is worth mentioning that the mass matrix structure described above increases the sampling efficiency by about a factor of two. -The speed of the sampling could be improved: we have tested using the parameter \texttt{max\_tree\_depth=5} and found convergence in five hours, showing a linear scaling in this parameter while keeping the sampling efficiencies at a high level; the user is highly encouraged to tune this critical parameter. +The speed of the sampling could be further improved: we have tested using the parameter \texttt{max\_tree\_depth=5} and found convergence in five hours, showing a linear scaling in this parameter while keeping the sampling efficiencies at a high level; the user is highly encouraged to tune this critical parameter. %\JZ{I don't think the anticipation of the next section was needed here to make the point so I removed it, but please let me know if you're unhappy about this. Otherwise if you're happy please delete this comment.} @@ -874,7 +877,7 @@ \subsection{Stochastic Variational Inference} \label{sec-SVI} % -We now explore \textit{Stochastic Variational Inference} \citep{10.5555/2567709.2502622, 8588399}, another inference algorithm enabled by auto-differentiation. If we write $p(z)$ the prior, $p(\mathcal{D}|z)$ the likelihood and $p(\mathcal{D})$ the marginal likelihood, then thanks to Bayes theorem we have $p(z|\mathcal{D})=p(z)p(\mathcal{D}|z)/p(\mathcal{D})$ as the posterior distribution of a model with latent variables $z$ and a set of observations $\mathcal{D}$. Variational Inference (VI) aims to find an approximation to this distribution, ie. $p(z|\mathcal{D}) \approx q(z;\lambda)$, by determining the variational parameters $\lambda$ of a predefined distribution. To do so, one uses the Kullback-Leibler divergence of the two distributions $KL(q(z;\lambda)||p(z|\mathcal{D}))$ leading to the following relation +We now explore \textit{Stochastic Variational Inference} \citep{10.5555/2567709.2502622, 8588399}, another inference algorithm enabled by auto-differentiation. If we write $p(z)$ the prior, $p(\mathcal{D}|z)$ the likelihood and $p(\mathcal{D})$ the marginal likelihood, then thanks to Bayes theorem we have $p(z|\mathcal{D})=p(z)p(\mathcal{D}|z)/p(\mathcal{D})$ as the posterior distribution of a model with latent variables $z$ and a set of observations $\mathcal{D}$. Variational Inference (VI) aims to find an approximation to this distribution, i.e. $p(z|\mathcal{D}) \approx q(z;\lambda)$, by determining the variational parameters $\lambda$ of a predefined distribution. To do so, one uses the Kullback-Leibler divergence of the two distributions $KL(q(z;\lambda)||p(z|\mathcal{D}))$ leading to the following relation \begin{align} \log p(\mathcal{D}) &= \mathtt{ELBO} + KL(q(z;\lambda)||p(z|\mathcal{D})) \label{eq-ELBO} \\ \mathrm{with} \ \mathtt{ELBO} &\equiv -\mathbb{E}_{q(z;\lambda)}\left[ \log q(z;\lambda)\right] + \mathbb{E}_{q(z;\lambda)}\left[ \log p(z,\mathcal{D}) \right] @@ -957,11 +960,11 @@ \subsection{Stochastic Variational Inference} \subsubsection{Neural Transport} \label{sec-Neural-Reparametrisation} % -If the SVI method can be used as is to get $z$ i.i.d. samples from the $q(z,\lambda^\ast)$ distribution as shown on the previous section, the \textit{Neural Transport MCMC} method \citep{Parno2018,2019arXiv190303704H} is an efficient way to boost HMC efficiency, especially in target distribution with unfavourable geometry where for instance the leapfrog integration algorithm has to face squeezed joint distributions for a subset of variables. From SVI, one obtains a first approximation of the target distribution, and this approximation is used to choose a better transform $T$ to map the parameter space to a more convenient one $z=F_\lambda(\zeta)$ (eg. $F_\lambda=T^{-1}\circ S^{-1}_\lambda$) is such that +If the SVI method can be used as is to get $z$ i.i.d. samples from the $q(z,\lambda^\ast)$ distribution as shown on the previous section, the \textit{Neural Transport MCMC} method \citep{Parno2018,2019arXiv190303704H} is an efficient way to boost HMC efficiency, especially in target distribution with unfavourable geometry where for instance the leapfrog integration algorithm has to face squeezed joint distributions for a subset of variables. From SVI, one obtains a first approximation of the target distribution, and this approximation is used to choose a better transform $T$ to map the parameter space to a more convenient one $z=F_\lambda(\zeta)$ (e.g. $F_\lambda=T^{-1}\circ S^{-1}_\lambda$) is such that \begin{equation} q(z;\lambda) \rightarrow q(\zeta;\lambda) \bydef q(F_\lambda(\zeta)) |J_{F_\lambda}(\zeta)| \end{equation} -where $F_\lambda$ with the optimal $\lambda^\ast$ maps the best-fitting $q(z;\lambda^\ast)$ to a geometrically simple function like a unit multivariate normal distribution. So, one can use a HMC sampler (eg. NUTS) based on $p(\zeta;\mathcal{D})$ distribution, initialized with $\zeta$ samples from $q(\zeta;\lambda^\ast)$, to get a Markov Chain of $N$ samples $(\zeta_i)_{i