GPX is Gaussian Process Regression, written in JAX.
GPX currently supports:
- Standard GPR
- Sparse GPR (SGPR) in the Projected Processes Approximation
- SGPR in the Projected Processes Approximation, with landmark selection using the Randomly Pivoted Cholesky Decomposition
- Radial Basis Function Networks
- Training on target values or on derivative values (using the Hessian kernel)
- Kernels with automatic support for gradient and Hessian
- Dense and sparse operations, the latter of which are important to scale GP to large datasets.
- Iterative estimation of the log marginal likelihood with stochastic trace estimation and Lanczos quadrature.
- Interface to scipy, nlopt, and optax optimizers
An environment with python 3.10 is recommended. You can create it with conda
, virtualenv
, or pyenv
.
Then simply clone the project and install it with pip
.
For example, using conda:
conda create -n gpx-env python=3.10
conda activate gpx-env
git clone https://github.com/Molecolab-Pisa/GPX
cd GPX
pip install .
If you need to install JAX with GPU support, install JAX first following the instructions provided by JAX.
You may want to look at our list of examples:
- GPR
- SGPR
- SGPR with RPCholesky
- GPR with derivatives
- Simple Multioutput GP
- Interface to NLOpt
- Kernelizers and Kernel Operations
- Maximum a Posteriori estimate
- Model Persistence in GPX
- Kernel derivatives
In order to cite GPX you can use the following bibtex entry:
@software{gpx2023github,
author = {Edoardo Cignoni and Patrizia Mazzeo and Amanda Arcidiacono and Lorenzo Cupellini and Benedetta Mennucci},
title = {GPX: Gaussian Process Regression in JAX},
url = {https://github.com/Molecolab-Pisa/GPX},
version = {0.1.0},
year = {2023},
}