-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from ziatdinovmax/sparse
Add Sparse GP
- Loading branch information
Showing
7 changed files
with
384 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
""" | ||
sparse_gp.py | ||
============ | ||
Variational inference implementation of sparse Gaussian process regression | ||
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
from typing import Callable, Dict, Optional, Tuple, Type | ||
|
||
import jax | ||
import jaxlib | ||
import jax.numpy as jnp | ||
from jax.scipy.linalg import cholesky, solve_triangular | ||
|
||
import numpyro | ||
import numpyro.distributions as dist | ||
from numpyro.infer import SVI, Trace_ELBO | ||
|
||
from .vigp import viGP | ||
from ..utils import initialize_inducing_points | ||
|
||
|
||
class viSparseGP(viGP): | ||
""" | ||
Variational inference-based sparse Gaussian process | ||
Args: | ||
input_dim: | ||
Number of input dimensions | ||
kernel: | ||
Kernel function ('RBF', 'Matern', 'Periodic', or custom function) | ||
mean_fn: | ||
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) | ||
kernel_prior: | ||
Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default | ||
mean_fn_prior: | ||
Optional priors over mean function parameters | ||
noise_prior_dist: | ||
Optional custom prior distribution over the observational noise variance. | ||
Defaults to LogNormal(0,1). | ||
lengthscale_prior_dist: | ||
Optional custom prior distribution over kernel lengthscale. | ||
Defaults to LogNormal(0, 1). | ||
guide: | ||
Auto-guide option, use 'delta' (default) or 'normal' | ||
""" | ||
def __init__(self, input_dim: int, kernel: str, | ||
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, | ||
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, | ||
noise_prior_dist: Optional[dist.Distribution] = None, | ||
lengthscale_prior_dist: Optional[dist.Distribution] = None, | ||
guide: str = 'delta') -> None: | ||
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, | ||
noise_prior_dist, lengthscale_prior_dist, guide) | ||
super(viSparseGP, self).__init__(*args) | ||
self.Xu = None | ||
|
||
def model(self, | ||
X: jnp.ndarray, | ||
y: jnp.ndarray = None, | ||
Xu: jnp.ndarray = None, | ||
**kwargs: float) -> None: | ||
""" | ||
Probabilistic sparse Gaussian process regression model | ||
""" | ||
if Xu is not None: | ||
Xu = numpyro.param("Xu", Xu) | ||
# Initialize mean function at zeros | ||
f_loc = jnp.zeros(X.shape[0]) | ||
# Sample kernel parameters | ||
if self.kernel_prior: | ||
kernel_params = self.kernel_prior() | ||
else: | ||
kernel_params = self._sample_kernel_params() | ||
# Sample noise | ||
if self.noise_prior: # this will be removed in the future releases | ||
noise = self.noise_prior() | ||
else: | ||
noise = self._sample_noise() | ||
D = jnp.broadcast_to(noise, (X.shape[0],) ) | ||
# Add mean function (if any) | ||
if self.mean_fn is not None: | ||
args = [X] | ||
if self.mean_fn_prior is not None: | ||
args += [self.mean_fn_prior()] | ||
f_loc += self.mean_fn(*args).squeeze() | ||
# Compute kernel between inducing points | ||
Kuu = self.kernel(Xu, Xu, kernel_params, **kwargs) | ||
# Cholesky decomposition | ||
Luu = cholesky(Kuu).T | ||
# Compute kernel between inducing and training points | ||
Kuf = self.kernel(Xu, X, kernel_params) | ||
# Solve triangular system | ||
W = solve_triangular(Luu, Kuf, lower=True).T | ||
# Diagonal of the kernel matrix | ||
Kffdiag = jnp.diag(self.kernel(X, X, kernel_params, jitter=0)) | ||
# Sum of squares computation | ||
Qffdiag = jnp.square(W).sum(axis=-1) | ||
# Trace term computation | ||
trace_term = (Kffdiag - Qffdiag).sum() / noise | ||
# Clamping the trace term | ||
trace_term = jnp.clip(trace_term, a_min=0) | ||
|
||
# VFE approximation | ||
numpyro.factor("trace_term", -trace_term / 2.0) | ||
|
||
numpyro.sample( | ||
"y", | ||
dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D), | ||
obs=y) | ||
|
||
def fit(self, | ||
rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, | ||
inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'random', | ||
num_steps: int = 1000, step_size: float = 5e-3, | ||
progress_bar: bool = True, print_summary: bool = True, | ||
device: Type[jaxlib.xla_extension.Device] = None, | ||
**kwargs: float | ||
) -> None: | ||
""" | ||
Run variational inference to learn sparse GP (hyper)parameters | ||
Args: | ||
rng_key: random number generator key | ||
X: 2D feature vector with *(number of points, number of features)* dimensions | ||
y: 1D target vector with *(n,)* dimensions | ||
Xu: Inducing points ratio. Must be a float between 0 and 1. Default value is 0.1. | ||
num_steps: number of SVI steps | ||
step_size: step size schedule for Adam optimizer | ||
progress_bar: show progress bar | ||
print_summary: print summary at the end of training | ||
device: | ||
optionally specify a cpu or gpu device on which to run the inference; | ||
e.g., ``device=jax.devices("cpu")[0]`` | ||
**jitter: | ||
Small positive term added to the diagonal part of a covariance | ||
matrix for numerical stability (Default: 1e-6) | ||
""" | ||
X, y = self._set_data(X, y) | ||
if device: | ||
X = jax.device_put(X, device) | ||
y = jax.device_put(y, device) | ||
Xu = initialize_inducing_points( | ||
X.copy(), inducing_points_ratio, | ||
inducing_points_selection, rng_key) | ||
self.X_train = X | ||
self.y_train = y | ||
|
||
optim = numpyro.optim.Adam(step_size=step_size, b1=0.5) | ||
self.svi = SVI( | ||
self.model, | ||
guide=self.guide_type(self.model), | ||
optim=optim, | ||
loss=Trace_ELBO(), | ||
X=X, | ||
y=y, | ||
Xu=Xu, | ||
**kwargs | ||
) | ||
|
||
self.kernel_params = self.svi.run( | ||
rng_key, num_steps, progress_bar=progress_bar)[0] | ||
|
||
self.Xu = self.kernel_params['Xu'] | ||
|
||
if print_summary: | ||
self._print_summary() | ||
|
||
def get_mvn_posterior(self, X_new: jnp.ndarray, | ||
params: Dict[str, jnp.ndarray], | ||
noiseless: bool = False, | ||
**kwargs: float | ||
) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
""" | ||
Returns parameters (mean and cov) of multivariate normal posterior | ||
for a single sample of GP parameters | ||
""" | ||
noise = params["noise"] | ||
N = self.X_train.shape[0] | ||
D = jnp.broadcast_to(noise, (N,)) | ||
noise_p = noise * (1 - jnp.array(noiseless, int)) | ||
|
||
y_residual = self.y_train.copy() | ||
if self.mean_fn is not None: | ||
args = [self.X_train, params] if self.mean_fn_prior else [self.X_train] | ||
y_residual -= self.mean_fn(*args).squeeze() | ||
|
||
# Compute self- and cross-covariance matrices | ||
Kuu = self.kernel(self.Xu, self.Xu, params, **kwargs) | ||
Luu = cholesky(Kuu, lower=True) | ||
Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0) | ||
|
||
W = solve_triangular(Luu, Kuf, lower=True) | ||
W_Dinv = W / D | ||
K = W_Dinv @ W.T | ||
K = K.at[jnp.diag_indices(K.shape[0])].add(1) | ||
L = cholesky(K, lower=True) | ||
|
||
y_2D = y_residual.reshape(-1, N).T | ||
W_Dinv_y = W_Dinv @ y_2D | ||
|
||
Kus = self.kernel(self.Xu, X_new, params, jitter=0) | ||
Ws = solve_triangular(Luu, Kus, lower=True) | ||
pack = jnp.concatenate((W_Dinv_y, Ws), axis=1) | ||
Linv_pack = solve_triangular(L, pack, lower=True) | ||
|
||
Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]] | ||
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] | ||
mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze() | ||
|
||
Kss = self.kernel(X_new, X_new, params, noise_p, **kwargs) | ||
Qss = Ws.T @ Ws | ||
cov = Kss - Qss + Linv_Ws.T @ Linv_Ws | ||
|
||
if self.mean_fn is not None: | ||
args = [X_new, params] if self.mean_fn_prior else [X_new] | ||
mean += self.mean_fn(*args).squeeze() | ||
|
||
return mean, cov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import sys | ||
import pytest | ||
import numpy as onp | ||
import jax.numpy as jnp | ||
import jax | ||
from numpy.testing import assert_equal | ||
|
||
sys.path.insert(0, "../gpax/") | ||
|
||
from gpax.models.sparse_gp import viSparseGP | ||
from gpax.utils import get_keys, enable_x64 | ||
|
||
enable_x64() | ||
|
||
|
||
def get_dummy_data(jax_ndarray=True, unsqueeze=False): | ||
X = onp.linspace(1, 2, 50) + 0.1 * onp.random.randn(50,) | ||
y = (10 * X**2) | ||
if unsqueeze: | ||
X = X[:, None] | ||
if jax_ndarray: | ||
return jnp.array(X), jnp.array(y) | ||
return X, y | ||
|
||
|
||
@pytest.mark.parametrize("jax_ndarray", [True, False]) | ||
@pytest.mark.parametrize("unsqueeze", [True, False]) | ||
def test_fit(jax_ndarray, unsqueeze): | ||
rng_key = get_keys()[0] | ||
X, y = get_dummy_data(jax_ndarray, unsqueeze) | ||
m = viSparseGP(1, 'Matern') | ||
m.fit(rng_key, X, y, num_steps=100) | ||
assert m.svi is not None | ||
assert isinstance(m.Xu, jnp.ndarray) | ||
|
||
|
||
def test_inducing_points_optimization(): | ||
rng_key = get_keys()[0] | ||
X, y = get_dummy_data() | ||
m1 = viSparseGP(1, 'Matern') | ||
m1.fit(rng_key, X, y, num_steps=1) | ||
m2 = viSparseGP(1, 'Matern') | ||
m2.fit(rng_key, X, y, num_steps=100) | ||
assert not jnp.array_equal(m1.Xu, m2.Xu) | ||
|
||
|
||
def test_get_mvn_posterior(): | ||
rng_keys = get_keys() | ||
X, y = get_dummy_data(unsqueeze=True) | ||
X_test, _ = get_dummy_data(unsqueeze=True) | ||
params = {"k_length": jax.random.normal(rng_keys[0], shape=(1, 1)), | ||
"k_scale": jax.random.normal(rng_keys[0], shape=(1,)), | ||
"noise": jax.random.normal(rng_keys[0], shape=(1,))} | ||
m = viSparseGP(1, 'RBF') | ||
m.X_train = X | ||
m.y_train = y | ||
m.Xu = X[::2].copy() | ||
|
||
mean, cov = m.get_mvn_posterior(X_test, params) | ||
|
||
assert isinstance(mean, jnp.ndarray) | ||
assert isinstance(cov, jnp.ndarray) | ||
assert_equal(mean.shape, X_test.squeeze().shape) | ||
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0])) |
Oops, something went wrong.