Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Feb 7, 2024
1 parent e2d8c40 commit 54f319a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions gpax/models/sparse_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def model(self,
y: jnp.ndarray = None,
Xu: jnp.ndarray = None,
**kwargs: float) -> None:
"""
Probabilistic sparse Gaussian process regression model
"""
if Xu is not None:
Xu = numpyro.param("Xu", Xu)
# Initialize mean function at zeros
Expand Down Expand Up @@ -115,7 +118,7 @@ def fit(self,
**kwargs: float
) -> None:
"""
Run variational inference to learn GP (hyper)parameters
Run variational inference to learn sparse GP (hyper)parameters
Args:
rng_key: random number generator key
Expand Down Expand Up @@ -163,9 +166,11 @@ def fit(self,
if print_summary:
self._print_summary()

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
def get_mvn_posterior(self, X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of GP parameters
Expand Down

0 comments on commit 54f319a

Please sign in to comment.