Skip to content

Commit

Permalink
Merge pull request #81 from ziatdinovmax/sparse
Browse files Browse the repository at this point in the history
Add Sparse GP
  • Loading branch information
ziatdinovmax authored Feb 7, 2024
2 parents 5762b1f + 2ad3fa9 commit 1138108
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 6 deletions.
5 changes: 3 additions & 2 deletions gpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from . import acquisition
from .hypo import sample_next
from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP,
vi_iBNN, viDKL, viGP, viMTDKL, VarNoiseGP, UIGP, MeasuredNoiseGP)
vi_iBNN, viDKL, viGP, viMTDKL, VarNoiseGP, UIGP,
MeasuredNoiseGP, viSparseGP)

__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
"UIGP", "MeasuredNoiseGP", "CoregGP", "sample_next", "__version__"]
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "sample_next", "__version__"]
2 changes: 2 additions & 0 deletions gpax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .uigp import UIGP
from .mngp import MeasuredNoiseGP
from .linreg import LinReg
from .sparse_gp import viSparseGP

__all__ = [
"ExactGP",
Expand All @@ -30,4 +31,5 @@
"UIGP",
"LinReg",
"MeasuredNoiseGP"
"viSparseGP"
]
223 changes: 223 additions & 0 deletions gpax/models/sparse_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""
sparse_gp.py
============
Variational inference implementation of sparse Gaussian process regression
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Callable, Dict, Optional, Tuple, Type

import jax
import jaxlib
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, solve_triangular

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO

from .vigp import viGP
from ..utils import initialize_inducing_points


class viSparseGP(viGP):
"""
Variational inference-based sparse Gaussian process
Args:
input_dim:
Number of input dimensions
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
mean_fn:
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior:
Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default
mean_fn_prior:
Optional priors over mean function parameters
noise_prior_dist:
Optional custom prior distribution over the observational noise variance.
Defaults to LogNormal(0,1).
lengthscale_prior_dist:
Optional custom prior distribution over kernel lengthscale.
Defaults to LogNormal(0, 1).
guide:
Auto-guide option, use 'delta' (default) or 'normal'
"""
def __init__(self, input_dim: int, kernel: str,
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
guide: str = 'delta') -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior,
noise_prior_dist, lengthscale_prior_dist, guide)
super(viSparseGP, self).__init__(*args)
self.Xu = None

def model(self,
X: jnp.ndarray,
y: jnp.ndarray = None,
Xu: jnp.ndarray = None,
**kwargs: float) -> None:
"""
Probabilistic sparse Gaussian process regression model
"""
if Xu is not None:
Xu = numpyro.param("Xu", Xu)
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
# Sample kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Sample noise
if self.noise_prior: # this will be removed in the future releases
noise = self.noise_prior()
else:
noise = self._sample_noise()
D = jnp.broadcast_to(noise, (X.shape[0],) )
# Add mean function (if any)
if self.mean_fn is not None:
args = [X]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# Compute kernel between inducing points
Kuu = self.kernel(Xu, Xu, kernel_params, **kwargs)
# Cholesky decomposition
Luu = cholesky(Kuu).T
# Compute kernel between inducing and training points
Kuf = self.kernel(Xu, X, kernel_params)
# Solve triangular system
W = solve_triangular(Luu, Kuf, lower=True).T
# Diagonal of the kernel matrix
Kffdiag = jnp.diag(self.kernel(X, X, kernel_params, jitter=0))
# Sum of squares computation
Qffdiag = jnp.square(W).sum(axis=-1)
# Trace term computation
trace_term = (Kffdiag - Qffdiag).sum() / noise
# Clamping the trace term
trace_term = jnp.clip(trace_term, a_min=0)

# VFE approximation
numpyro.factor("trace_term", -trace_term / 2.0)

numpyro.sample(
"y",
dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D),
obs=y)

def fit(self,
rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'random',
num_steps: int = 1000, step_size: float = 5e-3,
progress_bar: bool = True, print_summary: bool = True,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> None:
"""
Run variational inference to learn sparse GP (hyper)parameters
Args:
rng_key: random number generator key
X: 2D feature vector with *(number of points, number of features)* dimensions
y: 1D target vector with *(n,)* dimensions
Xu: Inducing points ratio. Must be a float between 0 and 1. Default value is 0.1.
num_steps: number of SVI steps
step_size: step size schedule for Adam optimizer
progress_bar: show progress bar
print_summary: print summary at the end of training
device:
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
Xu = initialize_inducing_points(
X.copy(), inducing_points_ratio,
inducing_points_selection, rng_key)
self.X_train = X
self.y_train = y

optim = numpyro.optim.Adam(step_size=step_size, b1=0.5)
self.svi = SVI(
self.model,
guide=self.guide_type(self.model),
optim=optim,
loss=Trace_ELBO(),
X=X,
y=y,
Xu=Xu,
**kwargs
)

self.kernel_params = self.svi.run(
rng_key, num_steps, progress_bar=progress_bar)[0]

self.Xu = self.kernel_params['Xu']

if print_summary:
self._print_summary()

def get_mvn_posterior(self, X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of GP parameters
"""
noise = params["noise"]
N = self.X_train.shape[0]
D = jnp.broadcast_to(noise, (N,))
noise_p = noise * (1 - jnp.array(noiseless, int))

y_residual = self.y_train.copy()
if self.mean_fn is not None:
args = [self.X_train, params] if self.mean_fn_prior else [self.X_train]
y_residual -= self.mean_fn(*args).squeeze()

# Compute self- and cross-covariance matrices
Kuu = self.kernel(self.Xu, self.Xu, params, **kwargs)
Luu = cholesky(Kuu, lower=True)
Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0)

W = solve_triangular(Luu, Kuf, lower=True)
W_Dinv = W / D
K = W_Dinv @ W.T
K = K.at[jnp.diag_indices(K.shape[0])].add(1)
L = cholesky(K, lower=True)

y_2D = y_residual.reshape(-1, N).T
W_Dinv_y = W_Dinv @ y_2D

Kus = self.kernel(self.Xu, X_new, params, jitter=0)
Ws = solve_triangular(Luu, Kus, lower=True)
pack = jnp.concatenate((W_Dinv_y, Ws), axis=1)
Linv_pack = solve_triangular(L, pack, lower=True)

Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]]
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:]
mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze()

Kss = self.kernel(X_new, X_new, params, noise_p, **kwargs)
Qss = Ws.T @ Ws
cov = Kss - Qss + Linv_Ws.T @ Linv_Ws

if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()

return mean, cov
8 changes: 6 additions & 2 deletions gpax/models/vigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ class viGP(ExactGP):
Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default
mean_fn_prior:
Optional priors over mean function parameters
noise_prior:
Optional custom prior for the observation noise variance; uses LogNormal(0,1) by default.
noise_prior_dist:
Optional custom prior distribution over the observational noise variance.
Defaults to LogNormal(0,1).
lengthscale_prior_dist:
Optional custom prior distribution over kernel lengthscale.
Defaults to LogNormal(0, 1).
guide:
Auto-guide option, use 'delta' (default) or 'normal'
Expand Down
44 changes: 44 additions & 0 deletions gpax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,47 @@ def preprocess_sparse_image(sparse_image):
# Generate indices for the entire image
full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim)
return gp_input, targets, full_indices


def initialize_inducing_points(X, ratio=0.1, method='uniform', key=None):
"""
Initialize inducing points for a sparse Gaussian Process in JAX.
Parameters:
- X: A (n_samples, num_features) array of training data.
- ratio: A float between 0 and 1 indicating the fraction of inducing points.
- method: A string indicating the method for selecting inducing points ('uniform', 'random', 'kmeans').
- key: A JAX random key, required if method is 'random'.
Returns:
- inducing_points: A subset of X used as inducing points.
"""
if not 0 < ratio < 1:
raise ValueError("The 'ratio' value must be between 0 and 1")

n_samples = X.shape[0]
n_inducing = int(n_samples * ratio)

if method == 'uniform':
indices = jnp.linspace(0, n_samples - 1, n_inducing, dtype=jnp.int8)
inducing_points = X[indices]
elif method == 'random':
if key is None:
raise ValueError("A JAX random key must be provided for random selection")
indices = jax.random.choice(key, n_samples, shape=(n_inducing,), replace=False)
inducing_points = X[indices]
elif method == 'kmeans':
try:
from sklearn.cluster import KMeans # noqa: F401
except ImportError as e:
raise ImportError(
"You need to install `seaborn` to be able to use this feature. "
"It can be installed with `pip install scikit-learn`."
) from e
# Use sklearn for KMeans clustering, then convert result to JAX array
kmeans = KMeans(n_clusters=n_inducing, random_state=0).fit(X)
inducing_points = jnp.array(kmeans.cluster_centers_)
else:
raise ValueError("Method must be 'uniform', 'random', or 'kmeans'")

return inducing_points
64 changes: 64 additions & 0 deletions tests/test_sparsegp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import sys
import pytest
import numpy as onp
import jax.numpy as jnp
import jax
from numpy.testing import assert_equal

sys.path.insert(0, "../gpax/")

from gpax.models.sparse_gp import viSparseGP
from gpax.utils import get_keys, enable_x64

enable_x64()


def get_dummy_data(jax_ndarray=True, unsqueeze=False):
X = onp.linspace(1, 2, 50) + 0.1 * onp.random.randn(50,)
y = (10 * X**2)
if unsqueeze:
X = X[:, None]
if jax_ndarray:
return jnp.array(X), jnp.array(y)
return X, y


@pytest.mark.parametrize("jax_ndarray", [True, False])
@pytest.mark.parametrize("unsqueeze", [True, False])
def test_fit(jax_ndarray, unsqueeze):
rng_key = get_keys()[0]
X, y = get_dummy_data(jax_ndarray, unsqueeze)
m = viSparseGP(1, 'Matern')
m.fit(rng_key, X, y, num_steps=100)
assert m.svi is not None
assert isinstance(m.Xu, jnp.ndarray)


def test_inducing_points_optimization():
rng_key = get_keys()[0]
X, y = get_dummy_data()
m1 = viSparseGP(1, 'Matern')
m1.fit(rng_key, X, y, num_steps=1)
m2 = viSparseGP(1, 'Matern')
m2.fit(rng_key, X, y, num_steps=100)
assert not jnp.array_equal(m1.Xu, m2.Xu)


def test_get_mvn_posterior():
rng_keys = get_keys()
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
params = {"k_length": jax.random.normal(rng_keys[0], shape=(1, 1)),
"k_scale": jax.random.normal(rng_keys[0], shape=(1,)),
"noise": jax.random.normal(rng_keys[0], shape=(1,))}
m = viSparseGP(1, 'RBF')
m.X_train = X
m.y_train = y
m.Xu = X[::2].copy()

mean, cov = m.get_mvn_posterior(X_test, params)

assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, X_test.squeeze().shape)
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))
Loading

0 comments on commit 1138108

Please sign in to comment.