From 54f319a03a7e0f9f4a9f08c73a5b581ce80b5f56 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 05:23:53 +0000 Subject: [PATCH] Update docstrings --- gpax/models/sparse_gp.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/gpax/models/sparse_gp.py b/gpax/models/sparse_gp.py index 25c9050..c99494e 100644 --- a/gpax/models/sparse_gp.py +++ b/gpax/models/sparse_gp.py @@ -60,6 +60,9 @@ def model(self, y: jnp.ndarray = None, Xu: jnp.ndarray = None, **kwargs: float) -> None: + """ + Probabilistic sparse Gaussian process regression model + """ if Xu is not None: Xu = numpyro.param("Xu", Xu) # Initialize mean function at zeros @@ -115,7 +118,7 @@ def fit(self, **kwargs: float ) -> None: """ - Run variational inference to learn GP (hyper)parameters + Run variational inference to learn sparse GP (hyper)parameters Args: rng_key: random number generator key @@ -163,9 +166,11 @@ def fit(self, if print_summary: self._print_summary() - def get_mvn_posterior( - self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + def get_mvn_posterior(self, X_new: jnp.ndarray, + params: Dict[str, jnp.ndarray], + noiseless: bool = False, + **kwargs: float + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters