diff --git a/gpax/models/spm.py b/gpax/models/spm.py index d7d08da..6ace03a 100644 --- a/gpax/models/spm.py +++ b/gpax/models/spm.py @@ -103,6 +103,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, optionally specify a cpu or gpu device on which to run the inference; e.g., ``device=jax.devices("cpu")[0]`` """ + X, y = self._set_data(X, y) if device: X = jax.device_put(X, device) y = jax.device_put(y, device) @@ -165,6 +166,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, Returns: Point predictions (or their mean) and posterior predictive distribution """ + X_new = self._set_data(X_new) if samples is None: samples = self.get_samples(chain_dim=False) if device: @@ -183,3 +185,10 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, def _print_summary(self): self.mcmc.print_summary() + + def _set_data(self, + X: jnp.ndarray, y: Optional[jnp.ndarray] = None, + ) -> Tuple[jnp.ndarray]: + if y is not None: + return X, y + return X