diff --git a/src/cfp/data/_datamanager.py b/src/cfp/data/_datamanager.py index 9fa7cc8f..5eb1e108 100644 --- a/src/cfp/data/_datamanager.py +++ b/src/cfp/data/_datamanager.py @@ -505,23 +505,24 @@ def _get_cell_data( self, adata: anndata.AnnData, sample_rep: str | None = None, + device: str = "cpu", ) -> jax.Array: sample_rep = self._sample_rep if sample_rep is None else sample_rep if sample_rep == "X": sample_rep = adata.X if isinstance(sample_rep, sp.csr_matrix): - return jnp.asarray(sample_rep.toarray()) + return jnp.asarray(sample_rep.toarray(), device=jax.devices("cpu")[0]) else: - return jnp.asarray(sample_rep) + return jnp.asarray(sample_rep, device=jax.devices("cpu")[0]) if isinstance(self._sample_rep, str): if self._sample_rep not in adata.obsm: raise KeyError( f"Sample representation '{self._sample_rep}' not found in `adata.obsm`." ) - return jnp.asarray(adata.obsm[self._sample_rep]) + return jnp.asarray(adata.obsm[self._sample_rep], device=jax.devices("cpu")[0]) attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] - return jnp.asarray(getattr(adata, attr)[key]) - + return jnp.asarray(getattr(adata, attr)[key], device=jax.devices("cpu")[0]) + def _verify_control_data(self, adata: anndata.AnnData | None) -> None: if adata is None: return None diff --git a/src/cfp/training/_trainer.py b/src/cfp/training/_trainer.py index 0040a826..a316ed21 100644 --- a/src/cfp/training/_trainer.py +++ b/src/cfp/training/_trainer.py @@ -114,6 +114,7 @@ def train( for it in pbar: rng, rng_step_fn = jax.random.split(rng, 2) batch = dataloader.sample(rng) + jax.device_put(batch) loss = self.solver.step_fn(rng_step_fn, batch) self.training_logs["loss"].append(float(loss))