From 6e9f26229020c2122d7535be0ae09843a0be081d Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Mon, 5 Feb 2024 19:52:12 +0000 Subject: [PATCH 01/17] Add sparse GP --- gpax/models/__init__.py | 2 + gpax/models/sparse_gp.py | 226 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 gpax/models/sparse_gp.py diff --git a/gpax/models/__init__.py b/gpax/models/__init__.py index 5fa0f2d..09d309a 100644 --- a/gpax/models/__init__.py +++ b/gpax/models/__init__.py @@ -13,6 +13,7 @@ from .uigp import UIGP from .mngp import MeasuredNoiseGP from .linreg import LinReg +from .sparse_gp import viSparseGP __all__ = [ "ExactGP", @@ -30,4 +31,5 @@ "UIGP", "LinReg", "MeasuredNoiseGP" + "viSparseGP" ] diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py new file mode 100644 index 0000000..a2e3a07 --- /dev/null +++ b/gpax/models/sparse_gp.py @@ -0,0 +1,226 @@ +""" +sparse_gp.py +======= + +Variational inference implementation of sparse Gaussian process regression + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) +""" + +from typing import Callable, Dict, Optional, Tuple, Type + +import jax +import jaxlib +import jax.numpy as jnp +from jax.scipy.linalg import cholesky, solve_triangular + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoDelta, AutoNormal + +from .gp import ExactGP + + +class viSparseGP(ExactGP): + """ + Variational inference based sparse Gaussian process + + Args: + input_dim: + Number of input dimensions + kernel: + Kernel function ('RBF', 'Matern', 'Periodic', or custom function) + mean_fn: + Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) + kernel_prior: + Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default + mean_fn_prior: + Optional priors over mean function parameters + noise_prior: + Optional custom prior for the observation noise variance; uses LogNormal(0,1) by default. + guide: + Auto-guide option, use 'delta' (default) or 'normal' + """ + def __init__(self, input_dim: int, kernel: str, + mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, + kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_prior_dist: Optional[dist.Distribution] = None, + lengthscale_prior_dist: Optional[dist.Distribution] = None, + guide: str = 'delta') -> None: + args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, + noise_prior_dist, lengthscale_prior_dist) + super(viSparseGP, self).__init__(*args) + self.X_train = None + self.y_train = None + self.Xu = None + self.guide_type = AutoNormal if guide == 'normal' else AutoDelta + self.svi = None + + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: + + # Initialize mean function at zeros + f_loc = jnp.zeros(X.shape[0]) + # Sample kernel parameters + if self.kernel_prior: + kernel_params = self.kernel_prior() + else: + kernel_params = self._sample_kernel_params() + # Sample noise + if self.noise_prior: # this will be removed in the future releases + noise = self.noise_prior() + else: + noise = self._sample_noise() + D = jnp.broadcast_to(noise, (X.shape[0],) ) + # Add mean function (if any) + if self.mean_fn is not None: + args = [X] + if self.mean_fn_prior is not None: + args += [self.mean_fn_prior()] + f_loc += self.mean_fn(*args).squeeze() + # compute kernel between inducing points + Kuu = self.kernel(self.Xu, self.Xu, kernel_params) + # Cholesky decomposition + Luu = cholesky(Kuu).T + # Kernel computation + Kuf = self.kernel(self.Xu, X, kernel_params) + # Solve triangular system + W = solve_triangular(Luu, Kuf, lower=True).T + # Diagonal of the kernel matrix + Kffdiag = jnp.diag(self.kernel(X, X, kernel_params, jitter=0)) + # Sum of squares computation + Qffdiag = jnp.square(W).sum(axis=-1) + # Trace term computation + trace_term = (Kffdiag - Qffdiag).sum() / noise + # Clamping the trace term + trace_term = jnp.clip(trace_term, a_min=0) + + # VFE + numpyro.factor("trace_term", -trace_term / 2.0) + + numpyro.sample( + "y", + dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D), + obs=y) + + def fit(self, + rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, Xu: jnp.ndarray, + num_steps: int = 1000, step_size: float = 5e-3, + progress_bar: bool = True, print_summary: bool = True, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> None: + """ + Run variational inference to learn GP (hyper)parameters + + Args: + rng_key: random number generator key + X: 2D feature vector with *(number of points, number of features)* dimensions + y: 1D target vector with *(n,)* dimensions + num_steps: number of SVI steps + step_size: step size schedule for Adam optimizer + progress_bar: show progress bar + print_summary: print summary at the end of training + device: + optionally specify a cpu or gpu device on which to run the inference; + e.g., ``device=jax.devices("cpu")[0]`` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + X, y = self._set_data(X, y) + if device: + X = jax.device_put(X, device) + y = jax.device_put(y, device) + self.X_train = X + self.y_train = y + + self.Xu = numpyro.param(Xu) + + optim = numpyro.optim.Adam(step_size=step_size, b1=0.5) + self.svi = SVI( + self.model, + guide=self.guide_type(self.model), + optim=optim, + loss=Trace_ELBO(), + X=X, + y=y, + **kwargs + ) + + self.kernel_params = self.svi.run( + rng_key, num_steps, progress_bar=progress_bar)[0] + + if print_summary: + self._print_summary() + + def get_samples(self) -> Dict[str, jnp.ndarray]: + """Get posterior samples""" + return self.svi.guide.median(self.kernel_params) + + def predict_in_batches(self, rng_key: jnp.ndarray, + X_new: jnp.ndarray, batch_size: int = 100, + samples: Optional[Dict[str, jnp.ndarray]] = None, + predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, + noiseless: bool = False, + device: Type[jaxlib.xla_extension.Device] = None, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Make prediction at X_new with sampled GP parameters + by spitting the input array into chunks ("batches") and running + predict_fn (defaults to self.predict) on each of them one-by-one + to avoid a memory overflow + """ + predict_fn = lambda xi: self.predict( + rng_key, xi, samples, noiseless, **kwargs) + y_pred, y_var = self._predict_in_batches( + rng_key, X_new, batch_size, 0, samples, + predict_fn=predict_fn, noiseless=noiseless, + device=device, **kwargs) + y_pred = jnp.concatenate(y_pred, 0) + y_var = jnp.concatenate(y_var, 0) + return y_pred, y_var + + def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, + samples: Optional[Dict[str, jnp.ndarray]] = None, + noiseless: bool = False, + device: Type[jaxlib.xla_extension.Device] = None, **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Make prediction at X_new points using posterior samples for GP parameters + + Args: + rng_key: random number generator key + X_new: new inputs with *(number of points, number of features)* dimensions + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + by default for the training data, we also want to include that noise in our prediction. + device: + optionally specify a cpu or gpu device on which to make a prediction; + e.g., ```device=jax.devices("gpu")[0]``` + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + + Returns + Center of the mass of sampled means and all the sampled predictions + """ + X_new = self._set_data(X_new) + if device: + self._set_training_data(device=device) + X_new = jax.device_put(X_new, device) + if samples is None: + samples = self.get_samples() + mean, cov = self.get_mvn_posterior(X_new, samples, noiseless, **kwargs) + return mean, cov.diagonal() + + def _print_summary(self) -> None: + params_map = self.get_samples() + print('\nInferred GP parameters') + for (k, vals) in params_map.items(): + spaces = " " * (15 - len(k)) + print(k, spaces, jnp.around(vals, 4)) From 918025a8c4541abc48e4b470bcbbe4fd1b395e1a Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Mon, 5 Feb 2024 20:03:08 +0000 Subject: [PATCH 02/17] Update the inpducing points initialization --- gpax/models/sparse_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index a2e3a07..a11df00 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -137,7 +137,7 @@ def fit(self, self.X_train = X self.y_train = y - self.Xu = numpyro.param(Xu) + self.Xu = numpyro.param("Xu", Xu) optim = numpyro.optim.Adam(step_size=step_size, b1=0.5) self.svi = SVI( From 6a79271e7a55c4e909dd410ba1681c4429a64fc2 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:43:52 +0000 Subject: [PATCH 03/17] Update handling of inducing points --- gpax/models/sparse_gp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index a11df00..b42f97b 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -59,8 +59,9 @@ def __init__(self, input_dim: int, kernel: str, self.guide_type = AutoNormal if guide == 'normal' else AutoDelta self.svi = None - def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: - + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray, **kwargs: float) -> None: + if Xu is not None: + self.Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[0]) # Sample kernel parameters @@ -137,8 +138,6 @@ def fit(self, self.X_train = X self.y_train = y - self.Xu = numpyro.param("Xu", Xu) - optim = numpyro.optim.Adam(step_size=step_size, b1=0.5) self.svi = SVI( self.model, @@ -147,6 +146,7 @@ def fit(self, loss=Trace_ELBO(), X=X, y=y, + Xu=Xu, **kwargs ) From 00372f1f0f09d44f5f2bf5037c7cea09babce90b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:46:55 +0000 Subject: [PATCH 04/17] Fix the arguments order --- gpax/models/sparse_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index b42f97b..afe92ae 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -59,7 +59,7 @@ def __init__(self, input_dim: int, kernel: str, self.guide_type = AutoNormal if guide == 'normal' else AutoDelta self.svi = None - def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray, **kwargs: float) -> None: + def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: if Xu is not None: self.Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros From 593509555ec33bc3c445d2addffc8210eb5f29b1 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:59:37 +0000 Subject: [PATCH 05/17] Update Xu handling --- gpax/models/sparse_gp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index afe92ae..ce654c3 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -55,13 +55,12 @@ def __init__(self, input_dim: int, kernel: str, super(viSparseGP, self).__init__(*args) self.X_train = None self.y_train = None - self.Xu = None self.guide_type = AutoNormal if guide == 'normal' else AutoDelta self.svi = None def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: if Xu is not None: - self.Xu = numpyro.param("Xu", Xu) + Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[0]) # Sample kernel parameters @@ -82,11 +81,11 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # compute kernel between inducing points - Kuu = self.kernel(self.Xu, self.Xu, kernel_params) + Kuu = self.kernel(Xu, Xu, kernel_params) # Cholesky decomposition Luu = cholesky(Kuu).T # Kernel computation - Kuf = self.kernel(self.Xu, X, kernel_params) + Kuf = self.kernel(Xu, X, kernel_params) # Solve triangular system W = solve_triangular(Luu, Kuf, lower=True).T # Diagonal of the kernel matrix From a5b87cf0005e7a45293e01bbf4c19ddd7b2bd2b3 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 20:08:00 +0000 Subject: [PATCH 06/17] Add a utility for selecting inducing points --- gpax/utils/utils.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/gpax/utils/utils.py b/gpax/utils/utils.py index d94b742..2f5e72e 100644 --- a/gpax/utils/utils.py +++ b/gpax/utils/utils.py @@ -165,3 +165,47 @@ def preprocess_sparse_image(sparse_image): # Generate indices for the entire image full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim) return gp_input, targets, full_indices + + +def initialize_inducing_points(X, ratio=0.1, method='uniform', key=None): + """ + Initialize inducing points for a sparse Gaussian Process in JAX. + + Parameters: + - X: A (n_samples, num_features) array of training data. + - ratio: A float between 0 and 1 indicating the fraction of inducing points. + - method: A string indicating the method for selecting inducing points ('uniform', 'random', 'kmeans'). + - key: A JAX random key, required if method is 'random'. + + Returns: + - inducing_points: A subset of X used as inducing points. + """ + if not 0 < ratio < 1: + raise ValueError("The 'ratio' value must be between 0 and 1") + + n_samples = X.shape[0] + n_inducing = int(n_samples * ratio) + + if method == 'uniform': + indices = jnp.linspace(0, n_samples - 1, n_inducing, dtype=jnp.int8) + inducing_points = X[indices] + elif method == 'random': + if key is None: + raise ValueError("A JAX random key must be provided for random selection") + indices = jax.random.choice(key, n_samples, shape=(n_inducing,), replace=False) + inducing_points = X[indices] + elif method == 'kmeans': + try: + from sklearn.cluster import KMeans # noqa: F401 + except ImportError as e: + raise ImportError( + "You need to install `seaborn` to be able to use this feature. " + "It can be installed with `pip install scikit-learn`." + ) from e + # Use sklearn for KMeans clustering, then convert result to JAX array + kmeans = KMeans(n_clusters=n_inducing, random_state=0).fit(X) + inducing_points = jnp.array(kmeans.cluster_centers_) + else: + raise ValueError("Method must be 'uniform', 'random', or 'kmeans'") + + return inducing_points From 0bceabc34c2b9a6b3a31583b6b09e2dc2b3c96f5 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:05:32 +0000 Subject: [PATCH 07/17] Update inducing points hadling --- gpax/models/sparse_gp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index ce654c3..7fb546b 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -20,6 +20,7 @@ from numpyro.infer.autoguide import AutoDelta, AutoNormal from .gp import ExactGP +from ..utils import initialize_inducing_points class viSparseGP(ExactGP): @@ -55,6 +56,7 @@ def __init__(self, input_dim: int, kernel: str, super(viSparseGP, self).__init__(*args) self.X_train = None self.y_train = None + self.Xu = None self.guide_type = AutoNormal if guide == 'normal' else AutoDelta self.svi = None @@ -97,7 +99,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * # Clamping the trace term trace_term = jnp.clip(trace_term, a_min=0) - # VFE + # VFE approximation numpyro.factor("trace_term", -trace_term / 2.0) numpyro.sample( @@ -106,7 +108,8 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * obs=y) def fit(self, - rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, Xu: jnp.ndarray, + rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, + inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'uniform', num_steps: int = 1000, step_size: float = 5e-3, progress_bar: bool = True, print_summary: bool = True, device: Type[jaxlib.xla_extension.Device] = None, @@ -119,6 +122,7 @@ def fit(self, rng_key: random number generator key X: 2D feature vector with *(number of points, number of features)* dimensions y: 1D target vector with *(n,)* dimensions + Xu: Inducing points ratio. Must be a float between 0 and 1. Default value is 0.1. num_steps: number of SVI steps step_size: step size schedule for Adam optimizer progress_bar: show progress bar @@ -134,6 +138,9 @@ def fit(self, if device: X = jax.device_put(X, device) y = jax.device_put(y, device) + Xu = initialize_inducing_points( + X.copy(), inducing_points_ratio, + inducing_points_selection, rng_key) self.X_train = X self.y_train = y @@ -152,6 +159,8 @@ def fit(self, self.kernel_params = self.svi.run( rng_key, num_steps, progress_bar=progress_bar)[0] + self.Xu = self.kernel_params['Xu'] + if print_summary: self._print_summary() From 99ee5b02f7ce4ea68fe061488876a57d3926c5a3 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:42:26 +0000 Subject: [PATCH 08/17] Add sparse GP prediction --- gpax/models/sparse_gp.py | 54 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index 7fb546b..880ddf2 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -82,11 +82,11 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() - # compute kernel between inducing points + # Xompute kernel between inducing points Kuu = self.kernel(Xu, Xu, kernel_params) # Cholesky decomposition Luu = cholesky(Kuu).T - # Kernel computation + # Compute kernel between inducing and training points Kuf = self.kernel(Xu, X, kernel_params) # Solve triangular system W = solve_triangular(Luu, Kuf, lower=True).T @@ -164,6 +164,56 @@ def fit(self, if print_summary: self._print_summary() + def get_mvn_posterior( + self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns parameters (mean and cov) of multivariate normal posterior + for a single sample of GP parameters + """ + noise = params["noise"] + N = self.X_train.shape[0] + D = jnp.broadcast_to(noise, (N,)) + noise_p = noise * (1 - jnp.array(noiseless, int)) + + y_residual = self.y_train.copy() + if self.mean_fn is not None: + args = [self.X_train, params] if self.mean_fn_prior else [self.X_train] + y_residual -= self.mean_fn(*args).squeeze() + + # Compute self- and cross-covariance matrices + Kuu = self.kernel(self.Xu, self.Xu, params) + Luu = cholesky(Kuu, lower=True) + Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0) + + W = solve_triangular(Luu, Kuf, lower=True) + W_Dinv = W / D + K = W_Dinv @ W.T + K = K.at[jnp.diag_indices(K.shape[0])].add(1) + L = cholesky(K, lower=True) + + y_2D = y_residual.reshape(-1, N).T + W_Dinv_y = W_Dinv @ y_2D + + Kus = self.kernel(self.Xu, X_new, params, jitter=0) + Ws = solve_triangular(Luu, Kus, lower=True) + pack = jnp.concatenate((W_Dinv_y, Ws), axis=1) + Linv_pack = solve_triangular(L, pack, lower=True) + + Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]] + Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] + mean = Linv_W_Dinv_y.T @ Linv_Ws + + Kss = self.kernel(X_new, X_new, params, noise_p) + Qss = Ws.T @ Ws + cov = Kss - Qss + Linv_Ws.T @ Linv_Ws + + if self.mean_fn is not None: + args = [X_new, params] if self.mean_fn_prior else [X_new] + mean += self.mean_fn(*args).squeeze() + + return mean, cov + def get_samples(self) -> Dict[str, jnp.ndarray]: """Get posterior samples""" return self.svi.guide.median(self.kernel_params) From b1c0a58a8ee96e578341211ab1d0ce9a9b743c04 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 22:26:45 +0000 Subject: [PATCH 09/17] Subclass sparse GP from viGP --- gpax/models/sparse_gp.py | 90 ++++------------------------------------ 1 file changed, 8 insertions(+), 82 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index 880ddf2..0b0b3ec 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -1,6 +1,6 @@ """ sparse_gp.py -======= +============ Variational inference implementation of sparse Gaussian process regression @@ -17,15 +17,14 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import SVI, Trace_ELBO -from numpyro.infer.autoguide import AutoDelta, AutoNormal -from .gp import ExactGP +from .vigp import viGP from ..utils import initialize_inducing_points -class viSparseGP(ExactGP): +class viSparseGP(viGP): """ - Variational inference based sparse Gaussian process + Variational inference-based sparse Gaussian process Args: input_dim: @@ -52,13 +51,9 @@ def __init__(self, input_dim: int, kernel: str, lengthscale_prior_dist: Optional[dist.Distribution] = None, guide: str = 'delta') -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior, - noise_prior_dist, lengthscale_prior_dist) + noise_prior_dist, lengthscale_prior_dist, guide) super(viSparseGP, self).__init__(*args) - self.X_train = None - self.y_train = None self.Xu = None - self.guide_type = AutoNormal if guide == 'normal' else AutoDelta - self.svi = None def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: if Xu is not None: @@ -83,7 +78,7 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # Xompute kernel between inducing points - Kuu = self.kernel(Xu, Xu, kernel_params) + Kuu = self.kernel(Xu, Xu, kernel_params) # Cholesky decomposition Luu = cholesky(Kuu).T # Compute kernel between inducing and training points @@ -185,7 +180,7 @@ def get_mvn_posterior( Kuu = self.kernel(self.Xu, self.Xu, params) Luu = cholesky(Kuu, lower=True) Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0) - + W = solve_triangular(Luu, Kuf, lower=True) W_Dinv = W / D K = W_Dinv @ W.T @@ -212,73 +207,4 @@ def get_mvn_posterior( args = [X_new, params] if self.mean_fn_prior else [X_new] mean += self.mean_fn(*args).squeeze() - return mean, cov - - def get_samples(self) -> Dict[str, jnp.ndarray]: - """Get posterior samples""" - return self.svi.guide.median(self.kernel_params) - - def predict_in_batches(self, rng_key: jnp.ndarray, - X_new: jnp.ndarray, batch_size: int = 100, - samples: Optional[Dict[str, jnp.ndarray]] = None, - predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, - noiseless: bool = False, - device: Type[jaxlib.xla_extension.Device] = None, - **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Make prediction at X_new with sampled GP parameters - by spitting the input array into chunks ("batches") and running - predict_fn (defaults to self.predict) on each of them one-by-one - to avoid a memory overflow - """ - predict_fn = lambda xi: self.predict( - rng_key, xi, samples, noiseless, **kwargs) - y_pred, y_var = self._predict_in_batches( - rng_key, X_new, batch_size, 0, samples, - predict_fn=predict_fn, noiseless=noiseless, - device=device, **kwargs) - y_pred = jnp.concatenate(y_pred, 0) - y_var = jnp.concatenate(y_var, 0) - return y_pred, y_var - - def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, - samples: Optional[Dict[str, jnp.ndarray]] = None, - noiseless: bool = False, - device: Type[jaxlib.xla_extension.Device] = None, **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Make prediction at X_new points using posterior samples for GP parameters - - Args: - rng_key: random number generator key - X_new: new inputs with *(number of points, number of features)* dimensions - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - by default for the training data, we also want to include that noise in our prediction. - device: - optionally specify a cpu or gpu device on which to make a prediction; - e.g., ```device=jax.devices("gpu")[0]``` - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - - Returns - Center of the mass of sampled means and all the sampled predictions - """ - X_new = self._set_data(X_new) - if device: - self._set_training_data(device=device) - X_new = jax.device_put(X_new, device) - if samples is None: - samples = self.get_samples() - mean, cov = self.get_mvn_posterior(X_new, samples, noiseless, **kwargs) - return mean, cov.diagonal() - - def _print_summary(self) -> None: - params_map = self.get_samples() - print('\nInferred GP parameters') - for (k, vals) in params_map.items(): - spaces = " " * (15 - len(k)) - print(k, spaces, jnp.around(vals, 4)) + return mean, cov \ No newline at end of file From b438d069622fb34ce0ac2719d27a4ec0a5136c1d Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 22:41:53 +0000 Subject: [PATCH 10/17] Add tests for inducing points selection --- tests/test_utils.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 45c91fe..4d4ef69 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import sys import pytest import numpy as onp +import jax import jax.numpy as jnp import jax.random as jra import numpyro @@ -8,7 +9,7 @@ sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, initialize_inducing_points from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist from gpax.utils import set_fn, auto_normal_priors @@ -107,4 +108,43 @@ def test_get_keys_different_seeds(): key1, key2 = get_keys() key1a, key2a = get_keys(42) assert_(not onp.array_equal(key1, key1a)) - assert_(not onp.array_equal(key2, key2a)) \ No newline at end of file + assert_(not onp.array_equal(key2, key2a)) + + +def test_ratio_out_of_bounds(): + X = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) + with pytest.raises(ValueError): + initialize_inducing_points(X, ratio=-0.1) + with pytest.raises(ValueError): + initialize_inducing_points(X, ratio=1.5) + + +def test_invalid_method(): + X = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) + with pytest.raises(ValueError): + initialize_inducing_points(X, method='invalid_method') + + +def test_missing_key_for_random_method(): + X = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) + with pytest.raises(ValueError): + initialize_inducing_points(X, method='random') + + +@pytest.mark.parametrize("method", ["uniform", "random"]) +def test_output_shape(method): + X = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) + ratio = 0.1 + inducing_points = initialize_inducing_points( + X, ratio=ratio, method=method, key=jax.random.PRNGKey(0)) + expected_shape = (int(100 * ratio), 5) + assert inducing_points.shape == expected_shape, "Output shape is incorrect" + + +@pytest.mark.skipif('sklearn' not in sys.modules, reason="sklearn is not installed") +def test_kmeans_dependency(): + X = jax.random.normal(jax.random.PRNGKey(0), (100, 5)) + try: + inducing_points = initialize_inducing_points(X, method='kmeans') + except ImportError: + pytest.fail("KMeans test failed due to missing sklearn dependency") From ce93965fd05285c875bb295e3b75046e3be43849 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:35:43 +0000 Subject: [PATCH 11/17] Update imports --- gpax/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpax/__init__.py b/gpax/__init__.py index f3d7d53..9a1df20 100644 --- a/gpax/__init__.py +++ b/gpax/__init__.py @@ -4,8 +4,9 @@ from . import acquisition from .hypo import sample_next from .models import (DKL, CoregGP, ExactGP, MultiTaskGP, iBNN, vExactGP, - vi_iBNN, viDKL, viGP, viMTDKL, VarNoiseGP, UIGP, MeasuredNoiseGP) + vi_iBNN, viDKL, viGP, viMTDKL, VarNoiseGP, UIGP, + MeasuredNoiseGP, viSparseGP) __all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL", "viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP", - "UIGP", "MeasuredNoiseGP", "CoregGP", "sample_next", "__version__"] + "UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "sample_next", "__version__"] From eca96a20fb39cafda947d74c528f346b9f4571a6 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:44:03 +0000 Subject: [PATCH 12/17] Squeeze predictive mean --- gpax/models/sparse_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index 0b0b3ec..ab98bcd 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -197,7 +197,7 @@ def get_mvn_posterior( Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]] Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] - mean = Linv_W_Dinv_y.T @ Linv_Ws + mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze() Kss = self.kernel(X_new, X_new, params, noise_p) Qss = Ws.T @ Ws From 01fcbd6460fd71febca874dde08ebe3353416bc1 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:49:32 +0000 Subject: [PATCH 13/17] Add tests for viSparseGP --- tests/test_sparsegp.py | 64 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_sparsegp.py diff --git a/tests/test_sparsegp.py b/tests/test_sparsegp.py new file mode 100644 index 0000000..70d9352 --- /dev/null +++ b/tests/test_sparsegp.py @@ -0,0 +1,64 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +import jax +from numpy.testing import assert_equal + +sys.path.insert(0, "../gpax/") + +from gpax.models.sparse_gp import viSparseGP +from gpax.utils import get_keys, enable_x64 + +enable_x64() + + +def get_dummy_data(jax_ndarray=True, unsqueeze=False): + X = onp.linspace(1, 2, 50) + 0.1 * onp.random.randn(50,) + y = (10 * X**2) + if unsqueeze: + X = X[:, None] + if jax_ndarray: + return jnp.array(X), jnp.array(y) + return X, y + + +@pytest.mark.parametrize("jax_ndarray", [True, False]) +@pytest.mark.parametrize("unsqueeze", [True, False]) +def test_fit(jax_ndarray, unsqueeze): + rng_key = get_keys()[0] + X, y = get_dummy_data(jax_ndarray, unsqueeze) + m = viSparseGP(1, 'Matern') + m.fit(rng_key, X, y, num_steps=100) + assert m.svi is not None + assert isinstance(m.Xu, jnp.ndarray) + + +def test_inducing_points_optimization(): + rng_key = get_keys()[0] + X, y = get_dummy_data() + m1 = viSparseGP(1, 'Matern') + m1.fit(rng_key, X, y, num_steps=1) + m2 = viSparseGP(1, 'Matern') + m2.fit(rng_key, X, y, num_steps=100) + assert not jnp.array_equal(m1.Xu, m2.Xu) + + +def test_get_mvn_posterior(): + rng_keys = get_keys() + X, y = get_dummy_data(unsqueeze=True) + X_test, _ = get_dummy_data(unsqueeze=True) + params = {"k_length": jax.random.normal(rng_keys[0], shape=(1, 1)), + "k_scale": jax.random.normal(rng_keys[0], shape=(1,)), + "noise": jax.random.normal(rng_keys[0], shape=(1,))} + m = viSparseGP(1, 'RBF') + m.X_train = X + m.y_train = y + m.Xu = X[::2].copy() + + mean, cov = m.get_mvn_posterior(X_test, params) + + assert isinstance(mean, jnp.ndarray) + assert isinstance(cov, jnp.ndarray) + assert_equal(mean.shape, X_test.squeeze().shape) + assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0])) \ No newline at end of file From e2d8c401edda1f4c8846804c59ec42f5925c49d6 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 05:22:23 +0000 Subject: [PATCH 14/17] Allow passing jitter as kwargs --- gpax/models/sparse_gp.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index ab98bcd..25c9050 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -55,7 +55,11 @@ def __init__(self, input_dim: int, kernel: str, super(viSparseGP, self).__init__(*args) self.Xu = None - def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: + def model(self, + X: jnp.ndarray, + y: jnp.ndarray = None, + Xu: jnp.ndarray = None, + **kwargs: float) -> None: if Xu is not None: Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros @@ -77,8 +81,8 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, Xu: jnp.ndarray = None, * if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() - # Xompute kernel between inducing points - Kuu = self.kernel(Xu, Xu, kernel_params) + # Compute kernel between inducing points + Kuu = self.kernel(Xu, Xu, kernel_params, **kwargs) # Cholesky decomposition Luu = cholesky(Kuu).T # Compute kernel between inducing and training points @@ -177,7 +181,7 @@ def get_mvn_posterior( y_residual -= self.mean_fn(*args).squeeze() # Compute self- and cross-covariance matrices - Kuu = self.kernel(self.Xu, self.Xu, params) + Kuu = self.kernel(self.Xu, self.Xu, params, **kwargs) Luu = cholesky(Kuu, lower=True) Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0) @@ -199,7 +203,7 @@ def get_mvn_posterior( Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze() - Kss = self.kernel(X_new, X_new, params, noise_p) + Kss = self.kernel(X_new, X_new, params, noise_p, **kwargs) Qss = Ws.T @ Ws cov = Kss - Qss + Linv_Ws.T @ Linv_Ws From 54f319a03a7e0f9f4a9f08c73a5b581ce80b5f56 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 05:23:53 +0000 Subject: [PATCH 15/17] Update docstrings --- gpax/models/sparse_gp.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index 25c9050..c99494e 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -60,6 +60,9 @@ def model(self, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: + """ + Probabilistic sparse Gaussian process regression model + """ if Xu is not None: Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros @@ -115,7 +118,7 @@ def fit(self, **kwargs: float ) -> None: """ - Run variational inference to learn GP (hyper)parameters + Run variational inference to learn sparse GP (hyper)parameters Args: rng_key: random number generator key @@ -163,9 +166,11 @@ def fit(self, if print_summary: self._print_summary() - def get_mvn_posterior( - self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + def get_mvn_posterior(self, X_new: jnp.ndarray, + params: Dict[str, jnp.ndarray], + noiseless: bool = False, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters From a51e695c30ff98bc1aa6675934a5a4c98251f62a Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 06:24:53 +0000 Subject: [PATCH 16/17] Set 'random' as default method for inducing point selection --- gpax/models/sparse_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index c99494e..c336180 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -111,7 +111,7 @@ def model(self, def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, - inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'uniform', + inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'random', num_steps: int = 1000, step_size: float = 5e-3, progress_bar: bool = True, print_summary: bool = True, device: Type[jaxlib.xla_extension.Device] = None, From 2ad3fa9c7081a2cd892745e5eef292944aa610b0 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:29:49 +0000 Subject: [PATCH 17/17] Update docstrings --- gpax/models/sparse_gp.py | 8 ++++++-- gpax/models/vigp.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index c336180..37f0210 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -37,8 +37,12 @@ class viSparseGP(viGP): Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default mean_fn_prior: Optional priors over mean function parameters - noise_prior: - Optional custom prior for the observation noise variance; uses LogNormal(0,1) by default. + noise_prior_dist: + Optional custom prior distribution over the observational noise variance. + Defaults to LogNormal(0,1). + lengthscale_prior_dist: + Optional custom prior distribution over kernel lengthscale. + Defaults to LogNormal(0, 1). guide: Auto-guide option, use 'delta' (default) or 'normal' """ diff --git a/gpax/models/vigp.py b/gpax/models/vigp.py index 11f8de3..e734441 100644 --- a/gpax/models/vigp.py +++ b/gpax/models/vigp.py @@ -35,8 +35,12 @@ class viGP(ExactGP): Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default mean_fn_prior: Optional priors over mean function parameters - noise_prior: - Optional custom prior for the observation noise variance; uses LogNormal(0,1) by default. + noise_prior_dist: + Optional custom prior distribution over the observational noise variance. + Defaults to LogNormal(0,1). + lengthscale_prior_dist: + Optional custom prior distribution over kernel lengthscale. + Defaults to LogNormal(0, 1). guide: Auto-guide option, use 'delta' (default) or 'normal'