Skip to content

Commit

Permalink
Add _set_data to sPM
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 25, 2024
1 parent b6e7884 commit 7d317ca
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
Expand Down Expand Up @@ -165,6 +166,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
Returns:
Point predictions (or their mean) and posterior predictive distribution
"""
X_new = self._set_data(X_new)
if samples is None:
samples = self.get_samples(chain_dim=False)
if device:
Expand All @@ -183,3 +185,10 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,

def _print_summary(self):
self.mcmc.print_summary()

def _set_data(self,
X: jnp.ndarray, y: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray]:
if y is not None:
return X, y
return X

0 comments on commit 7d317ca

Please sign in to comment.