From 47179bf911d4b12926d6007b27d92aeba8e04c16 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Wed, 12 Feb 2025 13:51:18 +0100 Subject: [PATCH 1/3] temporary fix for large data --- src/cfp/data/_datamanager.py | 5 +++-- src/cfp/training/_trainer.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cfp/data/_datamanager.py b/src/cfp/data/_datamanager.py index 9fa7cc8f..c9ac661b 100644 --- a/src/cfp/data/_datamanager.py +++ b/src/cfp/data/_datamanager.py @@ -505,6 +505,7 @@ def _get_cell_data( self, adata: anndata.AnnData, sample_rep: str | None = None, + device: str = "cpu", ) -> jax.Array: sample_rep = self._sample_rep if sample_rep is None else sample_rep if sample_rep == "X": @@ -520,8 +521,8 @@ def _get_cell_data( ) return jnp.asarray(adata.obsm[self._sample_rep]) attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] - return jnp.asarray(getattr(adata, attr)[key]) - + return jax.device_put(jnp.asarray(getattr(adata, attr)[key]), device=jax.devices("cpu")[0]) + def _verify_control_data(self, adata: anndata.AnnData | None) -> None: if adata is None: return None diff --git a/src/cfp/training/_trainer.py b/src/cfp/training/_trainer.py index 0040a826..a316ed21 100644 --- a/src/cfp/training/_trainer.py +++ b/src/cfp/training/_trainer.py @@ -114,6 +114,7 @@ def train( for it in pbar: rng, rng_step_fn = jax.random.split(rng, 2) batch = dataloader.sample(rng) + jax.device_put(batch) loss = self.solver.step_fn(rng_step_fn, batch) self.training_logs["loss"].append(float(loss)) From 64b8952dd6a44a6529d48ef3614454cdb2fa4731 Mon Sep 17 00:00:00 2001 From: Manuel Lubetzki Date: Tue, 4 Mar 2025 21:43:36 +0100 Subject: [PATCH 2/3] fixed other returns --- src/cfp/data/_datamanager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cfp/data/_datamanager.py b/src/cfp/data/_datamanager.py index c9ac661b..f6b08bf0 100644 --- a/src/cfp/data/_datamanager.py +++ b/src/cfp/data/_datamanager.py @@ -511,15 +511,15 @@ def _get_cell_data( if sample_rep == "X": sample_rep = adata.X if isinstance(sample_rep, sp.csr_matrix): - return jnp.asarray(sample_rep.toarray()) + return jax.device_put(jnp.asarray(sample_rep.toarray()), device=jax.devices("cpu")[0]) else: - return jnp.asarray(sample_rep) + return jax.device_put(jnp.asarray(sample_rep), device=jax.devices("cpu")[0]) if isinstance(self._sample_rep, str): if self._sample_rep not in adata.obsm: raise KeyError( f"Sample representation '{self._sample_rep}' not found in `adata.obsm`." ) - return jnp.asarray(adata.obsm[self._sample_rep]) + return jax.device_put(jnp.asarray(adata.obsm[self._sample_rep]), device=jax.devices("cpu")[0]) attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] return jax.device_put(jnp.asarray(getattr(adata, attr)[key]), device=jax.devices("cpu")[0]) From 5b4fa1645ecaf87f52569e5cad42c062185b15c2 Mon Sep 17 00:00:00 2001 From: Manuel Lubetzki Date: Wed, 5 Mar 2025 10:58:13 +0100 Subject: [PATCH 3/3] Create array on CPU --- src/cfp/data/_datamanager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cfp/data/_datamanager.py b/src/cfp/data/_datamanager.py index f6b08bf0..5eb1e108 100644 --- a/src/cfp/data/_datamanager.py +++ b/src/cfp/data/_datamanager.py @@ -511,17 +511,17 @@ def _get_cell_data( if sample_rep == "X": sample_rep = adata.X if isinstance(sample_rep, sp.csr_matrix): - return jax.device_put(jnp.asarray(sample_rep.toarray()), device=jax.devices("cpu")[0]) + return jnp.asarray(sample_rep.toarray(), device=jax.devices("cpu")[0]) else: - return jax.device_put(jnp.asarray(sample_rep), device=jax.devices("cpu")[0]) + return jnp.asarray(sample_rep, device=jax.devices("cpu")[0]) if isinstance(self._sample_rep, str): if self._sample_rep not in adata.obsm: raise KeyError( f"Sample representation '{self._sample_rep}' not found in `adata.obsm`." ) - return jax.device_put(jnp.asarray(adata.obsm[self._sample_rep]), device=jax.devices("cpu")[0]) + return jnp.asarray(adata.obsm[self._sample_rep], device=jax.devices("cpu")[0]) attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] - return jax.device_put(jnp.asarray(getattr(adata, attr)[key]), device=jax.devices("cpu")[0]) + return jnp.asarray(getattr(adata, attr)[key], device=jax.devices("cpu")[0]) def _verify_control_data(self, adata: anndata.AnnData | None) -> None: if adata is None: