Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
EiffL authored Jan 13, 2023
2 parents 546e106 + ac911d2 commit 73ed089
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
48 changes: 30 additions & 18 deletions paper_OJA.tex
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,22 @@

\author{
J.~E. Campagne$^{1,\ast}$,
F. Lanusse$^2$,
F. Lanusse$^2$,
J. Zuntz$^3$,\\
D. Lanzieri$^4$\\
D.~Kirkby$^4$,
D. Lanzieri$^2$,
Y.~Li$^{5,6}$,
A. Peel$^7$
\jaxcosmo\ contributors}
\thanks{$^\ast$[email protected]}

\affiliation{$^1$Université Paris-Saclay, CNRS/IN2P3, IJCLab, 91405 Orsay, France \\
$^4$Université Paris Cité, Université Paris-Saclay, CEA, CNRS, AIM, F-91191, Gif-sur-Yvette, France}
\affiliation{
$^1$Université Paris-Saclay, CNRS/IN2P3, IJCLab, 91405 Orsay, France\\
}

\affiliation{$^5$Department of Mathematics and Theory, Peng Cheng Laboratory, Shenzhen, Guangdong 518066, China}
\affiliation{$^6$Center for Computational Astrophysics \& Center for Computational Mathematics, Flatiron Institute, New York, New York 10010, USA}
\affiliation{$^7$Ecole Polytechnique F\'ed\'erale de Lausanne (EPFL), Observatoire de Sauverny, 1290 Versoix, Switzerland}

%\date{\today}

Expand Down Expand Up @@ -229,7 +237,7 @@ \section{Introduction}
% Note that some direct optimization of likelihood function has also been successfully undertaken for instance for some Planck analysis \citep{2014A&A...566A..54P},
Since the development of these MCMC packages, major advances have been made in automatic differentiation (\textit{autodiff}) \citep{10.5555/3122009.3242010, Margossian2019}, a set of technologies for transforming pieces of code into their derivatives.

While these tools have especially been applied to neural network optimization and machine learning in general, they can also enable classical statistical methods that require the derivatives of (e.g. likelihood) functions to operate: we consider such methods in this paper. \textit{Autodiff} has been implemented in widely used libraries like \texttt{Stan} \citep{JSSv076i01}, \texttt{TensorFlow} \citep{tensorflow2015-whitepaper}, and \texttt{PyTorch} \citep{NEURIPS2019_9015}.
While these tools have especially been applied to neural network optimization and machine learning (ML) in general, they can also enable classical statistical methods that require the derivatives of (e.g. likelihood) functions to operate: we consider such methods in this paper. \textit{Autodiff} has been implemented in widely used libraries like \texttt{Stan} \citep{JSSv076i01}, \texttt{TensorFlow} \citep{tensorflow2015-whitepaper}, and \texttt{PyTorch} \citep{NEURIPS2019_9015}.

A recent entrant to this field is the \texttt{JAX} library\footnote{\url{https://jax.readthedocs.io}} \citep{jax2018github} which has undergone rapid development and can automatically differentiate native \texttt{Python} and \texttt{NumPy} functions, offering a speed up to the development process and indeed code runtimes. \texttt{JAX} offers an easy parallelization mechanism (\texttt{vmap}), and just-in-time compilation (\texttt{jit}) and optimization targetting CPU, GPU and TPU hardware thanks to the \texttt{XLA} library that powers TensorFlow. These attractive features have driven wide adoption of \texttt{JAX} in computational research, and motivate us to consider its usage in cosmology.

Expand All @@ -250,7 +258,7 @@ \section{Introduction}
In this context we have developed the open source \jaxcosmo\ library\footnote{\url{https://github.com/DifferentiableUniverseInitiative/jax_cosmo}}, which we present in this paper. The package represents a first step in making the powerful features described above useful for cosmology; it implements a selection of theory predictions for key cosmology observables as differentiable \texttt{JAX} functions.


We give an overview of the code's design and contents in Section~\ref{sec-jaxcosmo-design}. We show how to use it for rapid and numerically stable Fisher forecasts in Section~\ref{sec-fisher-forecast}, in more complete gradient-based cosmological inference with variants of Hamiltonian Monte Carlo and its variant No-U-Turn, and ML-accelerated Stochastic Variational Inference in Section~\ref{sec:chmc}. We discuss and compare these methods in Section~\ref{sec-results} and conclude in Section~\ref{sec-conclusion}.
We give an overview of the code's design and contents in Section~\ref{sec-jaxcosmo-design}. We show how to use it for rapid and numerically stable Fisher forecasts in Section~\ref{sec-fisher-forecast}, in more complete gradient-based cosmological inference with variants of Hamiltonian Monte Carlo including the No-U-Turn Sampler, and ML-accelerated Stochastic Variational Inference in Section~\ref{sec:chmc}. We discuss and compare these methods in Section~\ref{sec-discussion} and conclude in Section~\ref{sec-conclusion}.



Expand Down Expand Up @@ -326,7 +334,7 @@ \section{Design of the \jaxcosmo\ library}

All \jaxcosmo\ data structures are organized as \texttt{JAX} \textit{pyTree} container objects, ensuring that two key \texttt{JAX} features are available to them: \textit{autodiff} and \textit{vmap}. The vmap feature enables any operation defined in \texttt{JAX} (including complicated composite operations) to be applied efficiently as a vector operation. The autodiff feature further makes it possible to take the derivative of any operation, automatically transforming a function that takes $n$ inputs and produces $m$ outputs into a new function that generates an $m \times n$ matrix of partial derivatives.

\jaxcosmo\ implements, for example, Runge-Kutta solvers for ODEs, and Simpson \& Romberg integration routines through this framework, so that we can automatically compute the derivative of their solutions with respect to their inputs. This includes the cosmological parameters (the standard set of $w_0 w_a CDM$ cosmological parameters is exposed, using $\sigma_8$ as an amplitude parameter), but also the other input quantities such as redshift distributions described below.
\jaxcosmo\ implements, for example, Runge-Kutta solvers for ODEs, and Simpson \& Romberg integration routines through this framework, so that we can automatically compute the derivatives of their solutions with respect to their inputs. This includes not only the cosmological parameters (the standard set of $w_0 w_a CDM$ cosmological parameters is exposed, using $\sigma_8$ as an amplitude parameter), but also other input quantities such as redshift distributions as described below.

In the rest of this section we describe the cosmological calculations that are implemented using these facilities.

Expand Down Expand Up @@ -380,14 +388,14 @@ \subsection{Angular power spectra}
\begin{align}
C_\ell^{i,j} \approx \left(\ell+\frac{1}{2}\right)^{m_i+m_j}\int\frac{d\chi}{c^2\chi^2}K_i(\chi)K_j(\chi)\,P\left(k=\frac{\ell+1/2}{\chi},z\right),\label{eq:Cell_limber}
\end{align}
The factors $m_i$ are equal to $(0,-2)$ for the galaxy clustering and weak lensing respectively, and the functions $K(z)$ each represent a single tomographic redshift bin's number density. These tracers are implemented as two kernel functions:
The $m_i$ factors are $(0,-2)$ for the galaxy clustering and weak lensing, respectively, and each $K(z)$ function represent a single tomographic redshift bin's number density. These tracers are implemented as two kernel functions:

\begin{description}
\item[\texttt{NumberCounts}]
\begin{equation}
K_i(z) = n_i(z)\ b(z)\ H(z)
\end{equation}
where $n_i(s)$ is the redshift distribution of the sample (eg. \texttt{jc.redshift.kde\_nz} function), $b(z)$ a galaxy bias function (see \texttt{jc.bias.constant\_linear\_bias}). No redshift space distortions are taken into account.
where $n_i(s)$ is the redshift distribution of the sample (e.g., \texttt{jc.redshift.kde\_nz} function), and $b(z)$ is the galaxy bias function (see \texttt{jc.bias.constant\_linear\_bias}). No redshift space distortions are taken into account.

\item [\texttt{WeakLensing}]
\begin{multline}
Expand All @@ -398,14 +406,14 @@ \subsection{Angular power spectra}
\begin{equation}
K_{IA}(z) = \left(\frac{(\ell+2)!}{(\ell-2)!}\right)^{1/2}\ p_i(z)\ b(z)\ H(z)\ \frac{C\ \Omega_m}{D(z)}
\end{equation}
with $C\approx 0.0134$ a dimensionless constant and $D(z)$ the growth factor of linear perturbations.
with $C\approx 0.0134$ being a dimensionless constant and $D(z)$ the growth factor of linear perturbations.
\end{description}

% Notice that all described kernel functions can be also viewed as function of the scale factor $a$ or as function of the radial comoving distance $\chi$.

% \FrL{We probably actually want to start by the Limber formula, which explicits what a kernel is, otherwise people don't necessarily know where these kernels come from.}

Because, like the other ingredients, the kernels are implemented as pyTree objects, \texttt{NumberCounts} and \texttt{WeakLensing}, all the integrals involved in these computations can be differentiated with respect to the cosmological parameters or the number densities using \textit{autodiff} and \jaxcosmo's implementation of integration quadrature. An example is given in the context of DES Y1 3x2pts analysis (Sec~\ref{sec-DESY1}).
Because, like the other ingredients, the kernels are implemented as pyTree objects, namely \texttt{NumberCounts} and \texttt{WeakLensing}, all the integrals involved in these computations can be differentiated with respect to the cosmological parameters and the number densities, using \textit{autodiff} and \jaxcosmo's implementation of integration quadrature. An example is given in the context of DES Y1 3x2pts analysis (Sec~\ref{sec-DESY1}).
%
\subsection{Validation against the Core Cosmology Library (CCL)}
%
Expand Down Expand Up @@ -540,8 +548,9 @@ \subsection{Massive Optimal Compression in 3 Lines}
Once the Fisher matrix has been accurately estimated, the MOPED\footnote{Massively Optimised Parameter Estimation and Data compression} algorithm can be used to compress data sets with minimal information loss \citep{2000MNRAS.317..965H,2016PhRvD..93h3525Z, 2017MNRAS.472.4244H}. In the case of the constant covariance matrix the algorithm compresses data in a way that is lossless at the Fisher matrix level (i.e. Fisher matrices estimated using the compressed and full data are identical, by construction) which reduces a possibly large data vector $\mu$ of size $N$ to $M$ numbers, where $M$ is the number of parameters $\theta_i$ considered. For instance, in the previous section, $N=500$ as $\mu=(C_\ell^{p,q})$ and $M=2$ for $(\Omega_c,\sigma_8)$.

The algorithm computes by iteration $M$ vectors of size $N$ such that (taking the notation of the previous section)
%JEC 13/1/23 Typo fixed
\begin{equation}
b_i = \frac{C^{-1}\mu_{,i}-\sum_{j=1}^{i-1}(\mu_{,j}^T b_j)b_j}{\sqrt{F_{i,i}-\sum_{j=1}^{i-1}(\mu_{,j}^T b_j)^2}}
b_i = \frac{C^{-1}\mu_{,i}-\sum_{j=1}^{i-1}(\mu_{,i}^T b_j)b_j}{\sqrt{F_{i,i}-\sum_{j=1}^{i-1}(\mu_{,i}^T b_j)^2}}
\label{eq:moped}
\end{equation}
where $\mu_{,i}=\partial \mu/\partial \theta_i$ ($i=1,\dots,M$). The vectors $(b_i)_{i\leq M}$ satisfy the following orthogonality constraint
Expand All @@ -555,7 +564,7 @@ \subsection{Massive Optimal Compression in 3 Lines}
These numbers are uncorrelated and of unit variance and this construction ensures that the log-likelihood of $y_i$ given $\theta_i$ is identical to that of $x$ up to second order, meaning that the Fisher matrices derived from the two parameters should be identical, and in general the $y$ values should lose very little information compared to the full likelihood.

In problems where a (constant) covariance matrix is estimated from simulations, the number of such simulations required for a given accuracy typically scales with the number of data points used.
MOPED therefore greatly reduces this number, often by a factor of hundreds. Since the uncertainty in the covariance due to a finite number of simulations must be accounted for \citep{2018MNRAS.473.2355S,2007A&A...464..399H}, this reduction can also offset any loss of information from the compression. Inaccuracies in the full covariance matrix used in the data compression result only in information loss, not bias, as does mis-specification of the fiducial parameter set $\theta_i$ where equation \ref{eq:moped}.
MOPED therefore greatly reduces this number, often by a factor of hundreds. Since the uncertainty in the covariance due to a finite number of simulations must be accounted for \citep{2018MNRAS.473.2355S,2007A&A...464..399H}, this reduction can also offset any loss of information from the compression. Inaccuracies in the full covariance matrix used in the data compression result only in information loss, not bias, as does mis-specification of the fiducial parameter set $\theta_i$ in equation \ref{eq:moped}.


% \FrL{hummm but we still do need a full covariance in the compression algorithm, right?} \FrL{And we still do need to invert it... is the idea rather that if we dont have a perfect matrix at that stage the worse that can happen is sub optimal compression?}
Expand Down Expand Up @@ -598,7 +607,7 @@ \subsection{Massive Optimal Compression in 3 Lines}
\section{Posterior Inference made fast by Gradient-Based Inference Methods}
\label{sec:chmc}
%
In the following sections we review some statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report on their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL).
In the following sections we review some statistical methods which directly benefit from the automatic differentiablity of \jaxcosmo\ likelihoods. We demonstrate a range of gradient-based methods, from Hamiltonian Monte Carlo (HMC), and its \textit{No-U-Turn Sampler} (NUTS) variant, to Stochastic Variational Inference (SVI). We further report on their respective computational costs in a DES-Y1 like analysis. All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL).
%
% and NUTS after a \textit{Neural Transport} perform using a \textit{Stochastic Variational Inference} (aka SVI). All the methods have been implemented using the \numpyro\ probabilistic programming language (PPL).
%
Expand Down Expand Up @@ -799,7 +808,7 @@ \subsection{NUTS}
% Show the advantage of using NUTS
% look at difference in efficiency in terms of how many times we need to call the model.

The No-U-Turn (\textit{NUTS}) variant of the traditional HMC sampler was introduced in \citet{10.5555/2627435.2638586}. It aims to help finding a new point $x_{i+1}$ from the current $x_i$ by finding good and dynamic choices for the leapfrog integration parameters in the root HMC algorithms, the step size $\varepsilon$ and the number of steps $L$.
The No-U-Turn Sampler (\textit{NUTS}) variant of the traditional HMC sampler was introduced in \citet{10.5555/2627435.2638586}. It aims to help finding a new point $x_{i+1}$ from the current $x_i$ by finding good and dynamic choices for the leapfrog integration parameters in the root HMC algorithms, the step size $\varepsilon$ and the number of steps $L$.

NUTS sets iterates the leapfrog algorithm not for a fixed $L$, but until the trajectory starts to ``double back'' and return to previously visited region, at the cost of increasing the number of model evaluations. The user has to set a new parameter (\texttt{max\_tree\_depth}) which gives as a power of $2$ the maximum number of model calls at each generation.

Expand Down Expand Up @@ -1055,10 +1064,13 @@ \section{Conclusions \& Prospects}


%%%
\section*{CRedit authorship contribution statement}
\section*{Credit authorship contribution statement}
\textbf{J.E Campagne:} Conceptualization, Methodology, Software, Validation, Writing, Visualization.
\textbf{F. Lanusse:}, Conceptualization, Methodology, Software, Validation, Writing, Project administration \textbf{J. Zuntz:} Investigation, Writing.
\textbf{F. Lanusse:}, Conceptualization, Methodology, Software, Validation, Writing, Project administration.
\textbf{D. Lanzieri:}, Software and validation for redshift distribution.
\textbf{J. Zuntz:} Investigation, Writing.
\textbf{D.~Kirkby:}, Software and validation for sparse linear algebra.
\textbf{A. Peel:} Software and validation for spline interpolations in JAX.

%\section*{Declaration of Competing Interest}
%The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.
Expand Down Expand Up @@ -1102,4 +1114,4 @@ \section*{Acknowledgements}

\endinput
%%
%% End of file `elsarticle-template-harv.tex'.
%% End of file `elsarticle-template-harv.tex'.
2 changes: 1 addition & 1 deletion refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ @inproceedings{KingmaB14


@article{10.5555/2627435.2638586,
author = {Homan, Matthew D. and Gelman, Andrew},
author = {Hoffman, Matthew D. and Gelman, Andrew},
title = {The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo},
year = {2014},
issue_date = {January 2014},
Expand Down

0 comments on commit 73ed089

Please sign in to comment.