From 08c0bc9c7fefae00a91ec47d8e1686e35c083a1c Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 8 Aug 2025 12:17:21 +0200 Subject: [PATCH 01/35] integrate new dataloaders --- src/cellflow/data/__init__.py | 18 +++- src/cellflow/data/_data.py | 175 +++++++++++++++++++++++++++++++ src/cellflow/data/_dataloader.py | 81 +++++++++++--- 3 files changed, 259 insertions(+), 15 deletions(-) diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index e6f6f2de..94998cef 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -1,5 +1,17 @@ -from cellflow.data._data import BaseDataMixin, ConditionData, PredictionData, TrainingData, ValidationData -from cellflow.data._dataloader import PredictionSampler, TrainSampler, ValidationSampler +from cellflow.data._data import ( + BaseDataMixin, + ConditionData, + PredictionData, + TrainingData, + ValidationData, + ZarrTrainingData, +) +from cellflow.data._dataloader import ( + PredictionSampler, + TrainSampler, + ValidationSampler, + CombinedTrainSampler, +) from cellflow.data._datamanager import DataManager __all__ = [ @@ -9,7 +21,9 @@ "PredictionData", "TrainingData", "ValidationData", + "ZarrTrainingData", "TrainSampler", "ValidationSampler", "PredictionSampler", + "CombinedTrainSampler", ] diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 0f51d304..93d2750a 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -13,6 +15,7 @@ "PredictionData", "TrainingData", "ValidationData", + "ZarrTrainingData", ] @@ -121,6 +124,112 @@ class TrainingData(BaseDataMixin): null_value: Any data_manager: Any + # --- Zarr export helpers ------------------------------------------------- + def to_zarr( + self, + path: str, + *, + chunk_size: int = 4096, + shard_size: int = 65536, + compressors: Any | None = None, + ) -> None: + """Write this training data to Zarr v3 with sharded, compressed arrays. + + Parameters + ---------- + path + Path to a Zarr group to create or open for writing. + chunk_size + Chunk size along the first axis. + shard_size + Shard size along the first axis. + compressors + Optional list/tuple of Zarr codecs. If ``None``, a sensible default is used. + """ + try: + import anndata as ad # lazy import + import importlib + zarr = importlib.import_module("zarr") + zarr_codecs = importlib.import_module("zarr.codecs") + except Exception as exc: # pragma: no cover + raise ImportError( + "Writing to Zarr requires 'anndata>=0.10' and 'zarr>=3'." + ) from exc + + if compressors is None: + compressors = (zarr_codecs.BloscCodec(cname="lz4", clevel=3),) + + # Convert to numpy-backed containers for serialization + cell_data = np.asarray(self.cell_data) + split_covariates_mask = np.asarray(self.split_covariates_mask) + perturbation_covariates_mask = np.asarray(self.perturbation_covariates_mask) + condition_data = { + str(k): np.asarray(v) for k, v in (self.condition_data or {}).items() + } + control_to_perturbation = { + str(k): np.asarray(v) + for k, v in (self.control_to_perturbation or {}).items() + } + split_idx_to_covariates = { + str(k): np.asarray(v) + for k, v in (self.split_idx_to_covariates or {}).items() + } + perturbation_idx_to_covariates = { + str(k): np.asarray(v) + for k, v in (self.perturbation_idx_to_covariates or {}).items() + } + perturbation_idx_to_id = { + str(k): v for k, v in (self.perturbation_idx_to_id or {}).items() + } + + train_data_dict: dict[str, Any] = { + "cell_data": cell_data, + "split_covariates_mask": split_covariates_mask, + "perturbation_covariates_mask": perturbation_covariates_mask, + "split_idx_to_covariates": split_idx_to_covariates, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "perturbation_idx_to_id": perturbation_idx_to_id, + "condition_data": condition_data, + "control_to_perturbation": control_to_perturbation, + "max_combination_length": int(self.max_combination_length), + } + + # Ensure Zarr v3 write format for sharding + ad.settings.zarr_write_format = 3 + + + def _write_sharded_callback( + func: Any, + group: Any, + key: str, + element: Any, + dataset_kwargs: dict[str, Any], + iospec: Any, + ) -> None: + # Only shard/chunk along the first dimension + if getattr(iospec, "encoding_type", None) in {"array"}: + dataset_kwargs = { + "shards": (shard_size,) + tuple(element.shape[1:]), + "chunks": (chunk_size,) + tuple(element.shape[1:]), + "compressors": compressors, + **dataset_kwargs, + } + elif getattr(iospec, "encoding_type", None) in {"csr_matrix", "csc_matrix"}: + dataset_kwargs = { + "shards": (shard_size,), + "chunks": (chunk_size,), + "compressors": compressors, + **dataset_kwargs, + } + + func(group, key, element, dataset_kwargs=dataset_kwargs) + + zgroup = zarr.open_group(path, mode="a") + ad.experimental.write_dispatched( + zgroup, "/", train_data_dict, callback=_write_sharded_callback + ) + zarr.consolidate_metadata(zgroup.store) + @dataclass class ValidationData(BaseDataMixin): @@ -203,3 +312,69 @@ class PredictionData(BaseDataMixin): max_combination_length: int null_value: Any data_manager: Any + + +@dataclass +class ZarrTrainingData(BaseDataMixin): + """Lazy, Zarr-backed variant of :class:`TrainingData`. + + Fields mirror those in :class:`TrainingData`, but array-like members are + Zarr arrays or Zarr-backed mappings. This enables out-of-core training and + composition without loading everything into memory. + + Use :meth:`read_zarr` to construct from a Zarr v3 group written via + :meth:`TrainingData.to_zarr`. + """ + + # Note: annotations use Any to allow zarr.Array and zarr groups without + # importing zarr at module import time. + cell_data: Any + split_covariates_mask: Any + perturbation_covariates_mask: Any + split_idx_to_covariates: dict[int, tuple[Any, ...]] + perturbation_idx_to_covariates: dict[int, tuple[str, ...]] + perturbation_idx_to_id: dict[int, Any] + condition_data: dict[str, Any] + control_to_perturbation: dict[int, Any] + max_combination_length: int + + @classmethod + def read_zarr(cls, path: str) -> "ZarrTrainingData": + try: + import anndata as ad # lazy import + import importlib + zarr = importlib.import_module("zarr") + except Exception as exc: # pragma: no cover + raise ImportError( + "Reading from Zarr requires 'anndata>=0.10' and 'zarr>=3'." + ) from exc + + group = zarr.open_group(path, mode="r") + max_len_node = group.get("max_combination_length") + if max_len_node is None: + max_combination_length = 0 + else: + try: + max_combination_length = int(max_len_node[()]) + except Exception: + max_combination_length = int(max_len_node) + + return cls( + cell_data=group["cell_data"], + split_covariates_mask=group["split_covariates_mask"], + perturbation_covariates_mask=group["perturbation_covariates_mask"], + split_idx_to_covariates=ad.io.read_elem( + group["split_idx_to_covariates"] + ), + perturbation_idx_to_covariates=ad.io.read_elem( + group["perturbation_idx_to_covariates"] + ), + perturbation_idx_to_id=ad.io.read_elem( + group["perturbation_idx_to_id"] + ), + condition_data=ad.io.read_elem(group["condition_data"]), + control_to_perturbation=ad.io.read_elem( + group["control_to_perturbation"] + ), + max_combination_length=max_combination_length, + ) diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 70bd91ef..5b27b641 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -7,9 +7,20 @@ import jax import numpy as np -from cellflow.data._data import PredictionData, TrainingData, ValidationData - -__all__ = ["TrainSampler", "ValidationSampler", "PredictionSampler", "OOCTrainSampler"] +from cellflow.data._data import ( + PredictionData, + TrainingData, + ValidationData, + ZarrTrainingData, +) + +__all__ = [ + "TrainSampler", + "ValidationSampler", + "PredictionSampler", + "OOCTrainSampler", + "CombinedTrainSampler", +] class TrainSampler: @@ -24,7 +35,7 @@ class TrainSampler: """ - def __init__(self, data: TrainingData, batch_size: int = 1024): + def __init__(self, data: TrainingData | ZarrTrainingData, batch_size: int = 1024): self._data = data self._data_idcs = np.arange(data.cell_data.shape[0]) self.batch_size = batch_size @@ -94,7 +105,7 @@ def sample(self, rng) -> dict[str, Any]: } @property - def data(self): + def data(self) -> TrainingData | ZarrTrainingData: """The training data.""" return self._data @@ -234,9 +245,7 @@ def data(self) -> PredictionData: return self._data -def prefetch_to_device( - sampler: TrainSampler, seed: int, num_iterations: int, prefetch_factor: int = 2, num_workers: int = 4 -) -> Generator[dict[str, Any], None, None]: +def prefetch_to_device(sampler: TrainSampler, seed: int, num_iterations: int, prefetch_factor: int = 2, num_workers: int = 4) -> Generator[dict[str, Any], None, None]: seq = np.random.SeedSequence(seed) random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] @@ -267,7 +276,7 @@ def worker(rng: np.random.Generator): try: for _ in range(num_iterations): - # Yield batches from the queue; will block waiting for available batch + # Yield batches from the queue; blocks waiting for available batch yield q.get() finally: # When the generator is closed or garbage collected, clean up the worker threads @@ -278,7 +287,12 @@ def worker(rng: np.random.Generator): class OOCTrainSampler: def __init__( - self, data: TrainingData, seed: int, batch_size: int = 1024, num_workers: int = 4, prefetch_factor: int = 2 + self, + data: TrainingData | ZarrTrainingData, + seed: int, + batch_size: int = 1024, + num_workers: int = 4, + prefetch_factor: int = 2, ): self.inner = TrainSampler(data=data, batch_size=batch_size) self.num_workers = num_workers @@ -287,9 +301,7 @@ def __init__( self._iterator = None def set_sampler(self, num_iterations: int) -> None: - self._iterator = prefetch_to_device( - sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor - ) + self._iterator = prefetch_to_device(sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor) def sample(self, rng=None) -> dict[str, Any]: if self._iterator is None: @@ -301,3 +313,46 @@ def sample(self, rng=None) -> dict[str, Any]: if rng is not None: del rng return next(self._iterator) + + +class CombinedTrainSampler: + """Sample batches from multiple datasets with optional sampling weights. + + Returns batches from the chosen dataset and includes the chosen + ``dataset_index`` in the returned dict. + + Parameters + ---------- + datasets + List of training datasets (in-memory or Zarr-backed). + weights + Sampling weights for each dataset. If ``None``, use uniform weights. + batch_size + Batch size for each inner sampler. + """ + + def __init__( + self, + datasets: list[TrainingData | ZarrTrainingData], + *, + weights: np.ndarray | list[float] | None = None, + batch_size: int = 1024, + ) -> None: + if len(datasets) == 0: + raise ValueError("'datasets' must be a non-empty list.") + self.samplers: list[TrainSampler] = [TrainSampler(d, batch_size=batch_size) for d in datasets] + if weights is None: + self.weights = np.full(len(datasets), 1.0 / len(datasets), dtype=float) + else: + w = np.asarray(weights, dtype=float) + if w.shape[0] != len(datasets): + raise ValueError("'weights' length must match number of datasets.") + total = float(w.sum()) + if total <= 0: + raise ValueError("'weights' must sum to a positive value.") + self.weights = w / total + + def sample(self, rng: np.random.Generator) -> dict[str, Any]: + dataset_index = int(rng.choice(len(self.samplers), p=self.weights)) + batch = self.samplers[dataset_index].sample(rng) + return {**batch, "dataset_index": dataset_index} From 14dda75707e4f905dce4dd07adba9a1e40de862e Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 9 Aug 2025 02:23:22 +0300 Subject: [PATCH 02/35] put working state --- src/cellflow/_optional.py | 9 ++ src/cellflow/compat/__init__.py | 3 + src/cellflow/compat/torch_.py | 17 ++++ src/cellflow/data/_data.py | 30 ++----- src/cellflow/data/_dataloader.py | 119 ------------------------- src/cellflow/data/_jax_dataloader.py | 110 +++++++++++++++++++++++ src/cellflow/data/_torch_dataloader.py | 77 ++++++++++++++++ 7 files changed, 222 insertions(+), 143 deletions(-) create mode 100644 src/cellflow/_optional.py create mode 100644 src/cellflow/compat/__init__.py create mode 100644 src/cellflow/compat/torch_.py create mode 100644 src/cellflow/data/_jax_dataloader.py create mode 100644 src/cellflow/data/_torch_dataloader.py diff --git a/src/cellflow/_optional.py b/src/cellflow/_optional.py new file mode 100644 index 00000000..53ea8e31 --- /dev/null +++ b/src/cellflow/_optional.py @@ -0,0 +1,9 @@ +class OptionalDependencyNotAvailable(ImportError): + pass + + +def torch_required_msg() -> str: + return ( + "Optional dependency 'torch' is required for this feature.\n" + "Install it via: pip install torch # or pip install 'cellflow-tools[torch]'" + ) diff --git a/src/cellflow/compat/__init__.py b/src/cellflow/compat/__init__.py new file mode 100644 index 00000000..82ff1adb --- /dev/null +++ b/src/cellflow/compat/__init__.py @@ -0,0 +1,3 @@ +from .torch_ import TorchIterableDataset + +__all__ = ["TorchIterableDataset"] diff --git a/src/cellflow/compat/torch_.py b/src/cellflow/compat/torch_.py new file mode 100644 index 00000000..c2120f94 --- /dev/null +++ b/src/cellflow/compat/torch_.py @@ -0,0 +1,17 @@ +from typing import TYPE_CHECKING + +from cellflow._optional import OptionalDependencyNotAvailable, torch_required_msg + +try: + from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore + TORCH_AVAILABLE = True +except ImportError as _: + TORCH_AVAILABLE = False + + class TorchIterableDataset: # type: ignore + def __init__(self, *args, **kwargs): + raise OptionalDependencyNotAvailable(torch_required_msg()) + +if TYPE_CHECKING: + # keeps type checkers aligned with the real type + from torch.utils.data import IterableDataset as TorchIterableDataset # noqa: F401 diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 93d2750a..2de0dd9e 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -4,8 +4,10 @@ from dataclasses import dataclass from typing import Any -import jax +import anndata as ad import numpy as np +import zarr +from zarr.codecs import BloscCodec from cellflow._types import ArrayLike @@ -146,18 +148,8 @@ def to_zarr( compressors Optional list/tuple of Zarr codecs. If ``None``, a sensible default is used. """ - try: - import anndata as ad # lazy import - import importlib - zarr = importlib.import_module("zarr") - zarr_codecs = importlib.import_module("zarr.codecs") - except Exception as exc: # pragma: no cover - raise ImportError( - "Writing to Zarr requires 'anndata>=0.10' and 'zarr>=3'." - ) from exc - if compressors is None: - compressors = (zarr_codecs.BloscCodec(cname="lz4", clevel=3),) + compressors = (BloscCodec(cname="lz4", clevel=3),) # Convert to numpy-backed containers for serialization cell_data = np.asarray(self.cell_data) @@ -197,7 +189,6 @@ def to_zarr( # Ensure Zarr v3 write format for sharding ad.settings.zarr_write_format = 3 - def _write_sharded_callback( func: Any, group: Any, @@ -300,8 +291,8 @@ class PredictionData(BaseDataMixin): Token to use for masking ``null_value``. """ - cell_data: jax.Array # (n_cells, n_features) - split_covariates_mask: jax.Array # (n_cells,), which cell assigned to which source distribution + cell_data: ArrayLike # (n_cells, n_features) + split_covariates_mask: ArrayLike # (n_cells,), which cell assigned to which source distribution split_idx_to_covariates: dict[int, tuple[Any, ...]] # (n_sources,) dictionary explaining split_covariates_mask perturbation_idx_to_covariates: dict[ int, tuple[str, ...] @@ -340,15 +331,6 @@ class ZarrTrainingData(BaseDataMixin): @classmethod def read_zarr(cls, path: str) -> "ZarrTrainingData": - try: - import anndata as ad # lazy import - import importlib - zarr = importlib.import_module("zarr") - except Exception as exc: # pragma: no cover - raise ImportError( - "Reading from Zarr requires 'anndata>=0.10' and 'zarr>=3'." - ) from exc - group = zarr.open_group(path, mode="r") max_len_node = group.get("max_combination_length") if max_len_node is None: diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 5b27b641..8df61581 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -1,10 +1,6 @@ import abc -import queue -import threading -from collections.abc import Generator from typing import Any, Literal -import jax import numpy as np from cellflow.data._data import ( @@ -18,8 +14,6 @@ "TrainSampler", "ValidationSampler", "PredictionSampler", - "OOCTrainSampler", - "CombinedTrainSampler", ] @@ -243,116 +237,3 @@ def sample(self) -> Any: def data(self) -> PredictionData: """The training data.""" return self._data - - -def prefetch_to_device(sampler: TrainSampler, seed: int, num_iterations: int, prefetch_factor: int = 2, num_workers: int = 4) -> Generator[dict[str, Any], None, None]: - seq = np.random.SeedSequence(seed) - random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] - - q: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=prefetch_factor * num_workers) - sem = threading.Semaphore(num_iterations) - stop_event = threading.Event() - - def worker(rng: np.random.Generator): - while not stop_event.is_set() and sem.acquire(blocking=False): - batch = sampler.sample(rng) - batch = jax.device_put(batch, jax.devices()[0], donate=True) - jax.block_until_ready(batch) - while not stop_event.is_set(): - try: - q.put(batch, timeout=1.0) - break # Batch successfully put into the queue; break out of retry loop - except queue.Full: - continue - - return - - # Start multiple worker threads - ts = [] - for i in range(num_workers): - t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i],)) - t.start() - ts.append(t) - - try: - for _ in range(num_iterations): - # Yield batches from the queue; blocks waiting for available batch - yield q.get() - finally: - # When the generator is closed or garbage collected, clean up the worker threads - stop_event.set() # Signal all workers to exit - for t in ts: - t.join() # Wait for all worker threads to finish - - -class OOCTrainSampler: - def __init__( - self, - data: TrainingData | ZarrTrainingData, - seed: int, - batch_size: int = 1024, - num_workers: int = 4, - prefetch_factor: int = 2, - ): - self.inner = TrainSampler(data=data, batch_size=batch_size) - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor - self.seed = seed - self._iterator = None - - def set_sampler(self, num_iterations: int) -> None: - self._iterator = prefetch_to_device(sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor) - - def sample(self, rng=None) -> dict[str, Any]: - if self._iterator is None: - raise ValueError( - "Sampler not set. Use `set_sampler` to set the sampler with" - "the number of iterations. Without the number of iterations," - " the sampler will not be able to sample the data." - ) - if rng is not None: - del rng - return next(self._iterator) - - -class CombinedTrainSampler: - """Sample batches from multiple datasets with optional sampling weights. - - Returns batches from the chosen dataset and includes the chosen - ``dataset_index`` in the returned dict. - - Parameters - ---------- - datasets - List of training datasets (in-memory or Zarr-backed). - weights - Sampling weights for each dataset. If ``None``, use uniform weights. - batch_size - Batch size for each inner sampler. - """ - - def __init__( - self, - datasets: list[TrainingData | ZarrTrainingData], - *, - weights: np.ndarray | list[float] | None = None, - batch_size: int = 1024, - ) -> None: - if len(datasets) == 0: - raise ValueError("'datasets' must be a non-empty list.") - self.samplers: list[TrainSampler] = [TrainSampler(d, batch_size=batch_size) for d in datasets] - if weights is None: - self.weights = np.full(len(datasets), 1.0 / len(datasets), dtype=float) - else: - w = np.asarray(weights, dtype=float) - if w.shape[0] != len(datasets): - raise ValueError("'weights' length must match number of datasets.") - total = float(w.sum()) - if total <= 0: - raise ValueError("'weights' must sum to a positive value.") - self.weights = w / total - - def sample(self, rng: np.random.Generator) -> dict[str, Any]: - dataset_index = int(rng.choice(len(self.samplers), p=self.weights)) - batch = self.samplers[dataset_index].sample(rng) - return {**batch, "dataset_index": dataset_index} diff --git a/src/cellflow/data/_jax_dataloader.py b/src/cellflow/data/_jax_dataloader.py new file mode 100644 index 00000000..b0c40358 --- /dev/null +++ b/src/cellflow/data/_jax_dataloader.py @@ -0,0 +1,110 @@ +import queue +import threading +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from cellflow.data._data import ( + TrainingData, + ZarrTrainingData, +) +from cellflow.data._dataloader import TrainSampler + + +def _prefetch_to_device( + sampler: TrainSampler, + seed: int, + num_iterations: int, + prefetch_factor: int = 2, + num_workers: int = 4, +) -> Generator[dict[str, Any], None, None]: + import jax + + seq = np.random.SeedSequence(seed) + random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] + + q: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=prefetch_factor * num_workers) + sem = threading.Semaphore(num_iterations) + stop_event = threading.Event() + + def worker(rng: np.random.Generator): + while not stop_event.is_set() and sem.acquire(blocking=False): + batch = sampler.sample(rng) + batch = jax.device_put(batch, jax.devices()[0], donate=True) + jax.block_until_ready(batch) + while not stop_event.is_set(): + try: + q.put(batch, timeout=1.0) + break # Batch successfully put into the queue; break out of retry loop + except queue.Full: + continue + + return + + # Start multiple worker threads + ts = [] + for i in range(num_workers): + t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i],)) + t.start() + ts.append(t) + + try: + for _ in range(num_iterations): + # Yield batches from the queue; blocks waiting for available batch + yield q.get() + finally: + # When the generator is closed or garbage collected, clean up the worker threads + stop_event.set() # Signal all workers to exit + for t in ts: + t.join() # Wait for all worker threads to finish + + +@dataclass +class JaxOutOfCoreTrainSampler: + """ + A sampler that prefetches batches to the GPU for out-of-core training. + + Here out-of-core means that data can be more than the GPU memory. + + Parameters + ---------- + data + The training data. + seed + The seed for the random number generator. + batch_size + The batch size. + num_workers + The number of workers to use for prefetching. + prefetch_factor + The prefetch factor similar to PyTorch's DataLoader. + + """ + + data: TrainingData | ZarrTrainingData + seed: int + batch_size: int = 1024 + num_workers: int = 4 + prefetch_factor: int = 2 + + def __post_init__(self): + self.inner = TrainSampler(data=self.data, batch_size=self.batch_size) + self._iterator = None + + def set_sampler(self, num_iterations: int) -> None: + self._iterator = _prefetch_to_device( + sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor + ) + + def sample(self, rng=None) -> dict[str, Any]: + if self._iterator is None: + raise ValueError( + "Sampler not set. Use `set_sampler` to set the sampler with" + "the number of iterations. Without the number of iterations," + " the sampler will not be able to sample the data." + ) + if rng is not None: + del rng + return next(self._iterator) diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py new file mode 100644 index 00000000..f93835af --- /dev/null +++ b/src/cellflow/data/_torch_dataloader.py @@ -0,0 +1,77 @@ +import numpy as np +from dataclasses import dataclass +from functools import partial +import numpy as np +import torch +from cellflow.data._data import ZarrTrainingData +from cellflow.compat import TorchIterableDataset +from cellflow.data._dataloader import TrainSampler + + +def _worker_init_fn_helper(worker_id, random_generators): + import torch + + del worker_id + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id # type: ignore[union-attr] + rng = random_generators[worker_id] + worker_info.dataset.set_rng(rng) # type: ignore[union-attr] + return rng + + +@dataclass +class CombinedTrainingSampler(TorchIterableDataset): + """ + Combined training sampler that iterates over multiple samplers. + + Need to call set_rng(rng) before using the sampler. + + Args: + samplers: List of training samplers. + rng: Random number generator. + """ + + samplers: list[TrainSampler] + weights: np.ndarray | None = None + rng: np.random.Generator | None = None + + def __post_init__(self): + if self.weights is None: + self.weights = np.ones(len(self.samplers)) + assert len(self.weights) == len(self.samplers) + self.weights = self.weights / self.weights.sum() + + def set_rng(self, rng: np.random.Generator): + self.rng = rng + + def __iter__(self): + return self + + def __next__(self): + if self.rng is None: + raise ValueError("Please call set_rng() before using the sampler.") + return self.samplers[self.rng.choice(len(self.samplers), p=self.weights)].sample(self.rng) + + @classmethod + def combine_zarr_training_samplers( + cls, + data_paths: list[str], + batch_size: int = 1024, + seed: int = 42, + num_workers: int = 4, + prefetch_factor: int = 2, + weights: np.ndarray | None = None, + ): + seq = np.random.SeedSequence(seed) + random_generators = [np.random.default_rng(s) for s in seq.spawn(len(data_paths))] + worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) + data = [ZarrTrainingData.read_zarr(path) for path in data_paths] + samplers = [TrainSampler(data[i], batch_size) for i in range(len(data))] + combined_sampler = cls(samplers, weights=weights) + return torch.utils.data.DataLoader( + combined_sampler, + batch_size=None, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) From e40e57579ec2aedab768b75445795d66fd2c0348 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 9 Aug 2025 02:47:22 +0300 Subject: [PATCH 03/35] add new files --- src/cellflow/compat/torch_.py | 4 +- src/cellflow/data/__init__.py | 6 ++- src/cellflow/data/_data.py | 45 +++++----------- src/cellflow/data/_torch_dataloader.py | 7 +-- src/cellflow/model/_cellflow.py | 10 ++-- src/cellflow/training/_trainer.py | 6 +-- tests/compat/test_torch_.py | 61 +++++++++++++++++++++ tests/data/test_cfsampler.py | 6 +-- tests/data/test_jax_dataloader.py | 43 +++++++++++++++ tests/data/test_torch_dataloader.py | 75 ++++++++++++++++++++++++++ tests/test_optional.py | 15 ++++++ 11 files changed, 228 insertions(+), 50 deletions(-) create mode 100644 tests/compat/test_torch_.py create mode 100644 tests/data/test_jax_dataloader.py create mode 100644 tests/data/test_torch_dataloader.py create mode 100644 tests/test_optional.py diff --git a/src/cellflow/compat/torch_.py b/src/cellflow/compat/torch_.py index c2120f94..5a51fa5e 100644 --- a/src/cellflow/compat/torch_.py +++ b/src/cellflow/compat/torch_.py @@ -4,14 +4,16 @@ try: from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore + TORCH_AVAILABLE = True except ImportError as _: TORCH_AVAILABLE = False - class TorchIterableDataset: # type: ignore + class TorchIterableDataset: # noqa: D101 def __init__(self, *args, **kwargs): raise OptionalDependencyNotAvailable(torch_required_msg()) + if TYPE_CHECKING: # keeps type checkers aligned with the real type from torch.utils.data import IterableDataset as TorchIterableDataset # noqa: F401 diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index 94998cef..6d3374df 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -10,8 +10,9 @@ PredictionSampler, TrainSampler, ValidationSampler, - CombinedTrainSampler, ) +from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler +from cellflow.data._torch_dataloader import TorchCombinedTrainSampler from cellflow.data._datamanager import DataManager __all__ = [ @@ -25,5 +26,6 @@ "TrainSampler", "ValidationSampler", "PredictionSampler", - "CombinedTrainSampler", + "TorchCombinedTrainSampler", + "JaxOutOfCoreTrainSampler", ] diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 2de0dd9e..5804b5d0 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -155,24 +155,13 @@ def to_zarr( cell_data = np.asarray(self.cell_data) split_covariates_mask = np.asarray(self.split_covariates_mask) perturbation_covariates_mask = np.asarray(self.perturbation_covariates_mask) - condition_data = { - str(k): np.asarray(v) for k, v in (self.condition_data or {}).items() - } - control_to_perturbation = { - str(k): np.asarray(v) - for k, v in (self.control_to_perturbation or {}).items() - } - split_idx_to_covariates = { - str(k): np.asarray(v) - for k, v in (self.split_idx_to_covariates or {}).items() - } + condition_data = {str(k): np.asarray(v) for k, v in (self.condition_data or {}).items()} + control_to_perturbation = {str(k): np.asarray(v) for k, v in (self.control_to_perturbation or {}).items()} + split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (self.split_idx_to_covariates or {}).items()} perturbation_idx_to_covariates = { - str(k): np.asarray(v) - for k, v in (self.perturbation_idx_to_covariates or {}).items() - } - perturbation_idx_to_id = { - str(k): v for k, v in (self.perturbation_idx_to_id or {}).items() + str(k): np.asarray(v) for k, v in (self.perturbation_idx_to_covariates or {}).items() } + perturbation_idx_to_id = {str(k): v for k, v in (self.perturbation_idx_to_id or {}).items()} train_data_dict: dict[str, Any] = { "cell_data": cell_data, @@ -216,9 +205,7 @@ def _write_sharded_callback( func(group, key, element, dataset_kwargs=dataset_kwargs) zgroup = zarr.open_group(path, mode="a") - ad.experimental.write_dispatched( - zgroup, "/", train_data_dict, callback=_write_sharded_callback - ) + ad.experimental.write_dispatched(zgroup, "/", train_data_dict, callback=_write_sharded_callback) zarr.consolidate_metadata(zgroup.store) @@ -330,7 +317,7 @@ class ZarrTrainingData(BaseDataMixin): max_combination_length: int @classmethod - def read_zarr(cls, path: str) -> "ZarrTrainingData": + def read_zarr(cls, path: str) -> ZarrTrainingData: group = zarr.open_group(path, mode="r") max_len_node = group.get("max_combination_length") if max_len_node is None: @@ -338,25 +325,17 @@ def read_zarr(cls, path: str) -> "ZarrTrainingData": else: try: max_combination_length = int(max_len_node[()]) - except Exception: + except Exception: # noqa: BLE001 max_combination_length = int(max_len_node) return cls( cell_data=group["cell_data"], split_covariates_mask=group["split_covariates_mask"], perturbation_covariates_mask=group["perturbation_covariates_mask"], - split_idx_to_covariates=ad.io.read_elem( - group["split_idx_to_covariates"] - ), - perturbation_idx_to_covariates=ad.io.read_elem( - group["perturbation_idx_to_covariates"] - ), - perturbation_idx_to_id=ad.io.read_elem( - group["perturbation_idx_to_id"] - ), + split_idx_to_covariates=ad.io.read_elem(group["split_idx_to_covariates"]), + perturbation_idx_to_covariates=ad.io.read_elem(group["perturbation_idx_to_covariates"]), + perturbation_idx_to_id=ad.io.read_elem(group["perturbation_idx_to_id"]), condition_data=ad.io.read_elem(group["condition_data"]), - control_to_perturbation=ad.io.read_elem( - group["control_to_perturbation"] - ), + control_to_perturbation=ad.io.read_elem(group["control_to_perturbation"]), max_combination_length=max_combination_length, ) diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index f93835af..cf5eb11a 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -1,10 +1,11 @@ -import numpy as np from dataclasses import dataclass from functools import partial + import numpy as np import torch -from cellflow.data._data import ZarrTrainingData + from cellflow.compat import TorchIterableDataset +from cellflow.data._data import ZarrTrainingData from cellflow.data._dataloader import TrainSampler @@ -20,7 +21,7 @@ def _worker_init_fn_helper(worker_id, random_generators): @dataclass -class CombinedTrainingSampler(TorchIterableDataset): +class TorchCombinedTrainSampler(TorchIterableDataset): """ Combined training sampler that iterates over multiple samplers. diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 3bab3691..2a5a1e1f 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -18,7 +18,7 @@ from cellflow import _constants from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t from cellflow.data._data import ConditionData, TrainingData, ValidationData -from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler from cellflow.data._datamanager import DataManager from cellflow.model._utils import _write_predictions from cellflow.networks import _velocity_field @@ -54,7 +54,7 @@ def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm") if solver == "otfm" else _velocity_field.GENOTConditionalVelocityField ) - self._dataloader: TrainSampler | OOCTrainSampler | None = None + self._dataloader: TrainSampler | JaxOutOfCoreTrainSampler | None = None self._trainer: CellFlowTrainer | None = None self._validation_data: dict[str, ValidationData] = {} self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None @@ -532,7 +532,7 @@ def train( monitor_metrics Metrics to monitor. out_of_core_dataloading - If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data._dataloader.OOCTrainSampler` + If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data.JaxOutOfCoreTrainSampler` to load data that does not fit into GPU memory. Returns @@ -549,7 +549,7 @@ def train( raise ValueError("Model not initialized. Please call `prepare_model` first.") if out_of_core_dataloading: - self._dataloader = OOCTrainSampler(data=self.train_data, batch_size=batch_size) + self._dataloader = JaxOutOfCoreTrainSampler(data=self.train_data, batch_size=batch_size) else: self._dataloader = TrainSampler(data=self.train_data, batch_size=batch_size) validation_loaders = {k: ValidationSampler(v) for k, v in self.validation_data.items() if k != "predict_kwargs"} @@ -808,7 +808,7 @@ def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | None: return self._solver @property - def dataloader(self) -> TrainSampler | OOCTrainSampler | None: + def dataloader(self) -> TrainSampler | JaxOutOfCoreTrainSampler | None: """The dataloader used for training.""" return self._dataloader diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 84973c38..0e664d3f 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -6,7 +6,7 @@ from numpy.typing import ArrayLike from tqdm import tqdm -from cellflow.data._dataloader import OOCTrainSampler, TrainSampler, ValidationSampler +from cellflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler from cellflow.solvers import _genot, _otfm from cellflow.training._callbacks import BaseCallback, CallbackRunner @@ -83,7 +83,7 @@ def _update_logs(self, logs: dict[str, Any]) -> None: def train( self, - dataloader: TrainSampler | OOCTrainSampler, + dataloader: TrainSampler | JaxOutOfCoreTrainSampler, num_iterations: int, valid_freq: int, valid_loaders: dict[str, ValidationSampler] | None = None, @@ -124,7 +124,7 @@ def train( pbar = tqdm(range(num_iterations)) sampler = dataloader - if isinstance(dataloader, OOCTrainSampler): + if isinstance(dataloader, JaxOutOfCoreTrainSampler): dataloader.set_sampler(num_iterations=num_iterations) for it in pbar: rng_jax, rng_step_fn = jax.random.split(rng_jax, 2) diff --git a/tests/compat/test_torch_.py b/tests/compat/test_torch_.py new file mode 100644 index 00000000..9b9f7e4a --- /dev/null +++ b/tests/compat/test_torch_.py @@ -0,0 +1,61 @@ +import importlib +import types + +import pytest + + +class ImportBlocker: + """Block importing of specific top-level packages. + + Inserts into sys.meta_path to raise ImportError for names starting with + blocked_prefix. + """ + + def __init__(self, blocked_prefix: str): + self.blocked_prefix = blocked_prefix + + def find_spec(self, fullname, path, target=None): # noqa: D401 + if fullname == self.blocked_prefix or fullname.startswith( + f"{self.blocked_prefix}." + ): + raise ImportError(f"blocked import: {fullname}") + return None + + +def test_torch_iterabledataset_fallback_raises(monkeypatch): + # Block importing torch to trigger fallback path + import sys + + blocker = ImportBlocker("torch") + monkeypatch.setattr(sys, "meta_path", [blocker] + sys.meta_path, raising=False) + + # Ensure module is re-imported fresh + if "cellflow.compat.torch_" in sys.modules: + del sys.modules["cellflow.compat.torch_"] + + torch_mod = importlib.import_module("cellflow.compat.torch_") + + # Fallback class should be defined locally and raise on init + from cellflow._optional import OptionalDependencyNotAvailable + + with pytest.raises(OptionalDependencyNotAvailable) as excinfo: + _ = torch_mod.TorchIterableDataset() # type: ignore[call-arg] + assert "Optional dependency 'torch'" in str(excinfo.value) + + +def test_torch_iterabledataset_when_torch_available(monkeypatch): + torch = pytest.importorskip("torch") + + # Make sure previously inserted blockers are not present by fully + # reloading module + mod_name = "cellflow.compat.torch_" + if mod_name in list(importlib.sys.modules): + del importlib.sys.modules[mod_name] + compat_torch = importlib.import_module(mod_name) + + # Should alias to torch.utils.data.IterableDataset + from torch.utils.data import IterableDataset as TorchIterableDatasetReal + + assert compat_torch.TorchIterableDataset is TorchIterableDatasetReal + + diff --git a/tests/data/test_cfsampler.py b/tests/data/test_cfsampler.py index 0ae5405a..1d7b127c 100644 --- a/tests/data/test_cfsampler.py +++ b/tests/data/test_cfsampler.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler +from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler from cellflow.data._datamanager import DataManager @@ -45,7 +45,7 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): assert sample_2["condition"]["dosage"].shape[0] == 1 -class TestOOCTrainSampler: +class TestJaxOutOfCoreTrainSampler: @pytest.mark.parametrize("batch_size", [1, 31]) def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): sample_rep = "X" @@ -67,7 +67,7 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): ) train_data = dm.get_train_data(adata_perturbation) - sampler = OOCTrainSampler(data=train_data, batch_size=batch_size, seed=0) + sampler = JaxOutOfCoreTrainSampler(data=train_data, batch_size=batch_size, seed=0) sampler.set_sampler(num_iterations=2) sample_1 = sampler.sample() sample_2 = sampler.sample() diff --git a/tests/data/test_jax_dataloader.py b/tests/data/test_jax_dataloader.py new file mode 100644 index 00000000..b06b59cd --- /dev/null +++ b/tests/data/test_jax_dataloader.py @@ -0,0 +1,43 @@ +import sys +import threading + +import numpy as np +import pytest + +from cellflow.data._dataloader import TrainSampler +from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler + + +class _DummyData: + def __init__(self): + self.cell_data = np.arange(20).reshape(10, 2) + self.split_covariates_mask = np.array([0] * 5 + [1] * 5) + self.perturbation_covariates_mask = np.array([0] * 5 + [1] * 5) + self.control_to_perturbation = {0: np.array([0]), 1: np.array([1])} + self.condition_data = None + + +def test_jax_out_of_core_sampler_no_jax(monkeypatch): + # Skip if jax is installed; this test ensures no import errors when jax missing + if "jax" in sys.modules: + pytest.skip("JAX present in environment; skip missing-JAX behavior test") + + sampler = JaxOutOfCoreTrainSampler(data=_DummyData(), seed=0, batch_size=2, num_workers=1, prefetch_factor=1) + # set_sampler imports jax; confirm it raises ImportError when jax not present + with pytest.raises(ImportError): + sampler.set_sampler(num_iterations=1) + + +@pytest.mark.skipif("jax" not in sys.modules, reason="Requires JAX runtime in environment") +def test_jax_out_of_core_sampler_with_jax(monkeypatch): + # Basic smoke test when JAX is available + data = _DummyData() + sampler = JaxOutOfCoreTrainSampler(data=data, seed=0, batch_size=2, num_workers=1, prefetch_factor=1) + sampler.set_sampler(num_iterations=2) + b1 = sampler.sample() + b2 = sampler.sample() + assert set(b1.keys()) == {"src_cell_data", "tgt_cell_data"} + assert b1["src_cell_data"].shape[0] == 2 + assert b2["src_cell_data"].shape[0] == 2 + + diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py new file mode 100644 index 00000000..92cb51f4 --- /dev/null +++ b/tests/data/test_torch_dataloader.py @@ -0,0 +1,75 @@ +import sys +from dataclasses import dataclass + +import numpy as np +import pytest + +# Skip these tests entirely if torch is not available because the module +# under test imports torch at module import time. +pytest.importorskip("torch") + +from cellflow.data._torch_dataloader import ( # noqa: E402 + CombinedTrainingSampler, + _worker_init_fn_helper, +) + + +@dataclass +class DummySampler: + label: str + + def sample(self, rng: np.random.Generator): # noqa: D401 + return {"label": self.label, "rand": rng.random()} + + +def test_combined_sampler_requires_rng(): + s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")]) + with pytest.raises(ValueError): + next(iter(s)) + + +def test_combined_sampler_respects_weights_choice_first(): + s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")], weights=np.array([1.0, 0.0])) + s.set_rng(np.random.default_rng(123)) + batch = next(iter(s)) + assert batch["label"] == "a" + + +def test_combined_sampler_respects_weights_choice_second(): + s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")], weights=np.array([0.0, 1.0])) + s.set_rng(np.random.default_rng(123)) + batch = next(iter(s)) + assert batch["label"] == "b" + + +class _FakeDataset: + def __init__(self): + self._rng = None + + def set_rng(self, rng): + self._rng = rng + + +def test_worker_init_fn_helper_sets_rng(monkeypatch): + # Provide a fake torch with minimal API for get_worker_info + class _FakeWorkerInfo: + def __init__(self): + self.id = 0 + self.dataset = _FakeDataset() + + class _FakeTorch: + class utils: + class data: + @staticmethod + def get_worker_info(): + return _FakeWorkerInfo() + + monkeypatch.setitem(sys.modules, "torch", _FakeTorch()) + + rngs = [np.random.default_rng(42)] + out = _worker_init_fn_helper(0, rngs) + # Verify returned rng is the same and dataset received it + assert out is rngs[0] + assert _FakeTorch.utils.data.get_worker_info().dataset._rng is rngs[0] + + diff --git a/tests/test_optional.py b/tests/test_optional.py new file mode 100644 index 00000000..c7132d6e --- /dev/null +++ b/tests/test_optional.py @@ -0,0 +1,15 @@ +import importlib + +import pytest + + +def test_optional_dependency_exception_message(): + opt = importlib.import_module("cellflow._optional") + # Ensure exception type exists and message contains installation hint + with pytest.raises(opt.OptionalDependencyNotAvailable) as excinfo: + raise opt.OptionalDependencyNotAvailable(opt.torch_required_msg()) + msg = str(excinfo.value) + assert "Optional dependency 'torch' is required" in msg + assert "pip install torch" in msg + + From 691e9410525bd8d955d396051b5204bc22820b20 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 9 Aug 2025 02:49:26 +0300 Subject: [PATCH 04/35] fix --- tests/data/test_torch_dataloader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py index 92cb51f4..bac9003a 100644 --- a/tests/data/test_torch_dataloader.py +++ b/tests/data/test_torch_dataloader.py @@ -9,7 +9,7 @@ pytest.importorskip("torch") from cellflow.data._torch_dataloader import ( # noqa: E402 - CombinedTrainingSampler, + TorchCombinedTrainSampler, _worker_init_fn_helper, ) @@ -23,20 +23,20 @@ def sample(self, rng: np.random.Generator): # noqa: D401 def test_combined_sampler_requires_rng(): - s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")]) + s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")]) with pytest.raises(ValueError): next(iter(s)) def test_combined_sampler_respects_weights_choice_first(): - s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")], weights=np.array([1.0, 0.0])) + s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")], weights=np.array([1.0, 0.0])) s.set_rng(np.random.default_rng(123)) batch = next(iter(s)) assert batch["label"] == "a" def test_combined_sampler_respects_weights_choice_second(): - s = CombinedTrainingSampler([DummySampler("a"), DummySampler("b")], weights=np.array([0.0, 1.0])) + s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")], weights=np.array([0.0, 1.0])) s.set_rng(np.random.default_rng(123)) batch = next(iter(s)) assert batch["label"] == "b" From a2a8ab10bcf79867ccba1deb805baa6c68847c41 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 9 Aug 2025 02:52:34 +0300 Subject: [PATCH 05/35] fix this --- tests/data/test_torch_dataloader.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py index bac9003a..abb09028 100644 --- a/tests/data/test_torch_dataloader.py +++ b/tests/data/test_torch_dataloader.py @@ -60,9 +60,12 @@ def __init__(self): class _FakeTorch: class utils: class data: - @staticmethod - def get_worker_info(): - return _FakeWorkerInfo() + pass + + worker_info = _FakeWorkerInfo() + def _get_worker_info(): + return worker_info + _FakeTorch.utils.data.get_worker_info = staticmethod(_get_worker_info) # type: ignore[attr-defined] monkeypatch.setitem(sys.modules, "torch", _FakeTorch()) @@ -70,6 +73,6 @@ def get_worker_info(): out = _worker_init_fn_helper(0, rngs) # Verify returned rng is the same and dataset received it assert out is rngs[0] - assert _FakeTorch.utils.data.get_worker_info().dataset._rng is rngs[0] + assert worker_info.dataset._rng is rngs[0] From 4a2cb7606f9bac4875c86ffd8254fdfbcefb2b66 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 9 Aug 2025 02:53:06 +0300 Subject: [PATCH 06/35] format --- src/cellflow/data/__init__.py | 2 +- src/cellflow/model/_cellflow.py | 2 +- tests/compat/test_torch_.py | 9 ++------- tests/data/test_jax_dataloader.py | 4 ---- tests/data/test_torch_dataloader.py | 8 ++++---- tests/test_optional.py | 2 -- 6 files changed, 8 insertions(+), 19 deletions(-) diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index 6d3374df..2de3519c 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -11,9 +11,9 @@ TrainSampler, ValidationSampler, ) +from cellflow.data._datamanager import DataManager from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler from cellflow.data._torch_dataloader import TorchCombinedTrainSampler -from cellflow.data._datamanager import DataManager __all__ = [ "DataManager", diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 2a5a1e1f..e4b09a9a 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -17,8 +17,8 @@ from cellflow import _constants from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.data._data import ConditionData, TrainingData, ValidationData from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from cellflow.data._data import ConditionData, TrainingData, ValidationData from cellflow.data._datamanager import DataManager from cellflow.model._utils import _write_predictions from cellflow.networks import _velocity_field diff --git a/tests/compat/test_torch_.py b/tests/compat/test_torch_.py index 9b9f7e4a..c899f9e6 100644 --- a/tests/compat/test_torch_.py +++ b/tests/compat/test_torch_.py @@ -1,5 +1,4 @@ import importlib -import types import pytest @@ -14,10 +13,8 @@ class ImportBlocker: def __init__(self, blocked_prefix: str): self.blocked_prefix = blocked_prefix - def find_spec(self, fullname, path, target=None): # noqa: D401 - if fullname == self.blocked_prefix or fullname.startswith( - f"{self.blocked_prefix}." - ): + def find_spec(self, fullname, path, target=None): + if fullname == self.blocked_prefix or fullname.startswith(f"{self.blocked_prefix}."): raise ImportError(f"blocked import: {fullname}") return None @@ -57,5 +54,3 @@ def test_torch_iterabledataset_when_torch_available(monkeypatch): from torch.utils.data import IterableDataset as TorchIterableDatasetReal assert compat_torch.TorchIterableDataset is TorchIterableDatasetReal - - diff --git a/tests/data/test_jax_dataloader.py b/tests/data/test_jax_dataloader.py index b06b59cd..d0874b76 100644 --- a/tests/data/test_jax_dataloader.py +++ b/tests/data/test_jax_dataloader.py @@ -1,10 +1,8 @@ import sys -import threading import numpy as np import pytest -from cellflow.data._dataloader import TrainSampler from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler @@ -39,5 +37,3 @@ def test_jax_out_of_core_sampler_with_jax(monkeypatch): assert set(b1.keys()) == {"src_cell_data", "tgt_cell_data"} assert b1["src_cell_data"].shape[0] == 2 assert b2["src_cell_data"].shape[0] == 2 - - diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py index abb09028..79874247 100644 --- a/tests/data/test_torch_dataloader.py +++ b/tests/data/test_torch_dataloader.py @@ -8,7 +8,7 @@ # under test imports torch at module import time. pytest.importorskip("torch") -from cellflow.data._torch_dataloader import ( # noqa: E402 +from cellflow.data._torch_dataloader import ( TorchCombinedTrainSampler, _worker_init_fn_helper, ) @@ -18,7 +18,7 @@ class DummySampler: label: str - def sample(self, rng: np.random.Generator): # noqa: D401 + def sample(self, rng: np.random.Generator): return {"label": self.label, "rand": rng.random()} @@ -63,8 +63,10 @@ class data: pass worker_info = _FakeWorkerInfo() + def _get_worker_info(): return worker_info + _FakeTorch.utils.data.get_worker_info = staticmethod(_get_worker_info) # type: ignore[attr-defined] monkeypatch.setitem(sys.modules, "torch", _FakeTorch()) @@ -74,5 +76,3 @@ def _get_worker_info(): # Verify returned rng is the same and dataset received it assert out is rngs[0] assert worker_info.dataset._rng is rngs[0] - - diff --git a/tests/test_optional.py b/tests/test_optional.py index c7132d6e..fa8f40f0 100644 --- a/tests/test_optional.py +++ b/tests/test_optional.py @@ -11,5 +11,3 @@ def test_optional_dependency_exception_message(): msg = str(excinfo.value) assert "Optional dependency 'torch' is required" in msg assert "pip install torch" in msg - - From 73900f6bc24b9294b249465d6293f52db618e7f4 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 19 Aug 2025 11:50:04 +0200 Subject: [PATCH 07/35] remove extra test files --- tests/data/test_jax_dataloader.py | 39 --------------- tests/data/test_torch_dataloader.py | 78 ----------------------------- 2 files changed, 117 deletions(-) delete mode 100644 tests/data/test_jax_dataloader.py delete mode 100644 tests/data/test_torch_dataloader.py diff --git a/tests/data/test_jax_dataloader.py b/tests/data/test_jax_dataloader.py deleted file mode 100644 index d0874b76..00000000 --- a/tests/data/test_jax_dataloader.py +++ /dev/null @@ -1,39 +0,0 @@ -import sys - -import numpy as np -import pytest - -from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler - - -class _DummyData: - def __init__(self): - self.cell_data = np.arange(20).reshape(10, 2) - self.split_covariates_mask = np.array([0] * 5 + [1] * 5) - self.perturbation_covariates_mask = np.array([0] * 5 + [1] * 5) - self.control_to_perturbation = {0: np.array([0]), 1: np.array([1])} - self.condition_data = None - - -def test_jax_out_of_core_sampler_no_jax(monkeypatch): - # Skip if jax is installed; this test ensures no import errors when jax missing - if "jax" in sys.modules: - pytest.skip("JAX present in environment; skip missing-JAX behavior test") - - sampler = JaxOutOfCoreTrainSampler(data=_DummyData(), seed=0, batch_size=2, num_workers=1, prefetch_factor=1) - # set_sampler imports jax; confirm it raises ImportError when jax not present - with pytest.raises(ImportError): - sampler.set_sampler(num_iterations=1) - - -@pytest.mark.skipif("jax" not in sys.modules, reason="Requires JAX runtime in environment") -def test_jax_out_of_core_sampler_with_jax(monkeypatch): - # Basic smoke test when JAX is available - data = _DummyData() - sampler = JaxOutOfCoreTrainSampler(data=data, seed=0, batch_size=2, num_workers=1, prefetch_factor=1) - sampler.set_sampler(num_iterations=2) - b1 = sampler.sample() - b2 = sampler.sample() - assert set(b1.keys()) == {"src_cell_data", "tgt_cell_data"} - assert b1["src_cell_data"].shape[0] == 2 - assert b2["src_cell_data"].shape[0] == 2 diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py deleted file mode 100644 index 79874247..00000000 --- a/tests/data/test_torch_dataloader.py +++ /dev/null @@ -1,78 +0,0 @@ -import sys -from dataclasses import dataclass - -import numpy as np -import pytest - -# Skip these tests entirely if torch is not available because the module -# under test imports torch at module import time. -pytest.importorskip("torch") - -from cellflow.data._torch_dataloader import ( - TorchCombinedTrainSampler, - _worker_init_fn_helper, -) - - -@dataclass -class DummySampler: - label: str - - def sample(self, rng: np.random.Generator): - return {"label": self.label, "rand": rng.random()} - - -def test_combined_sampler_requires_rng(): - s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")]) - with pytest.raises(ValueError): - next(iter(s)) - - -def test_combined_sampler_respects_weights_choice_first(): - s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")], weights=np.array([1.0, 0.0])) - s.set_rng(np.random.default_rng(123)) - batch = next(iter(s)) - assert batch["label"] == "a" - - -def test_combined_sampler_respects_weights_choice_second(): - s = TorchCombinedTrainSampler([DummySampler("a"), DummySampler("b")], weights=np.array([0.0, 1.0])) - s.set_rng(np.random.default_rng(123)) - batch = next(iter(s)) - assert batch["label"] == "b" - - -class _FakeDataset: - def __init__(self): - self._rng = None - - def set_rng(self, rng): - self._rng = rng - - -def test_worker_init_fn_helper_sets_rng(monkeypatch): - # Provide a fake torch with minimal API for get_worker_info - class _FakeWorkerInfo: - def __init__(self): - self.id = 0 - self.dataset = _FakeDataset() - - class _FakeTorch: - class utils: - class data: - pass - - worker_info = _FakeWorkerInfo() - - def _get_worker_info(): - return worker_info - - _FakeTorch.utils.data.get_worker_info = staticmethod(_get_worker_info) # type: ignore[attr-defined] - - monkeypatch.setitem(sys.modules, "torch", _FakeTorch()) - - rngs = [np.random.default_rng(42)] - out = _worker_init_fn_helper(0, rngs) - # Verify returned rng is the same and dataset received it - assert out is rngs[0] - assert worker_info.dataset._rng is rngs[0] From 245b59572bb1fe3def012faca5f403393a31a327 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 13:41:55 +0200 Subject: [PATCH 08/35] update the write function --- src/cellflow/data/_data.py | 39 +++++----------------- src/cellflow/data/_utils.py | 66 +++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 5804b5d0..57c14f93 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -8,6 +8,7 @@ import numpy as np import zarr from zarr.codecs import BloscCodec +from cellflow.data._utils import write_sharded from cellflow._types import ArrayLike @@ -175,38 +176,14 @@ def to_zarr( "max_combination_length": int(self.max_combination_length), } - # Ensure Zarr v3 write format for sharding - ad.settings.zarr_write_format = 3 - - def _write_sharded_callback( - func: Any, - group: Any, - key: str, - element: Any, - dataset_kwargs: dict[str, Any], - iospec: Any, - ) -> None: - # Only shard/chunk along the first dimension - if getattr(iospec, "encoding_type", None) in {"array"}: - dataset_kwargs = { - "shards": (shard_size,) + tuple(element.shape[1:]), - "chunks": (chunk_size,) + tuple(element.shape[1:]), - "compressors": compressors, - **dataset_kwargs, - } - elif getattr(iospec, "encoding_type", None) in {"csr_matrix", "csc_matrix"}: - dataset_kwargs = { - "shards": (shard_size,), - "chunks": (chunk_size,), - "compressors": compressors, - **dataset_kwargs, - } - - func(group, key, element, dataset_kwargs=dataset_kwargs) - zgroup = zarr.open_group(path, mode="a") - ad.experimental.write_dispatched(zgroup, "/", train_data_dict, callback=_write_sharded_callback) - zarr.consolidate_metadata(zgroup.store) + write_sharded( + zgroup, + train_data_dict, + chunk_size=chunk_size, + shard_size=shard_size, + compressors=compressors, + ) @dataclass diff --git a/src/cellflow/data/_utils.py b/src/cellflow/data/_utils.py index c22d50bb..530790f4 100644 --- a/src/cellflow/data/_utils.py +++ b/src/cellflow/data/_utils.py @@ -1,5 +1,67 @@ -from collections.abc import Iterable -from typing import Any +from typing import Any, Mapping, Iterable + +import anndata as ad +import zarr +from zarr.codecs import BloscCodec, BytesBytesCodec + + +def write_sharded( + group: zarr.Group, + data: dict[str, Any], + chunk_size: int = 4096, + shard_size: int = 65536, + compressors: Iterable[BytesBytesCodec] = ( + BloscCodec( + cname="lz4", + clevel=3, + ), + ), +): + """Function to write data to a zarr group in a sharded format. + + Parameters + ---------- + group + The zarr group to write to. + data + The data to write. + chunk_size + The chunk size. + shard_size + The shard size. + """ + # TODO: this is a copy of the function in arrayloaders + # when it is no longer public we should use the function from arrayloaders + # https://github.com/laminlabs/arrayloaders/blob/main/arrayloaders/io/store_creation.py + ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr + + def callback( + func: ad.experimental.Write, + g: zarr.Group, + k: str, + elem: ad.typing.RWAble, + dataset_kwargs: Mapping[str, Any], + iospec: ad.experimental.IOSpec, + ): + if iospec.encoding_type in {"array"}: + dataset_kwargs = { + "shards": (shard_size,) + (elem.shape[1:]), # only shard over 1st dim + "chunks": (chunk_size,) + (elem.shape[1:]), # only chunk over 1st dim + "compressors": compressors, + **dataset_kwargs, + } + elif iospec.encoding_type in {"csr_matrix", "csc_matrix"}: + dataset_kwargs = { + "shards": (shard_size,), + "chunks": (chunk_size,), + "compressors": compressors, + **dataset_kwargs, + } + + func(g, k, elem, dataset_kwargs=dataset_kwargs) + + ad.experimental.write_dispatched(group, "/", data, callback=callback) + zarr.consolidate_metadata(group.store) def _to_list(x: list[Any] | tuple[Any] | Any) -> list[Any] | tuple[Any]: From 2a0d8703327269945d506df40f5190797f40c02d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:44:48 +0000 Subject: [PATCH 09/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/cellflow/data/_data.py | 2 +- src/cellflow/data/_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 57c14f93..ba4d1953 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -8,9 +8,9 @@ import numpy as np import zarr from zarr.codecs import BloscCodec -from cellflow.data._utils import write_sharded from cellflow._types import ArrayLike +from cellflow.data._utils import write_sharded __all__ = [ "BaseDataMixin", diff --git a/src/cellflow/data/_utils.py b/src/cellflow/data/_utils.py index 530790f4..90383fcf 100644 --- a/src/cellflow/data/_utils.py +++ b/src/cellflow/data/_utils.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Iterable +from collections.abc import Iterable, Mapping +from typing import Any import anndata as ad import zarr From 6e34cc7c97ab63bfa8c7282921c58d093cdd9cf7 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 13:45:12 +0200 Subject: [PATCH 10/35] remove compat test --- tests/compat/test_torch_.py | 56 ------------------------------------- 1 file changed, 56 deletions(-) delete mode 100644 tests/compat/test_torch_.py diff --git a/tests/compat/test_torch_.py b/tests/compat/test_torch_.py deleted file mode 100644 index c899f9e6..00000000 --- a/tests/compat/test_torch_.py +++ /dev/null @@ -1,56 +0,0 @@ -import importlib - -import pytest - - -class ImportBlocker: - """Block importing of specific top-level packages. - - Inserts into sys.meta_path to raise ImportError for names starting with - blocked_prefix. - """ - - def __init__(self, blocked_prefix: str): - self.blocked_prefix = blocked_prefix - - def find_spec(self, fullname, path, target=None): - if fullname == self.blocked_prefix or fullname.startswith(f"{self.blocked_prefix}."): - raise ImportError(f"blocked import: {fullname}") - return None - - -def test_torch_iterabledataset_fallback_raises(monkeypatch): - # Block importing torch to trigger fallback path - import sys - - blocker = ImportBlocker("torch") - monkeypatch.setattr(sys, "meta_path", [blocker] + sys.meta_path, raising=False) - - # Ensure module is re-imported fresh - if "cellflow.compat.torch_" in sys.modules: - del sys.modules["cellflow.compat.torch_"] - - torch_mod = importlib.import_module("cellflow.compat.torch_") - - # Fallback class should be defined locally and raise on init - from cellflow._optional import OptionalDependencyNotAvailable - - with pytest.raises(OptionalDependencyNotAvailable) as excinfo: - _ = torch_mod.TorchIterableDataset() # type: ignore[call-arg] - assert "Optional dependency 'torch'" in str(excinfo.value) - - -def test_torch_iterabledataset_when_torch_available(monkeypatch): - torch = pytest.importorskip("torch") - - # Make sure previously inserted blockers are not present by fully - # reloading module - mod_name = "cellflow.compat.torch_" - if mod_name in list(importlib.sys.modules): - del importlib.sys.modules[mod_name] - compat_torch = importlib.import_module(mod_name) - - # Should alias to torch.utils.data.IterableDataset - from torch.utils.data import IterableDataset as TorchIterableDatasetReal - - assert compat_torch.TorchIterableDataset is TorchIterableDatasetReal From cc2d53b330f225ea7b7c56c08bd20a619da6bf93 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 14:26:43 +0200 Subject: [PATCH 11/35] fix import problems and rename function to write_zarr --- src/cellflow/data/_data.py | 20 +++++++++++++------- src/cellflow/data/_utils.py | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index ba4d1953..2670a0b3 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -7,7 +7,6 @@ import anndata as ad import numpy as np import zarr -from zarr.codecs import BloscCodec from cellflow._types import ArrayLike from cellflow.data._utils import write_sharded @@ -128,7 +127,7 @@ class TrainingData(BaseDataMixin): data_manager: Any # --- Zarr export helpers ------------------------------------------------- - def to_zarr( + def write_zarr( self, path: str, *, @@ -149,9 +148,6 @@ def to_zarr( compressors Optional list/tuple of Zarr codecs. If ``None``, a sensible default is used. """ - if compressors is None: - compressors = (BloscCodec(cname="lz4", clevel=3),) - # Convert to numpy-backed containers for serialization cell_data = np.asarray(self.cell_data) split_covariates_mask = np.asarray(self.split_covariates_mask) @@ -176,13 +172,17 @@ def to_zarr( "max_combination_length": int(self.max_combination_length), } - zgroup = zarr.open_group(path, mode="a") + additional_kwargs = {} + if compressors is not None: + additional_kwargs["compressors"] = compressors + + zgroup = zarr.open_group(path, mode="w") write_sharded( zgroup, train_data_dict, chunk_size=chunk_size, shard_size=shard_size, - compressors=compressors, + **additional_kwargs, ) @@ -293,6 +293,12 @@ class ZarrTrainingData(BaseDataMixin): control_to_perturbation: dict[int, Any] max_combination_length: int + def __post_init__(self): + self.control_to_perturbation = {int(k): v for k, v in self.control_to_perturbation.items()} + self.perturbation_idx_to_id = {int(k): v for k, v in self.perturbation_idx_to_id.items()} + self.perturbation_idx_to_covariates = {int(k): v for k, v in self.perturbation_idx_to_covariates.items()} + self.split_idx_to_covariates = {int(k): v for k, v in self.split_idx_to_covariates.items()} + @classmethod def read_zarr(cls, path: str) -> ZarrTrainingData: group = zarr.open_group(path, mode="r") diff --git a/src/cellflow/data/_utils.py b/src/cellflow/data/_utils.py index 90383fcf..d7e4a728 100644 --- a/src/cellflow/data/_utils.py +++ b/src/cellflow/data/_utils.py @@ -3,7 +3,8 @@ import anndata as ad import zarr -from zarr.codecs import BloscCodec, BytesBytesCodec +from zarr.abc.codec import BytesBytesCodec +from zarr.codecs import BloscCodec def write_sharded( From 297a83c0986c4e5f96742826b2bfa643015c51d3 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 15:26:18 +0200 Subject: [PATCH 12/35] hide explicit torch imports --- src/cellflow/data/_torch_dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index cf5eb11a..1e2f8050 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -2,7 +2,6 @@ from functools import partial import numpy as np -import torch from cellflow.compat import TorchIterableDataset from cellflow.data._data import ZarrTrainingData @@ -63,6 +62,8 @@ def combine_zarr_training_samplers( prefetch_factor: int = 2, weights: np.ndarray | None = None, ): + import torch + seq = np.random.SeedSequence(seed) random_generators = [np.random.default_rng(s) for s in seq.spawn(len(data_paths))] worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) From 2be2bd61d023b263639f179ccf7b0ed7b8f919f2 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 15:36:40 +0200 Subject: [PATCH 13/35] add read and write zarr tests --- tests/data/test_cfsampler.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/data/test_cfsampler.py b/tests/data/test_cfsampler.py index 1d7b127c..4c15e44a 100644 --- a/tests/data/test_cfsampler.py +++ b/tests/data/test_cfsampler.py @@ -1,13 +1,16 @@ +from pathlib import Path + import numpy as np import pytest from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler +from cellflow.data._data import ZarrTrainingData from cellflow.data._datamanager import DataManager class TestTrainSampler: @pytest.mark.parametrize("batch_size", [1, 31]) - def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): + def test_sampling_no_combinations(self, adata_perturbation, batch_size: int, tmp_path): sample_rep = "X" split_covariates = ["cell_type"] control_key = "control" @@ -27,20 +30,29 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): ) train_data = dm.get_train_data(adata_perturbation) + train_data.write_zarr(Path(tmp_path) / "test_train_data.zarr") sampler = TrainSampler(data=train_data, batch_size=batch_size) + zarr_sampler = TrainSampler(ZarrTrainingData.read_zarr(Path(tmp_path) / "test_train_data.zarr"), batch_size=batch_size) rng_1 = np.random.default_rng(0) rng_2 = np.random.default_rng(1) + rng_3 = np.random.default_rng(2) sample_1 = sampler.sample(rng_1) sample_2 = sampler.sample(rng_2) + sample_3 = zarr_sampler.sample(rng_3) assert "src_cell_data" in sample_1 assert "tgt_cell_data" in sample_1 assert "condition" in sample_1 + assert "src_cell_data" in sample_3 + assert "tgt_cell_data" in sample_3 + assert "condition" in sample_3 assert sample_1["src_cell_data"].shape[0] == batch_size assert sample_2["src_cell_data"].shape[0] == batch_size + assert sample_3["src_cell_data"].shape[0] == batch_size assert sample_1["tgt_cell_data"].shape[0] == batch_size assert sample_2["tgt_cell_data"].shape[0] == batch_size + assert sample_3["tgt_cell_data"].shape[0] == batch_size assert sample_1["condition"]["dosage"].shape[0] == 1 assert sample_2["condition"]["dosage"].shape[0] == 1 From a1f974c2132ad247c71fe434b78ce97dcc1ebcfe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 13:36:52 +0000 Subject: [PATCH 14/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/data/test_cfsampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_cfsampler.py b/tests/data/test_cfsampler.py index 4c15e44a..18406e52 100644 --- a/tests/data/test_cfsampler.py +++ b/tests/data/test_cfsampler.py @@ -32,7 +32,9 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int, tmp train_data = dm.get_train_data(adata_perturbation) train_data.write_zarr(Path(tmp_path) / "test_train_data.zarr") sampler = TrainSampler(data=train_data, batch_size=batch_size) - zarr_sampler = TrainSampler(ZarrTrainingData.read_zarr(Path(tmp_path) / "test_train_data.zarr"), batch_size=batch_size) + zarr_sampler = TrainSampler( + ZarrTrainingData.read_zarr(Path(tmp_path) / "test_train_data.zarr"), batch_size=batch_size + ) rng_1 = np.random.default_rng(0) rng_2 = np.random.default_rng(1) rng_3 = np.random.default_rng(2) From f4062bbb07da7f3a9a1cc514d92cef4027956ae2 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Mon, 25 Aug 2025 16:41:49 +0200 Subject: [PATCH 15/35] push working state --- tests/model/test_cellflow.py | 62 ++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 90aa259b..26029049 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from pathlib import Path import cellflow from cellflow.networks import _velocity_field @@ -492,3 +493,64 @@ def test_cellflow_get_condition_embedding( assert out[1].index.name == condition_id_key cond_id_vals = conds[condition_id_key].values assert out[1].index.isin(cond_id_vals).all() + + def test_cellflow_solver_with_torch_dataloader( + self, + adata_perturbation, + tmp_path, + ): + solver = "otfm" + sample_rep = "X" + control_key = "control" + condition_mode = "deterministic" + regularization = 0.1 + conditioning = "concatenation" + perturbation_covariates = {"drug": ["drug1", "drug2"]} + perturbation_covariate_reps = {"drug": "drug"} + condition_embedding_dim = 4 + vf_kwargs = None + + cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf.prepare_data( + sample_rep=sample_rep, + control_key=control_key, + perturbation_covariates=perturbation_covariates, + perturbation_covariate_reps=perturbation_covariate_reps, + ) + + cf.train_data.write_zarr(Path(tmp_path) / "train_data1.zarr") + cf.train_data.write_zarr(Path(tmp_path) / "train_data2.zarr") + cf.train_data.write_zarr(Path(tmp_path) / "train_data3.zarr") + + + assert cf.train_data is not None + assert hasattr(cf, "_data_dim") + + cf.prepare_model( + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + hidden_dims=(2, 2), + decoder_dims=(2, 2), + vf_kwargs=vf_kwargs, + conditioning=conditioning, + ) + assert cf._trainer is not None + + cf.train(num_iterations=3) + assert cf._dataloader is not None + + conds = adata_perturbation.obs.drop_duplicates(subset=["drug1", "drug2"]) + cond_embed_mean, cond_embed_var = cf.get_condition_embedding(conds, rep_dict=adata_perturbation.uns) + assert isinstance(cond_embed_mean, pd.DataFrame) + assert isinstance(cond_embed_var, pd.DataFrame) + assert cond_embed_mean.shape[0] == conds.shape[0] + assert cond_embed_mean.shape[1] == condition_embedding_dim + assert cond_embed_var.shape[0] == conds.shape[0] + assert cond_embed_var.shape[1] == condition_embedding_dim + assert cond_embed_mean.shape[0] == conds.shape[0] + assert cond_embed_mean.shape[1] == condition_embedding_dim + assert cond_embed_var.shape[0] == conds.shape[0] + assert cond_embed_var.shape[1] == condition_embedding_dim + # TODO: add mode to make the training independent of anndata or at least make it + # read only \ No newline at end of file From 7ac0f8f5ddec0c8e9d7ba9018d8bd5afdf0d45d9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 Aug 2025 11:33:29 +0200 Subject: [PATCH 16/35] remove torch test for cellflow workflow --- tests/model/test_cellflow.py | 63 +----------------------------------- 1 file changed, 1 insertion(+), 62 deletions(-) diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 26029049..c382b100 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -492,65 +492,4 @@ def test_cellflow_get_condition_embedding( assert out[1].shape[0] == out[1].shape[0] assert out[1].index.name == condition_id_key cond_id_vals = conds[condition_id_key].values - assert out[1].index.isin(cond_id_vals).all() - - def test_cellflow_solver_with_torch_dataloader( - self, - adata_perturbation, - tmp_path, - ): - solver = "otfm" - sample_rep = "X" - control_key = "control" - condition_mode = "deterministic" - regularization = 0.1 - conditioning = "concatenation" - perturbation_covariates = {"drug": ["drug1", "drug2"]} - perturbation_covariate_reps = {"drug": "drug"} - condition_embedding_dim = 4 - vf_kwargs = None - - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) - cf.prepare_data( - sample_rep=sample_rep, - control_key=control_key, - perturbation_covariates=perturbation_covariates, - perturbation_covariate_reps=perturbation_covariate_reps, - ) - - cf.train_data.write_zarr(Path(tmp_path) / "train_data1.zarr") - cf.train_data.write_zarr(Path(tmp_path) / "train_data2.zarr") - cf.train_data.write_zarr(Path(tmp_path) / "train_data3.zarr") - - - assert cf.train_data is not None - assert hasattr(cf, "_data_dim") - - cf.prepare_model( - condition_mode=condition_mode, - regularization=regularization, - condition_embedding_dim=condition_embedding_dim, - hidden_dims=(2, 2), - decoder_dims=(2, 2), - vf_kwargs=vf_kwargs, - conditioning=conditioning, - ) - assert cf._trainer is not None - - cf.train(num_iterations=3) - assert cf._dataloader is not None - - conds = adata_perturbation.obs.drop_duplicates(subset=["drug1", "drug2"]) - cond_embed_mean, cond_embed_var = cf.get_condition_embedding(conds, rep_dict=adata_perturbation.uns) - assert isinstance(cond_embed_mean, pd.DataFrame) - assert isinstance(cond_embed_var, pd.DataFrame) - assert cond_embed_mean.shape[0] == conds.shape[0] - assert cond_embed_mean.shape[1] == condition_embedding_dim - assert cond_embed_var.shape[0] == conds.shape[0] - assert cond_embed_var.shape[1] == condition_embedding_dim - assert cond_embed_mean.shape[0] == conds.shape[0] - assert cond_embed_mean.shape[1] == condition_embedding_dim - assert cond_embed_var.shape[0] == conds.shape[0] - assert cond_embed_var.shape[1] == condition_embedding_dim - # TODO: add mode to make the training independent of anndata or at least make it - # read only \ No newline at end of file + assert out[1].index.isin(cond_id_vals).all() \ No newline at end of file From 042e07a34e64d87bf1f1b17e248f5b7efb731456 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 09:34:57 +0000 Subject: [PATCH 17/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/model/test_cellflow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 0932281e..6bfbeb3f 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -1,7 +1,6 @@ import numpy as np import pandas as pd import pytest -from pathlib import Path import cellflow from cellflow.networks import _velocity_field @@ -528,4 +527,3 @@ def test_time_embedding( time_freqs=time_freqs, time_max_period=time_max_period, ) - From d454e34e073bba274fbebd4010437184d03b340c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Tue, 26 Aug 2025 11:35:32 +0200 Subject: [PATCH 18/35] Delete tests/test_optional.py --- tests/test_optional.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 tests/test_optional.py diff --git a/tests/test_optional.py b/tests/test_optional.py deleted file mode 100644 index fa8f40f0..00000000 --- a/tests/test_optional.py +++ /dev/null @@ -1,13 +0,0 @@ -import importlib - -import pytest - - -def test_optional_dependency_exception_message(): - opt = importlib.import_module("cellflow._optional") - # Ensure exception type exists and message contains installation hint - with pytest.raises(opt.OptionalDependencyNotAvailable) as excinfo: - raise opt.OptionalDependencyNotAvailable(opt.torch_required_msg()) - msg = str(excinfo.value) - assert "Optional dependency 'torch' is required" in msg - assert "pip install torch" in msg From 9e56b376bb72ade690f59a9f732ef974648b8e71 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 Aug 2025 11:37:41 +0200 Subject: [PATCH 19/35] fix unintentionally removed line --- tests/model/test_cellflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/model/test_cellflow.py b/tests/model/test_cellflow.py index 6bfbeb3f..024349cb 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_cellflow.py @@ -491,6 +491,7 @@ def test_cellflow_get_condition_embedding( assert out[1].shape[0] == out[1].shape[0] assert out[1].index.name == condition_id_key cond_id_vals = conds[condition_id_key].values + assert out[1].index.isin(cond_id_vals).all() @pytest.mark.parametrize("time_max_period", [None, 10000, -3]) def test_time_embedding( From e67de7d827871169c6555820fbf6a994bbdd2972 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 Aug 2025 12:21:46 +0200 Subject: [PATCH 20/35] ability to add names and tests --- src/cellflow/data/_torch_dataloader.py | 12 ++++-- tests/data/test_torch_dataloader.py | 57 ++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 tests/data/test_torch_dataloader.py diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index 1e2f8050..f78e7f0c 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -34,10 +34,11 @@ class TorchCombinedTrainSampler(TorchIterableDataset): samplers: list[TrainSampler] weights: np.ndarray | None = None rng: np.random.Generator | None = None - + dataset_names: list[str] | None = None def __post_init__(self): if self.weights is None: self.weights = np.ones(len(self.samplers)) + self.weights = np.asarray(self.weights) assert len(self.weights) == len(self.samplers) self.weights = self.weights / self.weights.sum() @@ -50,7 +51,11 @@ def __iter__(self): def __next__(self): if self.rng is None: raise ValueError("Please call set_rng() before using the sampler.") - return self.samplers[self.rng.choice(len(self.samplers), p=self.weights)].sample(self.rng) + dataset_idx = self.rng.choice(len(self.samplers), p=self.weights) + res = self.samplers[dataset_idx].sample(self.rng) + if self.dataset_names is not None: + res["dataset_name"] = self.dataset_names[dataset_idx] + return res @classmethod def combine_zarr_training_samplers( @@ -61,6 +66,7 @@ def combine_zarr_training_samplers( num_workers: int = 4, prefetch_factor: int = 2, weights: np.ndarray | None = None, + dataset_names: list[str] | None = None, ): import torch @@ -69,7 +75,7 @@ def combine_zarr_training_samplers( worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) data = [ZarrTrainingData.read_zarr(path) for path in data_paths] samplers = [TrainSampler(data[i], batch_size) for i in range(len(data))] - combined_sampler = cls(samplers, weights=weights) + combined_sampler = cls(samplers, weights=weights, dataset_names=dataset_names) return torch.utils.data.DataLoader( combined_sampler, batch_size=None, diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py new file mode 100644 index 00000000..1fc031a3 --- /dev/null +++ b/tests/data/test_torch_dataloader.py @@ -0,0 +1,57 @@ +import pytest + +import cellflow +from cellflow.data import TorchCombinedTrainSampler +import torch + +class TestTorchDataloader: + def test_torch_dataloader_shapes( + self, + adata_perturbation, + tmp_path, + ): + solver = "otfm" + sample_rep = "X" + control_key = "control" + perturbation_covariates = {"drug": ["drug1", "drug2"]} + perturbation_covariate_reps = {"drug": "drug"} + batch_size = 18 + + cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf.prepare_data( + sample_rep=sample_rep, + control_key=control_key, + perturbation_covariates=perturbation_covariates, + perturbation_covariate_reps=perturbation_covariate_reps, + ) + assert cf.train_data is not None + assert hasattr(cf, "_data_dim") + cf.train_data.write_zarr(tmp_path / "train_data1.zarr") + cf.train_data.write_zarr(tmp_path / "train_data2.zarr") + cf.train_data.write_zarr(tmp_path / "train_data3.zarr") + + combined_dataloader = TorchCombinedTrainSampler.combine_zarr_training_samplers( + data_paths=[ + tmp_path / "train_data1.zarr", + tmp_path / "train_data2.zarr", + tmp_path / "train_data3.zarr", + ], + batch_size=batch_size, + num_workers=2, + weights=[0.3, 0.3, 0.4], + seed=42, + dataset_names=["train_data1", "train_data2", "train_data3"], + ) + iter_dl = iter(combined_dataloader) + batch = next(iter_dl) + assert "dataset_name" in batch + assert batch["dataset_name"] in ["train_data1", "train_data2", "train_data3"] + assert "src_cell_data" in batch + assert "tgt_cell_data" in batch + assert "condition" in batch + dim = adata_perturbation.shape[1] + assert batch["src_cell_data"].shape == (batch_size, dim) + assert batch["tgt_cell_data"].shape == (batch_size, dim) + assert "drug" in batch["condition"] + drug_dim = adata_perturbation.uns["drug"]["drug_a"].shape[0] + assert batch["condition"]["drug"].shape == (1, len(perturbation_covariates["drug"]), drug_dim) \ No newline at end of file From feae2dd34db9a61b013fdd09a38cab61584900f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 10:21:10 +0000 Subject: [PATCH 21/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/cellflow/data/_torch_dataloader.py | 1 + tests/data/test_torch_dataloader.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index f78e7f0c..721a2c17 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -35,6 +35,7 @@ class TorchCombinedTrainSampler(TorchIterableDataset): weights: np.ndarray | None = None rng: np.random.Generator | None = None dataset_names: list[str] | None = None + def __post_init__(self): if self.weights is None: self.weights = np.ones(len(self.samplers)) diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py index 1fc031a3..fa2a2cf0 100644 --- a/tests/data/test_torch_dataloader.py +++ b/tests/data/test_torch_dataloader.py @@ -1,8 +1,6 @@ -import pytest - import cellflow from cellflow.data import TorchCombinedTrainSampler -import torch + class TestTorchDataloader: def test_torch_dataloader_shapes( @@ -54,4 +52,4 @@ def test_torch_dataloader_shapes( assert batch["tgt_cell_data"].shape == (batch_size, dim) assert "drug" in batch["condition"] drug_dim = adata_perturbation.uns["drug"]["drug_a"].shape[0] - assert batch["condition"]["drug"].shape == (1, len(perturbation_covariates["drug"]), drug_dim) \ No newline at end of file + assert batch["condition"]["drug"].shape == (1, len(perturbation_covariates["drug"]), drug_dim) From 5c611f9c516e61751aaa578d5231ca6f3cda9681 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 17 Sep 2025 11:15:49 +0200 Subject: [PATCH 22/35] bug fix --- src/cellflow/data/_torch_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index 721a2c17..22560ee2 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -72,7 +72,7 @@ def combine_zarr_training_samplers( import torch seq = np.random.SeedSequence(seed) - random_generators = [np.random.default_rng(s) for s in seq.spawn(len(data_paths))] + random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) data = [ZarrTrainingData.read_zarr(path) for path in data_paths] samplers = [TrainSampler(data[i], batch_size) for i in range(len(data))] From 3423c3919185ff63509e9acea46ae1d7657ef4a8 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 19 Sep 2025 19:46:10 +0200 Subject: [PATCH 23/35] add trainsampler with pool --- docs/notebooks/600_trainsampler.ipynb | 308 ++++++++++++++++++++++++++ src/cellflow/data/__init__.py | 2 + src/cellflow/data/_data.py | 32 ++- src/cellflow/data/_dataloader.py | 97 +++++++- 4 files changed, 425 insertions(+), 14 deletions(-) create mode 100644 docs/notebooks/600_trainsampler.ipynb diff --git a/docs/notebooks/600_trainsampler.ipynb b/docs/notebooks/600_trainsampler.ipynb new file mode 100644 index 00000000..0d369e24 --- /dev/null +++ b/docs/notebooks/600_trainsampler.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5765bb6c", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import cellflow as cf\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78c43f9f", + "metadata": {}, + "outputs": [], + "source": [ + "data_paths = [\n", + " \"/lustre/groups/ml01/workspace/100mil/tahoe_train_10000_rep1.zarr\",\n", + " \"/lustre/groups/ml01/workspace/100mil/tahoe_train_55000_rep1.zarr\",\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ed731bd", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import TrainSamplerWithPool, ZarrTrainingData" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62955dea", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def calculate_memory_cost(\n", + " data: ZarrTrainingData,\n", + " src_idx: int,\n", + " include_condition_data: bool = True\n", + ") -> dict[str, int | list | dict]:\n", + " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", + " \n", + " Parameters\n", + " ----------\n", + " data\n", + " The training data.\n", + " src_idx\n", + " The source distribution index.\n", + " include_condition_data\n", + " Whether to include condition data in memory calculations.\n", + " \n", + " Returns\n", + " -------\n", + " Dictionary with memory statistics in bytes for the source and its targets.\n", + " \"\"\"\n", + " if src_idx not in data.control_to_perturbation:\n", + " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", + " \n", + " # Get target indices for this source\n", + " target_indices = data.control_to_perturbation[src_idx]\n", + " \n", + " # Calculate memory for source cells\n", + " source_mask = data.split_covariates_mask == src_idx\n", + " n_source_cells = np.sum(source_mask)\n", + " source_memory = n_source_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", + " \n", + " # Calculate memory for target cells\n", + " target_memories = {}\n", + " total_target_memory = 0\n", + " \n", + " for target_idx in target_indices:\n", + " target_mask = data.perturbation_covariates_mask == target_idx\n", + " n_target_cells = np.sum(target_mask)\n", + " target_memory = n_target_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", + " target_memories[f\"target_{target_idx}\"] = target_memory\n", + " total_target_memory += target_memory\n", + " \n", + " # Calculate condition data memory if available and requested\n", + " condition_memory = 0\n", + " condition_details = {}\n", + " if include_condition_data and data.condition_data is not None:\n", + " for cond_name, cond_array in data.condition_data.items():\n", + " # Condition data is indexed by target indices\n", + " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", + " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", + " condition_memory += relevant_condition_size\n", + " \n", + " # Calculate total memory\n", + " total_memory = source_memory + total_target_memory + condition_memory\n", + " \n", + " # Calculate average target memory\n", + " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", + " \n", + " result = {\n", + " \"source_idx\": src_idx,\n", + " \"target_indices\": target_indices.tolist(),\n", + " \"source_memory\": source_memory,\n", + " \"source_cell_count\": int(n_source_cells),\n", + " \"total_target_memory\": total_target_memory,\n", + " \"avg_target_memory\": avg_target_memory,\n", + " \"condition_memory\": condition_memory,\n", + " \"total_memory\": total_memory,\n", + " \"target_details\": target_memories,\n", + " }\n", + " \n", + " if condition_details:\n", + " result[\"condition_details\"] = condition_details\n", + " \n", + " return result\n", + "\n", + "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", + " \"\"\"Format memory statistics into a human-readable string.\n", + " \n", + " Parameters\n", + " ----------\n", + " memory_stats\n", + " Dictionary with memory statistics from calculate_memory_cost.\n", + " unit\n", + " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", + " If 'auto', the most appropriate unit will be chosen automatically.\n", + " summary\n", + " If True, includes a summary with average, min, and max target memory statistics\n", + " and omits detailed per-target breakdown.\n", + " \n", + " Returns\n", + " -------\n", + " Human-readable string representation of memory statistics.\n", + " \"\"\"\n", + " def format_bytes(bytes_value, unit=\"auto\"):\n", + " if unit == \"auto\":\n", + " # Choose appropriate unit\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " elif unit == \"KB\":\n", + " bytes_value /= 1024\n", + " elif unit == \"MB\":\n", + " bytes_value /= (1024 * 1024)\n", + " elif unit == \"GB\":\n", + " bytes_value /= (1024 * 1024 * 1024)\n", + " \n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " src_idx = memory_stats[\"source_idx\"]\n", + " target_indices = memory_stats[\"target_indices\"]\n", + " \n", + " # Base information\n", + " lines = [\n", + " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", + " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", + " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", + " ]\n", + " \n", + " # Calculate min and max target memory if summary is requested\n", + " if summary and memory_stats[\"target_details\"]:\n", + " target_memories = list(memory_stats[\"target_details\"].values())\n", + " min_target = min(target_memories)\n", + " max_target = max(target_memories)\n", + " \n", + " lines.extend([\n", + " \"\\nTarget memory summary:\",\n", + " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", + " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", + " f\"- Min: {format_bytes(min_target, unit)}\",\n", + " f\"- Max: {format_bytes(max_target, unit)}\",\n", + " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", + " ])\n", + " \n", + " # Add condition memory summary if available\n", + " if memory_stats[\"condition_memory\"] > 0:\n", + " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", + " else:\n", + " # Detailed output (original format)\n", + " lines.extend([\n", + " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", + " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", + " \"\\nTarget details:\"\n", + " ])\n", + " \n", + " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", + " target_id = target_key.split(\"_\")[1]\n", + " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", + " \n", + " if \"condition_details\" in memory_stats:\n", + " lines.append(\"\\nCondition details:\")\n", + " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", + " cond_name = cond_key.split(\"_\", 1)[1]\n", + " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", + " \n", + " return \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "316e3a6a", + "metadata": {}, + "outputs": [], + "source": [ + "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d101216", + "metadata": {}, + "outputs": [], + "source": [ + "stats = calculate_memory_cost(ztd, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a79f9fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics for source index 0 with 194 targets:\n", + "- Source cells: 60135 cells, 68.82 MB\n", + "- Total memory: 548.11 MB\n", + "\n", + "Target memory summary:\n", + "- Total: 479.28 MB\n", + "- Average: 2.47 MB\n", + "- Min: 44.53 KB\n", + "- Max: 6.35 MB\n", + "- Range: 6.31 MB\n", + "\n", + "Condition memory: 4.55 KB\n" + ] + } + ], + "source": [ + "print(format_memory_stats(stats, summary=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c400080", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'ztd' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mztd\u001b[49m\n", + "\u001b[31mNameError\u001b[39m: name 'ztd' is not defined" + ] + } + ], + "source": [ + "ztd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17f1fc6c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index 2de3519c..554e68de 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -14,6 +14,7 @@ from cellflow.data._datamanager import DataManager from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler from cellflow.data._torch_dataloader import TorchCombinedTrainSampler +from cellflow.data._dataloader import TrainSamplerWithPool __all__ = [ "DataManager", @@ -28,4 +29,5 @@ "PredictionSampler", "TorchCombinedTrainSampler", "JaxOutOfCoreTrainSampler", + "TrainSamplerWithPool", ] diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 2670a0b3..0159d81d 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -7,9 +7,9 @@ import anndata as ad import numpy as np import zarr - from cellflow._types import ArrayLike from cellflow.data._utils import write_sharded +from zarr.storage import LocalStore __all__ = [ "BaseDataMixin", @@ -235,6 +235,10 @@ class ValidationData(BaseDataMixin): n_conditions_on_train_end: int | None = None +def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]: + keys = zgroup[key].keys() + return {k: zgroup[key][k] for k in keys} + @dataclass class PredictionData(BaseDataMixin): """Data container to perform prediction. @@ -294,13 +298,21 @@ class ZarrTrainingData(BaseDataMixin): max_combination_length: int def __post_init__(self): - self.control_to_perturbation = {int(k): v for k, v in self.control_to_perturbation.items()} - self.perturbation_idx_to_id = {int(k): v for k, v in self.perturbation_idx_to_id.items()} - self.perturbation_idx_to_covariates = {int(k): v for k, v in self.perturbation_idx_to_covariates.items()} - self.split_idx_to_covariates = {int(k): v for k, v in self.split_idx_to_covariates.items()} + # load everything except cell_data to memory + + # load masks as numpy arrays + self.split_covariates_mask = self.split_covariates_mask[...] + self.perturbation_covariates_mask = self.perturbation_covariates_mask[...] + self.condition_data = {k: np.asarray(v) for k, v in self.condition_data.items()} + self.control_to_perturbation = {int(k): np.asarray(v) for k, v in self.control_to_perturbation.items()} + self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} + self.perturbation_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items()} + self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} @classmethod def read_zarr(cls, path: str) -> ZarrTrainingData: + if isinstance(path, str): + path = LocalStore(path) group = zarr.open_group(path, mode="r") max_len_node = group.get("max_combination_length") if max_len_node is None: @@ -315,10 +327,10 @@ def read_zarr(cls, path: str) -> ZarrTrainingData: cell_data=group["cell_data"], split_covariates_mask=group["split_covariates_mask"], perturbation_covariates_mask=group["perturbation_covariates_mask"], - split_idx_to_covariates=ad.io.read_elem(group["split_idx_to_covariates"]), - perturbation_idx_to_covariates=ad.io.read_elem(group["perturbation_idx_to_covariates"]), - perturbation_idx_to_id=ad.io.read_elem(group["perturbation_idx_to_id"]), - condition_data=ad.io.read_elem(group["condition_data"]), - control_to_perturbation=ad.io.read_elem(group["control_to_perturbation"]), + split_idx_to_covariates=_read_dict(group, "split_idx_to_covariates"), + perturbation_idx_to_covariates=_read_dict(group, "perturbation_idx_to_covariates"), + perturbation_idx_to_id=_read_dict(group, "perturbation_idx_to_id"), + condition_data=_read_dict(group, "condition_data"), + control_to_perturbation=_read_dict(group, "control_to_perturbation"), max_combination_length=max_combination_length, ) diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 8df61581..226543c5 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -43,6 +43,10 @@ def _sample_target_dist_idx(self, source_dist_idx, rng): """Sample a target distribution index given the source distribution index.""" return rng.choice(self._data.control_to_perturbation[source_dist_idx]) + def _sample_source_dist_idx(self, rng) -> int: + """Sample a source distribution index.""" + return rng.choice(self.n_source_dists) + def _get_embeddings(self, idx, condition_data) -> dict[str, np.ndarray]: """Get embeddings for a given index.""" result = {} @@ -54,7 +58,6 @@ def _sample_from_mask(self, rng, mask) -> np.ndarray: """Sample indices according to a mask.""" # Convert mask to probability distribution valid_indices = np.where(mask)[0] - # Handle case with no valid indices (should not happen in practice) if len(valid_indices) == 0: raise ValueError("No valid indices found in the mask") @@ -63,6 +66,7 @@ def _sample_from_mask(self, rng, mask) -> np.ndarray: batch_idcs = rng.choice(valid_indices, self.batch_size, replace=True) return batch_idcs + def sample(self, rng) -> dict[str, Any]: """Sample a batch of data. @@ -76,16 +80,16 @@ def sample(self, rng) -> dict[str, Any]: Dictionary with source and target data """ # Sample source distribution index - source_dist_idx = rng.integers(0, self.n_source_dists) + source_dist_idx = self._sample_source_dist_idx(rng) # Get source cells source_cells_mask = self._data.split_covariates_mask == source_dist_idx source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) - source_batch = self._data.cell_data[source_batch_idcs] - target_dist_idx = self._sample_target_dist_idx(source_dist_idx, rng) target_cells_mask = self._data.perturbation_covariates_mask == target_dist_idx target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) + + source_batch = self._data.cell_data[source_batch_idcs] target_batch = self._data.cell_data[target_batch_idcs] if not self._has_condition_data: @@ -104,6 +108,91 @@ def data(self) -> TrainingData | ZarrTrainingData: return self._data +class TrainSamplerWithPool(TrainSampler): + """Data sampler with gradual pool replacement using reservoir sampling. + + This approach replaces pool elements one by one rather than refreshing + the entire pool, providing better cache locality while maintaining + reasonable randomness. + + Parameters + ---------- + data + The training data. + batch_size + The batch size. + pool_size + The size of the pool of source distribution indices. + replacement_prob + Probability of replacing a pool element after each sample. + Lower values = longer cache retention, less randomness. + Higher values = faster cache turnover, more randomness. + replace_in_pool + Whether to allow replacement when sampling from the pool. + """ + + def __init__( + self, + data: TrainingData | ZarrTrainingData, + batch_size: int = 1024, + pool_size: int = 100, + replacement_prob: float = 0.01, + ): + super().__init__(data, batch_size) + self._pool_size = pool_size + self._replacement_prob = replacement_prob + self._src_idx_pool = np.empty(self._pool_size, dtype=int) + self._pool_usage_count = np.zeros(self._pool_size, dtype=int) + self._all_src_idx_pool = set(range(self.n_source_dists)) + self._initialized = False + + def _init_pool(self, rng): + """Initialize the pool with random source distribution indices.""" + self._src_idx_pool = rng.choice(self.n_source_dists, size=self._pool_size, replace=False) + self._initialized = True + + def _sample_source_dist_idx(self, rng) -> int: + """Sample a source distribution index with gradual pool replacement.""" + if not self._initialized: + self._init_pool(rng) + + # Sample from current pool + pool_idx = rng.choice(self._pool_size) + source_idx = self._src_idx_pool[pool_idx] + + # Increment usage count for monitoring + self._pool_usage_count[pool_idx] += 1 + + # Gradually replace elements based on replacement probability + if rng.random() < self._replacement_prob: + self._replace_pool_element(rng, pool_idx) + + return source_idx + + def _replace_pool_element(self, rng, pool_idx): + """Replace a single pool element with a new one.""" + # Get all indices not currently in the pool + available_indices = list(self._all_src_idx_pool - set(self._src_idx_pool)) + + if available_indices: + # Choose new element (could be weighted by usage count) + new_idx = rng.choice(available_indices) + self._src_idx_pool[pool_idx] = new_idx + self._pool_usage_count[pool_idx] = 0 + + def get_pool_stats(self) -> dict: + """Get statistics about the current pool state.""" + if self._src_idx_pool is None: + return {"pool_size": 0, "avg_usage": 0, "unique_sources": 0} + return { + "pool_size": self._pool_size, + "avg_usage": float(np.mean(self._pool_usage_count)), + "unique_sources": len(set(self._src_idx_pool)), + "pool_elements": self._src_idx_pool.copy(), + "usage_counts": self._pool_usage_count.copy(), + } + + class BaseValidSampler(abc.ABC): @abc.abstractmethod def sample(*args, **kwargs): From 8291b7a5d3ff0145aeb522c2eb5c702725020432 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 20 Sep 2025 10:01:30 +0200 Subject: [PATCH 24/35] save current state --- docs/notebooks/600_trainsampler.ipynb | 198 +++++++++++++++++++++++--- scripts/create_tahoe.py | 32 +++++ src/cellflow/data/_data.py | 4 +- src/cellflow/data/_dataloader.py | 166 ++++++++++++++++----- 4 files changed, 349 insertions(+), 51 deletions(-) create mode 100644 scripts/create_tahoe.py diff --git a/docs/notebooks/600_trainsampler.ipynb b/docs/notebooks/600_trainsampler.ipynb index 0d369e24..431f0634 100644 --- a/docs/notebooks/600_trainsampler.ipynb +++ b/docs/notebooks/600_trainsampler.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "5765bb6c", "metadata": {}, "outputs": [], @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "78c43f9f", "metadata": {}, "outputs": [], @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "3ed731bd", "metadata": {}, "outputs": [], @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "62955dea", "metadata": {}, "outputs": [], @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "316e3a6a", "metadata": {}, "outputs": [], @@ -216,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "3d101216", "metadata": {}, "outputs": [], @@ -226,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "a79f9fc2", "metadata": {}, "outputs": [ @@ -255,31 +255,193 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "8c400080", "metadata": {}, + "outputs": [], + "source": [ + "ztd_stats = {}\n", + "for i in range(ztd.n_controls):\n", + " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "710fb69d", + "metadata": {}, + "outputs": [], + "source": [ + "def print_average_memory_per_source(stats_dict):\n", + " \"\"\"Print the average total memory per source index.\n", + " \n", + " Parameters\n", + " ----------\n", + " stats_dict\n", + " Optional pre-calculated memory statistics dictionary.\n", + " If None, statistics will be calculated for all source indices.\n", + " \"\"\"\n", + " \n", + " \n", + " # Extract total memory for each source index\n", + " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", + " \n", + " # Calculate statistics\n", + " avg_memory = np.mean(total_memories)\n", + " min_memory = np.min(total_memories)\n", + " max_memory = np.max(total_memories)\n", + " median_memory = np.median(total_memories)\n", + " \n", + " # Format the output\n", + " def format_bytes(bytes_value):\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", + " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", + " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", + " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", + " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", + " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", + " \n", + " # Identify source indices with min and max memory\n", + " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " \n", + " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", + " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "e2f8f809", + "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'ztd' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mztd\u001b[49m\n", - "\u001b[31mNameError\u001b[39m: name 'ztd' is not defined" + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics across 50 source indices:\n", + "- Average total memory per source: 423.18 MB\n", + "- Minimum total memory: 4.33 MB\n", + "- Maximum total memory: 1.29 GB\n", + "- Median total memory: 404.51 MB\n", + "- Range: 1.28 GB\n", + "\n", + "Source index with minimum memory: 39 (4.33 MB)\n", + "Source index with maximum memory: 22 (1.29 GB)\n" ] } ], "source": [ - "ztd" + "print_average_memory_per_source(ztd_stats)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, + "id": "91207483", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import TrainSamplerWithPool\n", + "import numpy as np\n", + "rng = np.random.default_rng(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, "id": "17f1fc6c", "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" + ] + } + ], + "source": [ + "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", + "tswp.init_pool_n_cache(rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "782380b2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81017ffd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "replaced 47 with 34\n", + "replaced 32 with 30\n" + ] + } + ], + "source": [ + "import time\n", + "iter_times = []\n", + "rng = np.random.default_rng(0)\n", + "start_time = time.time()\n", + "for iter in range(40):\n", + " batch = tswp.sample(rng)\n", + " end_time = time.time()\n", + " iter_times.append(end_time - start_time)\n", + " start_time = end_time\n", + "\n", + "print(\"average time per iteration: \", np.mean(iter_times))\n", + "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "001e842a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pool_size': 20,\n", + " 'avg_usage': 1.95,\n", + " 'unique_sources': 20,\n", + " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", + " 10, 46, 33]),\n", + " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tswp.get_pool_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f07c55d9", + "metadata": {}, "outputs": [], "source": [] } diff --git a/scripts/create_tahoe.py b/scripts/create_tahoe.py new file mode 100644 index 00000000..f17d1676 --- /dev/null +++ b/scripts/create_tahoe.py @@ -0,0 +1,32 @@ +from sc_exp_design.models import CellFlow +import anndata as ad +import h5py + +from anndata.experimental import read_lazy + +print("loading data") +with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f: + adata_all = ad.AnnData( + obs=ad.io.read_elem(f["obs"]), + var=read_lazy(f["var"]), + uns = read_lazy(f["uns"]), + obsm = read_lazy(f["obsm"]), + ) +cf = CellFlow() + +print(" preparing train data ") +cf.prepare_train_data(adata_all, + sample_rep="X_pca", + control_key="control", + perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)}, + perturbation_covariate_reps={"drugs": "drug_embeddings"}, + sample_covariates=["cell_line"], + sample_covariate_reps={"cell_line": "cell_line_embeddings"}, + split_covariates=["cell_line"]) + + + + +print("writing zarr") +cf.train_data.write_zarr(f"/lustre/groups/ml01/workspace/100mil/tahoe_train_data.zarr") +print("zarr written") diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 0159d81d..cfa39e8f 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -11,6 +11,7 @@ from cellflow.data._utils import write_sharded from zarr.storage import LocalStore + __all__ = [ "BaseDataMixin", "ConditionData", @@ -309,10 +310,11 @@ def __post_init__(self): self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} self.perturbation_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items()} self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} + @classmethod def read_zarr(cls, path: str) -> ZarrTrainingData: if isinstance(path, str): - path = LocalStore(path) + path = LocalStore(path, read_only=True) group = zarr.open_group(path, mode="r") max_len_node = group.get("max_combination_length") if max_len_node is None: diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 226543c5..40636227 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -1,7 +1,9 @@ import abc from typing import Any, Literal +import tqdm import numpy as np +import dask.array as da from cellflow.data._data import ( PredictionData, @@ -39,7 +41,7 @@ def __init__(self, data: TrainingData | ZarrTrainingData, batch_size: int = 1024 self._control_to_perturbation_keys = sorted(data.control_to_perturbation.keys()) self._has_condition_data = data.condition_data is not None - def _sample_target_dist_idx(self, source_dist_idx, rng): + def _sample_target_dist_idx(self, rng, source_dist_idx: int) -> int: """Sample a target distribution index given the source distribution index.""" return rng.choice(self._data.control_to_perturbation[source_dist_idx]) @@ -67,6 +69,32 @@ def _sample_from_mask(self, rng, mask) -> np.ndarray: return batch_idcs + def _get_source_cells_mask(self, source_dist_idx: int) -> np.ndarray: + return self._data.split_covariates_mask == source_dist_idx + + def _get_target_cells_mask(self, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + return self._data.perturbation_covariates_mask == target_dist_idx + + def _sample_source_batch_idcs(self, rng, source_dist_idx: int) -> dict[str, Any]: + source_cells_mask = self._get_source_cells_mask(source_dist_idx) + source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) + return source_batch_idcs + + def _sample_target_batch_idcs(self, rng, source_dist_idx: int, target_dist_idx: int) -> dict[str, Any]: + target_cells_mask = self._get_target_cells_mask(source_dist_idx, target_dist_idx) + target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) + return target_batch_idcs + + def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: + source_cells_mask = self._get_source_cells_mask(source_dist_idx) + source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) + return self._data.cell_data[source_batch_idcs] + + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + target_cells_mask = self._get_target_cells_mask(source_dist_idx, target_dist_idx) + target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) + return self._data.cell_data[target_batch_idcs] + def sample(self, rng) -> dict[str, Any]: """Sample a batch of data. @@ -79,29 +107,19 @@ def sample(self, rng) -> dict[str, Any]: ------- Dictionary with source and target data """ - # Sample source distribution index + # Sample source and target source_dist_idx = self._sample_source_dist_idx(rng) + target_dist_idx = self._sample_target_dist_idx(rng, source_dist_idx) - # Get source cells - source_cells_mask = self._data.split_covariates_mask == source_dist_idx - source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) - target_dist_idx = self._sample_target_dist_idx(source_dist_idx, rng) - target_cells_mask = self._data.perturbation_covariates_mask == target_dist_idx - target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) - - source_batch = self._data.cell_data[source_batch_idcs] - target_batch = self._data.cell_data[target_batch_idcs] + # Sample source and target cells + source_batch = self._sample_source_cells(rng, source_dist_idx) + target_batch = self._sample_target_cells(rng, source_dist_idx, target_dist_idx) - if not self._has_condition_data: - return {"src_cell_data": source_batch, "tgt_cell_data": target_batch} - else: + res = {"src_cell_data": source_batch, "tgt_cell_data": target_batch} + if self._has_condition_data: condition_batch = self._get_embeddings(target_dist_idx, self._data.condition_data) - return { - "src_cell_data": source_batch, - "tgt_cell_data": target_batch, - "condition": condition_batch, - } - + res["condition"] = condition_batch + return res @property def data(self) -> TrainingData | ZarrTrainingData: """The training data.""" @@ -142,10 +160,80 @@ def __init__( self._pool_size = pool_size self._replacement_prob = replacement_prob self._src_idx_pool = np.empty(self._pool_size, dtype=int) - self._pool_usage_count = np.zeros(self._pool_size, dtype=int) - self._all_src_idx_pool = set(range(self.n_source_dists)) + self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int) self._initialized = False + + def _compute_idx_mappings(self): + import cupy as cp + self._tgt_to_cell_data_idcs = [None] * self.n_target_dists + gpu_per_cov_mask = cp.asarray(self._data.perturbation_covariates_mask) + gpu_spl_cov_mask = cp.asarray(self._data.split_covariates_mask) + + for tgt_idx in tqdm.tqdm(range(self.n_target_dists), desc="Computing target to cell data idcs"): + mask = gpu_per_cov_mask == tgt_idx + self._tgt_to_cell_data_idcs[tgt_idx] = cp.where(mask)[0].get() + self._src_to_cell_data_idcs = [None] * self.n_source_dists + for src_idx in tqdm.tqdm(range(self.n_source_dists), desc="Computing source to cell data idcs"): + mask = (gpu_spl_cov_mask == src_idx) + self._src_to_cell_data_idcs[src_idx] = cp.where(mask)[0].get() + + + def init_pool_n_cache(self, rng): + self._compute_idx_mappings() + self._init_pool(rng) + self._init_cache_pool_elements() + + @staticmethod + def _get_target_idx_pool(src_idx_pool: np.ndarray, control_to_perturbation: dict[int, np.ndarray]) -> set[int]: + tgt_idx_pool = set() + for src_idx in src_idx_pool: + tgt_idx_pool.update(control_to_perturbation[src_idx].tolist()) + return tgt_idx_pool + + def _init_cache_pool_elements(self): + if not self._initialized: + raise ValueError("Pool not initialized. Call init_pool_n_cache(rng) first.") + + # Build concatenated row indices and slice maps for sources + src_concat = [] + src_slices: dict[int, slice] = {} + offset = 0 + for src_idx in self._src_idx_pool: + idcs = self._src_to_cell_data_idcs[src_idx] + n = len(idcs) + src_slices[src_idx] = slice(offset, offset + n) + src_concat.append(idcs) + offset += n + src_concat = np.concatenate(src_concat) if len(src_concat) else np.empty((0,), dtype=int) + + # Build concatenated row indices and slice maps for targets + tgt_pool = TrainSamplerWithPool._get_target_idx_pool( + self._src_idx_pool, self._data.control_to_perturbation + ) + tgt_concat = [] + tgt_slices: dict[int, slice] = {} + offset = 0 + for tgt_idx in tqdm.tqdm(sorted(tgt_pool), desc="Caching target cells"): + idcs = self._tgt_to_cell_data_idcs[tgt_idx] + n = len(idcs) + tgt_slices[tgt_idx] = slice(offset, offset + n) + tgt_concat.append(idcs) + offset += n + tgt_concat = np.concatenate(tgt_concat) if len(tgt_concat) else np.empty((0,), dtype=int) + + # Single orthogonal-index reads (fast path) + self._src_block = self._data.cell_data.oindex[src_concat, :] if src_concat.size else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) + self._tgt_block = self._data.cell_data.oindex[tgt_concat, :] if tgt_concat.size else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) + + # Views into the blocks (no extra copies) + self._cached_srcs = {src_idx: self._src_block[sli] for src_idx, sli in src_slices.items()} + tgt_views = {tgt_idx: self._tgt_block[sli] for tgt_idx, sli in tgt_slices.items()} + self._cached_tgts = {src_idx: {tgt_idx: tgt_views[tgt_idx] for tgt_idx in self._data.control_to_perturbation[src_idx] if tgt_idx in tgt_views} + for src_idx in self._src_idx_pool} + self._initialized = True + + def _init_pool(self, rng): """Initialize the pool with random source distribution indices.""" self._src_idx_pool = rng.choice(self.n_source_dists, size=self._pool_size, replace=False) @@ -161,24 +249,33 @@ def _sample_source_dist_idx(self, rng) -> int: source_idx = self._src_idx_pool[pool_idx] # Increment usage count for monitoring - self._pool_usage_count[pool_idx] += 1 + self._pool_usage_count[source_idx] += 1 # Gradually replace elements based on replacement probability if rng.random() < self._replacement_prob: - self._replace_pool_element(rng, pool_idx) + self._replace_pool_element(rng) return source_idx - def _replace_pool_element(self, rng, pool_idx): + def _replace_pool_element(self, rng): """Replace a single pool element with a new one.""" - # Get all indices not currently in the pool - available_indices = list(self._all_src_idx_pool - set(self._src_idx_pool)) - if available_indices: - # Choose new element (could be weighted by usage count) - new_idx = rng.choice(available_indices) - self._src_idx_pool[pool_idx] = new_idx - self._pool_usage_count[pool_idx] = 0 + # instead sample weighted by usage count + # let's only consider the pool_usage_count.min() for least used + # and the pool_usage_count.max() for most used + most_used_weight = (self._pool_usage_count == self._pool_usage_count.max()).astype(float) + most_used_weight /= most_used_weight.sum() + + # weight by most used + replaced_pool_idx = rng.choice(self.n_source_dists, p=most_used_weight) + if replaced_pool_idx in set(self._src_idx_pool): + in_pool_idx = np.where(self._src_idx_pool == replaced_pool_idx)[0][0] + least_used_weight = (self._pool_usage_count == self._pool_usage_count.min()).astype(float) + least_used_weight /= least_used_weight.sum() + new_pool_idx = rng.choice(self.n_source_dists, p=least_used_weight) + self._src_idx_pool[in_pool_idx] = new_pool_idx + print(f"replaced {replaced_pool_idx} with {new_pool_idx}") + def get_pool_stats(self) -> dict: """Get statistics about the current pool state.""" @@ -192,6 +289,11 @@ def get_pool_stats(self) -> dict: "usage_counts": self._pool_usage_count.copy(), } + def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: + return rng.choice(self._cached_srcs[source_dist_idx], size=self.batch_size, replace=True) + + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + return rng.choice(self._cached_tgts[source_dist_idx][target_dist_idx], size=self.batch_size, replace=True) class BaseValidSampler(abc.ABC): @abc.abstractmethod From 2a26de9320bb34385d0203f02358592ffaa7dc92 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Sep 2025 08:01:46 +0000 Subject: [PATCH 25/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/cellflow/data/__init__.py | 2 +- src/cellflow/data/_data.py | 12 ++++---- src/cellflow/data/_dataloader.py | 50 ++++++++++++++++++-------------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index 554e68de..5121b1c3 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -9,12 +9,12 @@ from cellflow.data._dataloader import ( PredictionSampler, TrainSampler, + TrainSamplerWithPool, ValidationSampler, ) from cellflow.data._datamanager import DataManager from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler from cellflow.data._torch_dataloader import TorchCombinedTrainSampler -from cellflow.data._dataloader import TrainSamplerWithPool __all__ = [ "DataManager", diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index cfa39e8f..97d73ea6 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -4,13 +4,12 @@ from dataclasses import dataclass from typing import Any -import anndata as ad import numpy as np import zarr -from cellflow._types import ArrayLike -from cellflow.data._utils import write_sharded from zarr.storage import LocalStore +from cellflow._types import ArrayLike +from cellflow.data._utils import write_sharded __all__ = [ "BaseDataMixin", @@ -240,6 +239,7 @@ def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]: keys = zgroup[key].keys() return {k: zgroup[key][k] for k in keys} + @dataclass class PredictionData(BaseDataMixin): """Data container to perform prediction. @@ -308,9 +308,11 @@ def __post_init__(self): self.condition_data = {k: np.asarray(v) for k, v in self.condition_data.items()} self.control_to_perturbation = {int(k): np.asarray(v) for k, v in self.control_to_perturbation.items()} self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} - self.perturbation_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items()} + self.perturbation_idx_to_covariates = { + int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items() + } self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} - + @classmethod def read_zarr(cls, path: str) -> ZarrTrainingData: if isinstance(path, str): diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 40636227..4061e0bc 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -1,9 +1,8 @@ import abc from typing import Any, Literal -import tqdm import numpy as np -import dask.array as da +import tqdm from cellflow.data._data import ( PredictionData, @@ -68,10 +67,9 @@ def _sample_from_mask(self, rng, mask) -> np.ndarray: batch_idcs = rng.choice(valid_indices, self.batch_size, replace=True) return batch_idcs - def _get_source_cells_mask(self, source_dist_idx: int) -> np.ndarray: return self._data.split_covariates_mask == source_dist_idx - + def _get_target_cells_mask(self, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: return self._data.perturbation_covariates_mask == target_dist_idx @@ -89,7 +87,7 @@ def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: source_cells_mask = self._get_source_cells_mask(source_dist_idx) source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) return self._data.cell_data[source_batch_idcs] - + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: target_cells_mask = self._get_target_cells_mask(source_dist_idx, target_dist_idx) target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) @@ -120,6 +118,7 @@ def sample(self, rng) -> dict[str, Any]: condition_batch = self._get_embeddings(target_dist_idx, self._data.condition_data) res["condition"] = condition_batch return res + @property def data(self) -> TrainingData | ZarrTrainingData: """The training data.""" @@ -163,22 +162,21 @@ def __init__( self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int) self._initialized = False - def _compute_idx_mappings(self): import cupy as cp + self._tgt_to_cell_data_idcs = [None] * self.n_target_dists gpu_per_cov_mask = cp.asarray(self._data.perturbation_covariates_mask) gpu_spl_cov_mask = cp.asarray(self._data.split_covariates_mask) - + for tgt_idx in tqdm.tqdm(range(self.n_target_dists), desc="Computing target to cell data idcs"): mask = gpu_per_cov_mask == tgt_idx self._tgt_to_cell_data_idcs[tgt_idx] = cp.where(mask)[0].get() self._src_to_cell_data_idcs = [None] * self.n_source_dists for src_idx in tqdm.tqdm(range(self.n_source_dists), desc="Computing source to cell data idcs"): - mask = (gpu_spl_cov_mask == src_idx) + mask = gpu_spl_cov_mask == src_idx self._src_to_cell_data_idcs[src_idx] = cp.where(mask)[0].get() - def init_pool_n_cache(self, rng): self._compute_idx_mappings() self._init_pool(rng) @@ -190,7 +188,7 @@ def _get_target_idx_pool(src_idx_pool: np.ndarray, control_to_perturbation: dict for src_idx in src_idx_pool: tgt_idx_pool.update(control_to_perturbation[src_idx].tolist()) return tgt_idx_pool - + def _init_cache_pool_elements(self): if not self._initialized: raise ValueError("Pool not initialized. Call init_pool_n_cache(rng) first.") @@ -208,9 +206,7 @@ def _init_cache_pool_elements(self): src_concat = np.concatenate(src_concat) if len(src_concat) else np.empty((0,), dtype=int) # Build concatenated row indices and slice maps for targets - tgt_pool = TrainSamplerWithPool._get_target_idx_pool( - self._src_idx_pool, self._data.control_to_perturbation - ) + tgt_pool = TrainSamplerWithPool._get_target_idx_pool(self._src_idx_pool, self._data.control_to_perturbation) tgt_concat = [] tgt_slices: dict[int, slice] = {} offset = 0 @@ -223,17 +219,30 @@ def _init_cache_pool_elements(self): tgt_concat = np.concatenate(tgt_concat) if len(tgt_concat) else np.empty((0,), dtype=int) # Single orthogonal-index reads (fast path) - self._src_block = self._data.cell_data.oindex[src_concat, :] if src_concat.size else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) - self._tgt_block = self._data.cell_data.oindex[tgt_concat, :] if tgt_concat.size else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) + self._src_block = ( + self._data.cell_data.oindex[src_concat, :] + if src_concat.size + else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) + ) + self._tgt_block = ( + self._data.cell_data.oindex[tgt_concat, :] + if tgt_concat.size + else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) + ) # Views into the blocks (no extra copies) self._cached_srcs = {src_idx: self._src_block[sli] for src_idx, sli in src_slices.items()} tgt_views = {tgt_idx: self._tgt_block[sli] for tgt_idx, sli in tgt_slices.items()} - self._cached_tgts = {src_idx: {tgt_idx: tgt_views[tgt_idx] for tgt_idx in self._data.control_to_perturbation[src_idx] if tgt_idx in tgt_views} - for src_idx in self._src_idx_pool} + self._cached_tgts = { + src_idx: { + tgt_idx: tgt_views[tgt_idx] + for tgt_idx in self._data.control_to_perturbation[src_idx] + if tgt_idx in tgt_views + } + for src_idx in self._src_idx_pool + } self._initialized = True - def _init_pool(self, rng): """Initialize the pool with random source distribution indices.""" self._src_idx_pool = rng.choice(self.n_source_dists, size=self._pool_size, replace=False) @@ -259,7 +268,6 @@ def _sample_source_dist_idx(self, rng) -> int: def _replace_pool_element(self, rng): """Replace a single pool element with a new one.""" - # instead sample weighted by usage count # let's only consider the pool_usage_count.min() for least used # and the pool_usage_count.max() for most used @@ -275,7 +283,6 @@ def _replace_pool_element(self, rng): new_pool_idx = rng.choice(self.n_source_dists, p=least_used_weight) self._src_idx_pool[in_pool_idx] = new_pool_idx print(f"replaced {replaced_pool_idx} with {new_pool_idx}") - def get_pool_stats(self) -> dict: """Get statistics about the current pool state.""" @@ -291,10 +298,11 @@ def get_pool_stats(self) -> dict: def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: return rng.choice(self._cached_srcs[source_dist_idx], size=self.batch_size, replace=True) - + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: return rng.choice(self._cached_tgts[source_dist_idx][target_dist_idx], size=self.batch_size, replace=True) + class BaseValidSampler(abc.ABC): @abc.abstractmethod def sample(*args, **kwargs): From 7089baefb97648c1945d7240a16b32a93572f436 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 20 Sep 2025 10:18:41 +0200 Subject: [PATCH 26/35] get rid of old module thing --- src/cfp/__init__.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 src/cfp/__init__.py diff --git a/src/cfp/__init__.py b/src/cfp/__init__.py deleted file mode 100644 index 76f0742f..00000000 --- a/src/cfp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cellflow import * # noqa: F403 From a8c5c0905c1d15fa9a96bc597425dcac70043b4a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sat, 20 Sep 2025 11:58:01 +0200 Subject: [PATCH 27/35] new way to save whole tahoe dataset --- .gitignore | 1 + docs/notebooks/600_trainsampler.ipynb | 354 +++++++++++++++++++++++++- scripts/create_tahoe.py | 105 ++++++-- src/cellflow/data/_datamanager.py | 8 +- 4 files changed, 441 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 1fd1f144..9872ae7b 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ tests/plotting/actual_figures/ # huggingface hub/ +docs/notebooks/test.zarr \ No newline at end of file diff --git a/docs/notebooks/600_trainsampler.ipynb b/docs/notebooks/600_trainsampler.ipynb index 431f0634..d8fcf0bf 100644 --- a/docs/notebooks/600_trainsampler.ipynb +++ b/docs/notebooks/600_trainsampler.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "5765bb6c", "metadata": {}, "outputs": [], @@ -15,20 +15,358 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "78c43f9f", + "execution_count": 2, + "id": "e94414bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading data\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", + " return dispatch(args[0].__class__)(*args, **kw)\n" + ] + } + ], + "source": [ + "from cellflow.model import CellFlow\n", + "import anndata as ad\n", + "import h5py\n", + "\n", + "from anndata.experimental import read_lazy\n", + "\n", + "print(\"loading data\")\n", + "with h5py.File(\"/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad\", \"r\") as f:\n", + " adata_all = ad.AnnData(\n", + " obs=ad.io.read_elem(f[\"obs\"]),\n", + " var=read_lazy(f[\"var\"]),\n", + " uns = read_lazy(f[\"uns\"]),\n", + " obsm = read_lazy(f[\"obsm\"]),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8f224240", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[########################################] | 100% Completed | 908.17 ms\n", + "[########################################] | 100% Completed | 21.42 s\n", + "[########################################] | 100% Completed | 375.38 s\n" + ] + } + ], + "source": [ + "from cellflow.data import DataManager\n", + "dm = DataManager(adata_all, \n", + " sample_rep=\"X_pca\",\n", + " control_key=\"control\",\n", + " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", + " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", + " sample_covariates=[\"cell_line\"],\n", + " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", + " split_covariates=[\"cell_line\"],\n", + " max_combination_length=None,\n", + " null_value=0.0\n", + ")\n", + "\n", + "cond_data = dm._get_condition_data(adata=adata_all)\n", + "cell_data = dm._get_cell_data(adata_all)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "c41b2a3b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 75.74it/s]\n", + "Computing target to cell data idcs: 14%|█▍ | 8030/56827 [00:27<02:48, 289.15it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[27]\u001b[39m\u001b[32m, line 21\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m tgt_idx \u001b[38;5;129;01min\u001b[39;00m tqdm.tqdm(\u001b[38;5;28mrange\u001b[39m(n_target_dists), desc=\u001b[33m\"\u001b[39m\u001b[33mComputing target to cell data idcs\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 19\u001b[39m mask = gpu_per_cov_mask == tgt_idx\n\u001b[32m 20\u001b[39m tgt_cell_data[\u001b[38;5;28mstr\u001b[39m(tgt_idx)] = {\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m: \u001b[43mcp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmask\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[32m 22\u001b[39m }\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "import cupy as cp\n", + "import tqdm\n", + "\n", + "n_source_dists = len(cond_data.split_idx_to_covariates)\n", + "n_target_dists = len(cond_data.perturbation_idx_to_covariates)\n", + "\n", + "tgt_cell_data = {}\n", + "src_cell_data = {}\n", + "gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)\n", + "gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)\n", + "\n", + "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data idcs\"):\n", + " mask = gpu_spl_cov_mask == src_idx\n", + " src_cell_data[str(src_idx)] = {\n", + " \"cell_data_index\": cp.where(mask)[0].get(),\n", + " }\n", + "\n", + "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data idcs\"):\n", + " mask = gpu_per_cov_mask == tgt_idx\n", + " tgt_cell_data[str(tgt_idx)] = {\n", + " \"cell_data_index\": cp.where(mask)[0].get(),\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a352c69", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Array Chunk
Bytes 106.87 GiB 1.14 MiB
Shape (95624334, 300) (1000, 300)
Dask graph 95625 chunks in 1 graph layer
Data type float32 numpy.ndarray
\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " 300\n", + " 95624334\n", + "\n", + "
" + ], + "text/plain": [ + "dask.array" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cell_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26c512d3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 246.04it/s]\n", + "Computing target to cell data: 100%|██████████| 56827/56827 [00:08<00:00, 6554.24it/s]\n" + ] + } + ], + "source": [ + "import dask\n", + "\n", + "\n", + "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", + " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " src_cell_data[str(src_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", + "\n", + "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", + " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " tgt_cell_data[str(tgt_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d01e392", "metadata": {}, "outputs": [], "source": [ - "data_paths = [\n", - " \"/lustre/groups/ml01/workspace/100mil/tahoe_train_10000_rep1.zarr\",\n", - " \"/lustre/groups/ml01/workspace/100mil/tahoe_train_55000_rep1.zarr\",\n", - "]\n" + "import numpy as np\n", + "\n", + "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", + "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", + "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", + "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", + "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", + "perturbation_idx_to_covariates = {\n", + " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", + "}\n", + "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "id": "e49deaf9", + "metadata": {}, + "outputs": [], + "source": [ + "train_data_dict = {\n", + " \"split_covariates_mask\": split_covariates_mask,\n", + " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", + " \"split_idx_to_covariates\": split_idx_to_covariates,\n", + " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", + " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", + " \"condition_data\": condition_data,\n", + " \"control_to_perturbation\": control_to_perturbation,\n", + " \"max_combination_length\": int(cond_data.max_combination_length),\n", + " \"src_cell_data\": src_cell_data,\n", + " \"tgt_cell_data\": tgt_cell_data,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32e27b1f", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 6\u001b[39m chunk_size = \u001b[32m65536\u001b[39m\n\u001b[32m 7\u001b[39m shard_size = chunk_size * \u001b[32m16\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m \u001b[43mwrite_sharded\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 9\u001b[39m \u001b[43m \u001b[49m\u001b[43mzgroup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 10\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_data_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 11\u001b[39m \u001b[43m \u001b[49m\u001b[43mchunk_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchunk_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 12\u001b[39m \u001b[43m \u001b[49m\u001b[43mshard_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mshard_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 13\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompressors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:65\u001b[39m, in \u001b[36mwrite_sharded\u001b[39m\u001b[34m(group, data, chunk_size, shard_size, compressors)\u001b[39m\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m 63\u001b[39m func(g, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m---> \u001b[39m\u001b[32m65\u001b[39m \u001b[43mad\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexperimental\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_dispatched\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m/\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 66\u001b[39m zarr.consolidate_metadata(group.store)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/experimental/_dispatch_io.py:77\u001b[39m, in \u001b[36mwrite_dispatched\u001b[39m\u001b[34m(store, key, elem, callback, dataset_kwargs)\u001b[39m\n\u001b[32m 73\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01manndata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m_io\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mspecs\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _REGISTRY, Writer\n\u001b[32m 75\u001b[39m writer = Writer(_REGISTRY, callback=callback)\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m \u001b[43mwriter\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:387\u001b[39m, in \u001b[36mwrite_mapping\u001b[39m\u001b[34m(f, k, v, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 385\u001b[39m g = f.require_group(k)\n\u001b[32m 386\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sub_k, sub_v \u001b[38;5;129;01min\u001b[39;00m v.items():\n\u001b[32m--> \u001b[39m\u001b[32m387\u001b[39m \u001b[43m_writer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_v\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:387\u001b[39m, in \u001b[36mwrite_mapping\u001b[39m\u001b[34m(f, k, v, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 385\u001b[39m g = f.require_group(k)\n\u001b[32m 386\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sub_k, sub_v \u001b[38;5;129;01min\u001b[39;00m v.items():\n\u001b[32m--> \u001b[39m\u001b[32m387\u001b[39m \u001b[43m_writer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_v\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:308\u001b[39m, in \u001b[36mzero_dim_array_as_scalar..func_wrapper\u001b[39m\u001b[34m(f, k, elem, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 306\u001b[39m _writer.write_elem(f, k, elem[()], dataset_kwargs=dataset_kwargs)\n\u001b[32m 307\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m308\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_writer\u001b[49m\u001b[43m=\u001b[49m\u001b[43m_writer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:645\u001b[39m, in \u001b[36mwrite_vlen_string_array_zarr\u001b[39m\u001b[34m(f, k, elem, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 636\u001b[39m filters, fill_value = [VLenUTF8()], \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 637\u001b[39m f.create_array(\n\u001b[32m 638\u001b[39m k,\n\u001b[32m 639\u001b[39m shape=elem.shape,\n\u001b[32m (...)\u001b[39m\u001b[32m 643\u001b[39m **dataset_kwargs,\n\u001b[32m 644\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m645\u001b[39m \u001b[43mf\u001b[49m\u001b[43m[\u001b[49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m = elem\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/array.py:2902\u001b[39m, in \u001b[36mArray.__setitem__\u001b[39m\u001b[34m(self, selection, value)\u001b[39m\n\u001b[32m 2900\u001b[39m \u001b[38;5;28mself\u001b[39m.vindex[cast(\u001b[33m\"\u001b[39m\u001b[33mCoordinateSelection | MaskSelection\u001b[39m\u001b[33m\"\u001b[39m, selection)] = value\n\u001b[32m 2901\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m is_pure_orthogonal_indexing(pure_selection, \u001b[38;5;28mself\u001b[39m.ndim):\n\u001b[32m-> \u001b[39m\u001b[32m2902\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mset_orthogonal_selection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpure_selection\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfields\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2903\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 2904\u001b[39m \u001b[38;5;28mself\u001b[39m.set_basic_selection(cast(\u001b[33m\"\u001b[39m\u001b[33mBasicSelection\u001b[39m\u001b[33m\"\u001b[39m, pure_selection), value, fields=fields)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/array.py:3354\u001b[39m, in \u001b[36mArray.set_orthogonal_selection\u001b[39m\u001b[34m(self, selection, value, fields, prototype)\u001b[39m\n\u001b[32m 3352\u001b[39m prototype = default_buffer_prototype()\n\u001b[32m 3353\u001b[39m indexer = OrthogonalIndexer(selection, \u001b[38;5;28mself\u001b[39m.shape, \u001b[38;5;28mself\u001b[39m.metadata.chunk_grid)\n\u001b[32m-> \u001b[39m\u001b[32m3354\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msync\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3355\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_async_array\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_set_selection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfields\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprototype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprototype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3356\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/sync.py:156\u001b[39m, in \u001b[36msync\u001b[39m\u001b[34m(coro, loop, timeout)\u001b[39m\n\u001b[32m 152\u001b[39m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[32m 154\u001b[39m future = asyncio.run_coroutine_threadsafe(_runner(coro), loop)\n\u001b[32m--> \u001b[39m\u001b[32m156\u001b[39m finished, unfinished = \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfuture\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_when\u001b[49m\u001b[43m=\u001b[49m\u001b[43masyncio\u001b[49m\u001b[43m.\u001b[49m\u001b[43mALL_COMPLETED\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 157\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(unfinished) > \u001b[32m0\u001b[39m:\n\u001b[32m 158\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCoroutine \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcoro\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m failed to finish within \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimeout\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m s\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/concurrent/futures/_base.py:305\u001b[39m, in \u001b[36mwait\u001b[39m\u001b[34m(fs, timeout, return_when)\u001b[39m\n\u001b[32m 301\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m DoneAndNotDoneFutures(done, not_done)\n\u001b[32m 303\u001b[39m waiter = _create_and_install_waiters(fs, return_when)\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43mevent\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 306\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m fs:\n\u001b[32m 307\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m f._condition:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/threading.py:655\u001b[39m, in \u001b[36mEvent.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 653\u001b[39m signaled = \u001b[38;5;28mself\u001b[39m._flag\n\u001b[32m 654\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m signaled:\n\u001b[32m--> \u001b[39m\u001b[32m655\u001b[39m signaled = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cond\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 656\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m signaled\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/threading.py:355\u001b[39m, in \u001b[36mCondition.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 353\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[32m 354\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m355\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 356\u001b[39m gotit = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 357\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "import zarr\n", + "from cellflow.data._utils import write_sharded\n", + "\n", + "path = \"test.zarr\"\n", + "zgroup = zarr.open_group(path, mode=\"w\")\n", + "chunk_size = 65536\n", + "shard_size = chunk_size * 16\n", + "write_sharded(\n", + " zgroup,\n", + " train_data_dict,\n", + " chunk_size=chunk_size,\n", + " shard_size=shard_size,\n", + " compressors=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "3ed731bd", "metadata": {}, "outputs": [], diff --git a/scripts/create_tahoe.py b/scripts/create_tahoe.py index f17d1676..636069bd 100644 --- a/scripts/create_tahoe.py +++ b/scripts/create_tahoe.py @@ -1,8 +1,13 @@ -from sc_exp_design.models import CellFlow import anndata as ad import h5py - +import zarr +from cellflow.data._utils import write_sharded from anndata.experimental import read_lazy +from cellflow.data import DataManager +import cupy as cp +import tqdm +import dask +import numpy as np print("loading data") with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f: @@ -12,21 +17,91 @@ uns = read_lazy(f["uns"]), obsm = read_lazy(f["obsm"]), ) -cf = CellFlow() -print(" preparing train data ") -cf.prepare_train_data(adata_all, - sample_rep="X_pca", - control_key="control", - perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)}, - perturbation_covariate_reps={"drugs": "drug_embeddings"}, - sample_covariates=["cell_line"], - sample_covariate_reps={"cell_line": "cell_line_embeddings"}, - split_covariates=["cell_line"]) +dm = DataManager(adata_all, + sample_rep="X_pca", + control_key="control", + perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)}, + perturbation_covariate_reps={"drugs": "drug_embeddings"}, + sample_covariates=["cell_line"], + sample_covariate_reps={"cell_line": "cell_line_embeddings"}, + split_covariates=["cell_line"], + max_combination_length=None, + null_value=0.0 +) + +cond_data = dm._get_condition_data(adata=adata_all) +cell_data = dm._get_cell_data(adata_all) + + + +n_source_dists = len(cond_data.split_idx_to_covariates) +n_target_dists = len(cond_data.perturbation_idx_to_covariates) + +tgt_cell_data = {} +src_cell_data = {} +gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask) +gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask) + +for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data idcs"): + mask = gpu_spl_cov_mask == src_idx + src_cell_data[str(src_idx)] = { + "cell_data_index": cp.where(mask)[0].get(), + } + +for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data idcs"): + mask = gpu_per_cov_mask == tgt_idx + tgt_cell_data[str(tgt_idx)] = { + "cell_data_index": cp.where(mask)[0].get(), + } + + + +for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"): + indices = src_cell_data[str(src_idx)]["cell_data_index"] + delayed_obj = dask.delayed(lambda x: cell_data[x])(indices) + src_cell_data[str(src_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype) + +for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"): + indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"] + delayed_obj = dask.delayed(lambda x: cell_data[x])(indices) + tgt_cell_data[str(tgt_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype) + +split_covariates_mask = np.asarray(cond_data.split_covariates_mask) +perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask) +condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()} +control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()} +split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()} +perturbation_idx_to_covariates = { + str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items() +} +perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()} +train_data_dict = { + "split_covariates_mask": split_covariates_mask, + "perturbation_covariates_mask": perturbation_covariates_mask, + "split_idx_to_covariates": split_idx_to_covariates, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "perturbation_idx_to_id": perturbation_idx_to_id, + "condition_data": condition_data, + "control_to_perturbation": control_to_perturbation, + "max_combination_length": int(cond_data.max_combination_length), + "src_cell_data": src_cell_data, + "tgt_cell_data": tgt_cell_data, +} -print("writing zarr") -cf.train_data.write_zarr(f"/lustre/groups/ml01/workspace/100mil/tahoe_train_data.zarr") -print("zarr written") +print("writing data") +path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr" +zgroup = zarr.open_group(path, mode="w") +chunk_size = 65536 +shard_size = chunk_size * 16 +write_sharded( + zgroup, + train_data_dict, + chunk_size=chunk_size, + shard_size=shard_size, + compressors=None, +) +print("done") \ No newline at end of file diff --git a/src/cellflow/data/_datamanager.py b/src/cellflow/data/_datamanager.py index 065cddd2..3ccc0793 100644 --- a/src/cellflow/data/_datamanager.py +++ b/src/cellflow/data/_datamanager.py @@ -759,15 +759,15 @@ def _get_cell_data( if sample_rep == "X": sample_rep = adata.X if isinstance(sample_rep, sp.csr_matrix): - return np.asarray(sample_rep.toarray()) + return sample_rep.toarray() else: - return np.asarray(sample_rep) + return sample_rep if isinstance(self._sample_rep, str): if self._sample_rep not in adata.obsm: raise KeyError(f"Sample representation '{self._sample_rep}' not found in `adata.obsm`.") - return np.asarray(adata.obsm[self._sample_rep]) + return adata.obsm[self._sample_rep] attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] - return np.asarray(getattr(adata, attr)[key]) + return getattr(adata, attr)[key] def _verify_control_data(self, adata: anndata.AnnData | None) -> None: if adata is None: From 988803f97bd2f29a2cbe75d01533e7f6a8815770 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 29 Sep 2025 13:10:45 +0200 Subject: [PATCH 28/35] update first working version --- docs/notebooks/600_trainsampler copy 2.ipynb | 1853 +++++++++++++++++ docs/notebooks/600_trainsampler copy.ipynb | 454 ++++ docs/notebooks/600_trainsampler.ipynb | 789 ++----- scripts/{create_tahoe.py => process_tahoe.py} | 133 +- scripts/process_tahoe.sbatch | 17 + src/cellflow/data/__init__.py | 8 +- src/cellflow/data/_data.py | 74 +- src/cellflow/data/_dataloader.py | 113 +- src/cellflow/data/_jax_dataloader.py | 3 +- src/cellflow/data/_torch_dataloader.py | 8 +- src/cellflow/data/_utils.py | 21 +- 11 files changed, 2755 insertions(+), 718 deletions(-) create mode 100644 docs/notebooks/600_trainsampler copy 2.ipynb create mode 100644 docs/notebooks/600_trainsampler copy.ipynb rename scripts/{create_tahoe.py => process_tahoe.py} (50%) create mode 100644 scripts/process_tahoe.sbatch diff --git a/docs/notebooks/600_trainsampler copy 2.ipynb b/docs/notebooks/600_trainsampler copy 2.ipynb new file mode 100644 index 00000000..10e98336 --- /dev/null +++ b/docs/notebooks/600_trainsampler copy 2.ipynb @@ -0,0 +1,1853 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5765bb6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading data\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", + " return dispatch(args[0].__class__)(*args, **kw)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data loaded\n", + "[########################################] | 100% Completed | 1.11 sms\n", + "[########################################] | 100% Completed | 25.93 s\n", + "[########################################] | 100% Completed | 294.61 s\n" + ] + } + ], + "source": [ + "%%\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "\n", + "# %%\n", + "import anndata as ad\n", + "import h5py\n", + "import zarr\n", + "from cellflow.data._utils import write_sharded\n", + "from anndata.experimental import read_lazy\n", + "from cellflow.data import DataManager\n", + "import cupy as cp\n", + "import tqdm\n", + "import dask\n", + "import concurrent.futures\n", + "from functools import partial\n", + "import numpy as np\n", + "import dask.array as da\n", + "from dask.diagnostics import ProgressBar\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05bd4946", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 87.79it/s]\n", + "Computing target to cell data idcs: 68%|██████▊ | 38602/56827 [00:45<00:21, 854.71it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:06<00:00, 852.83it/s]\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "35894e7d", + "metadata": {}, + "outputs": [], + "source": [ + "cell_data = cell_data.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "310df180", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Array Chunk
Bytes 106.87 GiB 1.14 MiB
Shape (95624334, 300) (1000, 300)
Dask graph 95625 chunks in 3 graph layers
Data type float32 numpy.ndarray
\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " 300\n", + " 95624334\n", + "\n", + "
" + ], + "text/plain": [ + "dask.array" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cell_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6303064c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 300)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da.take(cell_data, , axis=0).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "9a9e4693", + "metadata": {}, + "outputs": [], + "source": [ + "cell_data_batch = cell_data[:100_000].compute()" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "8b45658e", + "metadata": {}, + "outputs": [], + "source": [ + "spl_cov_mask_batch = gpu_spl_cov_mask[:100_000]" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "ee2e0760", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", + " 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", + " 33, 34, 35, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],\n", + " dtype=int32)" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cp.unique(spl_cov_mask_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "55ded52a", + "metadata": {}, + "outputs": [], + "source": [ + "mapping = (gpu_per_cov_mask-gpu_spl_cov_mask+50)" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "99ee8158", + "metadata": {}, + "outputs": [], + "source": [ + "sorted_indices = cp.argsort(mapping)\n", + "ordered_mapping = mapping[sorted_indices]" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "615b59c9", + "metadata": {}, + "outputs": [], + "source": [ + "unique_values, inverse_indices = cp.unique(ordered_mapping, return_inverse=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "cd48c25a", + "metadata": {}, + "outputs": [], + "source": [ + "ord_cell_data = da.take(cell_data,sorted_indices.get(),axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "ffe82904", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Array Chunk
Bytes 106.87 GiB 1.21 MiB
Shape (95624334, 300) (1053, 300)
Dask graph 95720 chunks in 4 graph layers
Data type float32 numpy.ndarray
\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " 300\n", + " 95624334\n", + "\n", + "
" + ], + "text/plain": [ + "dask.array" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ord_cell_data" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "a636955a", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[90]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m ord_cell_data = \u001b[43mord_cell_data\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:373\u001b[39m, in \u001b[36mDaskMethodsMixin.compute\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 349\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, **kwargs):\n\u001b[32m 350\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[32m 351\u001b[39m \n\u001b[32m 352\u001b[39m \u001b[33;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 371\u001b[39m \u001b[33;03m dask.compute\u001b[39;00m\n\u001b[32m 372\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m373\u001b[39m (result,) = \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 374\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:681\u001b[39m, in \u001b[36mcompute\u001b[39m\u001b[34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[39m\n\u001b[32m 678\u001b[39m expr = expr.optimize()\n\u001b[32m 679\u001b[39m keys = \u001b[38;5;28mlist\u001b[39m(flatten(expr.__dask_keys__()))\n\u001b[32m--> \u001b[39m\u001b[32m681\u001b[39m results = \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m repack(results)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_task_spec.py:272\u001b[39m, in \u001b[36mconvert_legacy_graph\u001b[39m\u001b[34m(dsk, all_keys)\u001b[39m\n\u001b[32m 270\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(t, GraphNode):\n\u001b[32m 271\u001b[39m t = DataNode(k, t)\n\u001b[32m--> \u001b[39m\u001b[32m272\u001b[39m new_dsk[k] = t\n\u001b[32m 273\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m new_dsk\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "ord_cell_data = ord_cell_data.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c48f9c4c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-1, -1, -1, ..., 32, 31, 38], shape=(95624334,), dtype=int32)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gpu_spl_cov_mask[:100_000].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1fcc7c7", + "metadata": {}, + "outputs": [], + "source": [ + "for k in list(src_cell_data.keys()):\n", + " idx = src_cell_data[k][\"cell_data_index\"]\n", + " src_cell_data[k][\"cell_data\"] = da.take(cell_data, idx, axis=0)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df755b79", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Exception ignored in: >\n", + "Traceback (most recent call last):\n", + " File \"/home/icb/selman.ozleyen/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 790, in _clean_thread_parent_frames\n", + " active_threads = {thread.ident for thread in threading.enumerate()}\n", + " ^^^^^^^^^^^^\n", + "KeyboardInterrupt: \n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(tgt_cell_data.keys()):\n\u001b[32m 2\u001b[39m idx = tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mda\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcell_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/routines.py:2013\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(a, indices, axis)\u001b[39m\n\u001b[32m 2011\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _take_dask_array_from_numpy(a, indices, axis)\n\u001b[32m 2012\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2013\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43ma\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mslice\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/core.py:2038\u001b[39m, in \u001b[36mArray.__getitem__\u001b[39m\u001b[34m(self, index)\u001b[39m\n\u001b[32m 2036\u001b[39m out = \u001b[33m\"\u001b[39m\u001b[33mgetitem-\u001b[39m\u001b[33m\"\u001b[39m + tokenize(\u001b[38;5;28mself\u001b[39m, index2)\n\u001b[32m 2037\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2038\u001b[39m dsk, chunks = \u001b[43mslice_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2039\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m SlicingNoop:\n\u001b[32m 2040\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/contextlib.py:81\u001b[39m, in \u001b[36mContextDecorator.__call__..inner\u001b[39m\u001b[34m(*args, **kwds)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(*args, **kwds):\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._recreate_cm():\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:172\u001b[39m, in \u001b[36mslice_array\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 169\u001b[39m index += (\u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m),) * missing\n\u001b[32m 171\u001b[39m \u001b[38;5;66;03m# Pass down to next function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m dsk_out, bd_out = \u001b[43mslice_with_newaxes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 174\u001b[39m bd_out = \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mtuple\u001b[39m, bd_out))\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk_out, bd_out\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:194\u001b[39m, in \u001b[36mslice_with_newaxes\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 191\u001b[39m where_none[i] -= n\n\u001b[32m 193\u001b[39m \u001b[38;5;66;03m# Pass down and do work\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m194\u001b[39m dsk, blockdims2 = \u001b[43mslice_wrap_lists\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 195\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwhere_none\u001b[49m\n\u001b[32m 196\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 198\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m where_none:\n\u001b[32m 199\u001b[39m expand = expander(where_none)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:271\u001b[39m, in \u001b[36mslice_wrap_lists\u001b[39m\u001b[34m(out_name, in_name, blockdims, index, allow_getitem_optimization)\u001b[39m\n\u001b[32m 269\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(is_arraylike(i) \u001b[38;5;129;01mor\u001b[39;00m i == \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index):\n\u001b[32m 270\u001b[39m axis = where_list[\u001b[32m0\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m271\u001b[39m blockdims2, dsk3 = \u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 272\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m[\u001b[49m\u001b[43mwhere_list\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\n\u001b[32m 273\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 274\u001b[39m \u001b[38;5;66;03m# Mixed case. Both slices/integers and lists. slice/integer then take\u001b[39;00m\n\u001b[32m 275\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 276\u001b[39m \u001b[38;5;66;03m# Do first pass without lists\u001b[39;00m\n\u001b[32m 277\u001b[39m tmp = \u001b[33m\"\u001b[39m\u001b[33mslice-\u001b[39m\u001b[33m\"\u001b[39m + tokenize((out_name, in_name, blockdims, index))\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:630\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(outname, inname, chunks, index, axis)\u001b[39m\n\u001b[32m 623\u001b[39m indexer.append(index[i : i + average_chunk_size].tolist())\n\u001b[32m 625\u001b[39m token = (\n\u001b[32m 626\u001b[39m outname.split(\u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 627\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m outname\n\u001b[32m 628\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m tokenize(outname, chunks, index, axis)\n\u001b[32m 629\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m630\u001b[39m chunks, graph = \u001b[43m_shuffle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 631\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m chunks, graph\n\u001b[32m 632\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(chunks[axis]) == \u001b[32m1\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/_shuffle.py:212\u001b[39m, in \u001b[36m_shuffle\u001b[39m\u001b[34m(chunks, indexer, axis, in_name, out_name, token)\u001b[39m\n\u001b[32m 209\u001b[39m new_chunks.append(current_chunk)\n\u001b[32m 211\u001b[39m \u001b[38;5;66;03m# force 64 bit to avoid potential integer overflows on win32 and numpy<2\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m212\u001b[39m chunk_boundaries = np.cumsum(\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m[\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43muint64\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[32m 214\u001b[39m \u001b[38;5;66;03m# Get existing chunk tuple locations\u001b[39;00m\n\u001b[32m 215\u001b[39m chunk_tuples = \u001b[38;5;28mlist\u001b[39m(\n\u001b[32m 216\u001b[39m product(*(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(c)) \u001b[38;5;28;01mfor\u001b[39;00m i, c \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(chunks) \u001b[38;5;28;01mif\u001b[39;00m i != axis))\n\u001b[32m 217\u001b[39m )\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[64]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(tgt_cell_data.keys()):\n\u001b[32m 2\u001b[39m idx = tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mda\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcell_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/routines.py:2013\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(a, indices, axis)\u001b[39m\n\u001b[32m 2011\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _take_dask_array_from_numpy(a, indices, axis)\n\u001b[32m 2012\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2013\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43ma\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mslice\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/core.py:2038\u001b[39m, in \u001b[36mArray.__getitem__\u001b[39m\u001b[34m(self, index)\u001b[39m\n\u001b[32m 2036\u001b[39m out = \u001b[33m\"\u001b[39m\u001b[33mgetitem-\u001b[39m\u001b[33m\"\u001b[39m + tokenize(\u001b[38;5;28mself\u001b[39m, index2)\n\u001b[32m 2037\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2038\u001b[39m dsk, chunks = \u001b[43mslice_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2039\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m SlicingNoop:\n\u001b[32m 2040\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/contextlib.py:81\u001b[39m, in \u001b[36mContextDecorator.__call__..inner\u001b[39m\u001b[34m(*args, **kwds)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(*args, **kwds):\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._recreate_cm():\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:172\u001b[39m, in \u001b[36mslice_array\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 169\u001b[39m index += (\u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m),) * missing\n\u001b[32m 171\u001b[39m \u001b[38;5;66;03m# Pass down to next function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m dsk_out, bd_out = \u001b[43mslice_with_newaxes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 174\u001b[39m bd_out = \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mtuple\u001b[39m, bd_out))\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk_out, bd_out\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:194\u001b[39m, in \u001b[36mslice_with_newaxes\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 191\u001b[39m where_none[i] -= n\n\u001b[32m 193\u001b[39m \u001b[38;5;66;03m# Pass down and do work\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m194\u001b[39m dsk, blockdims2 = \u001b[43mslice_wrap_lists\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 195\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwhere_none\u001b[49m\n\u001b[32m 196\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 198\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m where_none:\n\u001b[32m 199\u001b[39m expand = expander(where_none)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:271\u001b[39m, in \u001b[36mslice_wrap_lists\u001b[39m\u001b[34m(out_name, in_name, blockdims, index, allow_getitem_optimization)\u001b[39m\n\u001b[32m 269\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(is_arraylike(i) \u001b[38;5;129;01mor\u001b[39;00m i == \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index):\n\u001b[32m 270\u001b[39m axis = where_list[\u001b[32m0\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m271\u001b[39m blockdims2, dsk3 = \u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 272\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m[\u001b[49m\u001b[43mwhere_list\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\n\u001b[32m 273\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 274\u001b[39m \u001b[38;5;66;03m# Mixed case. Both slices/integers and lists. slice/integer then take\u001b[39;00m\n\u001b[32m 275\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 276\u001b[39m \u001b[38;5;66;03m# Do first pass without lists\u001b[39;00m\n\u001b[32m 277\u001b[39m tmp = \u001b[33m\"\u001b[39m\u001b[33mslice-\u001b[39m\u001b[33m\"\u001b[39m + tokenize((out_name, in_name, blockdims, index))\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:630\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(outname, inname, chunks, index, axis)\u001b[39m\n\u001b[32m 623\u001b[39m indexer.append(index[i : i + average_chunk_size].tolist())\n\u001b[32m 625\u001b[39m token = (\n\u001b[32m 626\u001b[39m outname.split(\u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 627\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m outname\n\u001b[32m 628\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m tokenize(outname, chunks, index, axis)\n\u001b[32m 629\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m630\u001b[39m chunks, graph = \u001b[43m_shuffle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 631\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m chunks, graph\n\u001b[32m 632\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(chunks[axis]) == \u001b[32m1\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/_shuffle.py:229\u001b[39m, in \u001b[36m_shuffle\u001b[39m\u001b[34m(chunks, indexer, axis, in_name, out_name, token)\u001b[39m\n\u001b[32m 225\u001b[39m sorter_name = \u001b[33m\"\u001b[39m\u001b[33mshuffle-sorter-\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 226\u001b[39m taker_name = \u001b[33m\"\u001b[39m\u001b[33mshuffle-taker-\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 228\u001b[39m old_blocks = {\n\u001b[32m--> \u001b[39m\u001b[32m229\u001b[39m old_index: (in_name,) + old_index\n\u001b[32m 230\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m old_index \u001b[38;5;129;01min\u001b[39;00m np.ndindex(\u001b[38;5;28mtuple\u001b[39m([\u001b[38;5;28mlen\u001b[39m(c) \u001b[38;5;28;01mfor\u001b[39;00m c \u001b[38;5;129;01min\u001b[39;00m chunks]))\n\u001b[32m 231\u001b[39m }\n\u001b[32m 232\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m new_chunk_idx, new_chunk_taker \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(new_chunks):\n\u001b[32m 233\u001b[39m new_chunk_taker = np.array(new_chunk_taker)\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "for k in list(tgt_cell_data.keys()):\n", + " idx = tgt_cell_data[k][\"cell_data_index\"]\n", + " tgt_cell_data[k][\"cell_data\"] = da.take(cell_data, idx, axis=0)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb2e3a2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Availible mem 34.36\n", + "Using batch size: 2,300,000 rows\n", + "Estimated memory per batch: 5.52 GB\n" + ] + } + ], + "source": [ + "mempool = cp.get_default_memory_pool()\n", + "mempool.set_limit(40 * 1024**3) # Set limit to 40 GB\n", + "batch_size = 2_300_000\n", + "gpu_fraction = 0.8\n", + "available_memory = mempool.get_limit() * gpu_fraction\n", + "\n", + "# Calculate optimal batch size based on memory\n", + "bytes_per_element = cell_data.dtype.itemsize\n", + "elements_per_row = cell_data.shape[1]\n", + "bytes_per_row = bytes_per_element * elements_per_row\n", + "\n", + "# Reserve memory for both input and output\n", + "max_batch_size = int(available_memory / (bytes_per_row * 2))\n", + "actual_batch_size = min(batch_size, max_batch_size)\n", + "\n", + "print(f\"Using batch size: {actual_batch_size:,} rows\")\n", + "print(f\"Estimated memory per batch: {(actual_batch_size * bytes_per_row * 2) / 1e9:.2f} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b36da19", + "metadata": {}, + "outputs": [], + "source": [ + "def process_indices_gpu(indices_dict: Dict, description: str) -> Dict:\n", + " \"\"\"Process a dictionary of indices on GPU\"\"\"\n", + " results = {}\n", + " \n", + " for key in tqdm.tqdm(indices_dict.keys(), desc=description):\n", + " indices = indices_dict[key][\"cell_data_index\"]\n", + " \n", + " if len(indices) == 0:\n", + " results[key] = {\"cell_data\": np.empty((0, cell_data.shape[1]), dtype=cell_data.dtype)}\n", + " continue\n", + " \n", + " # Process in batches if indices are large\n", + " if len(indices) <= actual_batch_size:\n", + " # Small enough to process at once\n", + " gpu_result = process_single_batch_gpu(cell_data, indices)\n", + " results[key] = {\"cell_data\": gpu_result}\n", + " else:\n", + " # Process in multiple batches\n", + " batched_results = []\n", + " n_batches = (len(indices) + actual_batch_size - 1) // actual_batch_size\n", + " \n", + " for batch_idx in range(n_batches):\n", + " start_idx = batch_idx * actual_batch_size\n", + " end_idx = min((batch_idx + 1) * actual_batch_size, len(indices))\n", + " batch_indices = indices[start_idx:end_idx]\n", + " \n", + " batch_result = process_single_batch_gpu(cell_data, batch_indices)\n", + " batched_results.append(batch_result)\n", + " \n", + " # Clear GPU memory between batches\n", + " cp.get_default_memory_pool().free_all_blocks()\n", + " \n", + " # Concatenate results\n", + " final_result = np.concatenate(batched_results, axis=0)\n", + " results[key] = {\"cell_data\": final_result}\n", + " \n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67262e7f", + "metadata": {}, + "outputs": [], + "source": [ + "import cupy as cp\n", + "import numpy as np\n", + "from typing import Dict, List, Tuple\n", + "import gc\n", + "\n", + "def process_cell_data_gpu_batched(\n", + " cell_data,\n", + " src_cell_data: Dict,\n", + " tgt_cell_data: Dict,\n", + " batch_size: int = 20000, # Adjust based on GPU memory\n", + " gpu_memory_fraction: float = 0.8\n", + ") -> Tuple[Dict, Dict]:\n", + " \"\"\"\n", + " Process cell data indexing on GPU in batches to manage memory efficiently.\n", + " \n", + " Parameters\n", + " ----------\n", + " cell_data : dask.array or numpy array\n", + " The main cell data array\n", + " src_cell_data : dict\n", + " Dictionary containing source cell data indices\n", + " tgt_cell_data : dict\n", + " Dictionary containing target cell data indices\n", + " batch_size : int\n", + " Number of rows to process per batch\n", + " gpu_memory_fraction : float\n", + " Fraction of GPU memory to use for cell data\n", + " \n", + " Returns\n", + " -------\n", + " Tuple of updated src_cell_data and tgt_cell_data dictionaries\n", + " \"\"\"\n", + " \n", + " # Get available GPU memory\n", + " mempool = cp.get_default_memory_pool()\n", + " available_memory = mempool.get_limit() * gpu_memory_fraction\n", + " \n", + " # Calculate optimal batch size based on memory\n", + " bytes_per_element = cell_data.dtype.itemsize\n", + " elements_per_row = cell_data.shape[1]\n", + " bytes_per_row = bytes_per_element * elements_per_row\n", + " \n", + " # Reserve memory for both input and output\n", + " max_batch_size = int(available_memory / (bytes_per_row * 2))\n", + " actual_batch_size = min(batch_size, max_batch_size)\n", + " \n", + " \n", + " def process_indices_gpu(indices_dict: Dict, description: str) -> Dict:\n", + " \"\"\"Process a dictionary of indices on GPU\"\"\"\n", + " results = {}\n", + " \n", + " for key in tqdm.tqdm(indices_dict.keys(), desc=description):\n", + " indices = indices_dict[key][\"cell_data_index\"]\n", + " \n", + " if len(indices) == 0:\n", + " results[key] = {\"cell_data\": np.empty((0, cell_data.shape[1]), dtype=cell_data.dtype)}\n", + " continue\n", + " \n", + " # Process in batches if indices are large\n", + " if len(indices) <= actual_batch_size:\n", + " # Small enough to process at once\n", + " gpu_result = process_single_batch_gpu(cell_data, indices)\n", + " results[key] = {\"cell_data\": gpu_result}\n", + " else:\n", + " # Process in multiple batches\n", + " batched_results = []\n", + " n_batches = (len(indices) + actual_batch_size - 1) // actual_batch_size\n", + " \n", + " for batch_idx in range(n_batches):\n", + " start_idx = batch_idx * actual_batch_size\n", + " end_idx = min((batch_idx + 1) * actual_batch_size, len(indices))\n", + " batch_indices = indices[start_idx:end_idx]\n", + " \n", + " batch_result = process_single_batch_gpu(cell_data, batch_indices)\n", + " batched_results.append(batch_result)\n", + " \n", + " # Clear GPU memory between batches\n", + " cp.get_default_memory_pool().free_all_blocks()\n", + " \n", + " # Concatenate results\n", + " final_result = np.concatenate(batched_results, axis=0)\n", + " results[key] = {\"cell_data\": final_result}\n", + " \n", + " return results\n", + " \n", + " def process_single_batch_gpu(data, indices):\n", + " \"\"\"Process a single batch of indices on GPU\"\"\"\n", + " # Move indices to GPU\n", + " gpu_indices = cp.asarray(indices)\n", + " \n", + " # Move data batch to GPU (only the needed rows)\n", + " if hasattr(data, 'compute'): # Dask array\n", + " # For dask arrays, compute only the needed slices\n", + " cpu_batch = data[indices].compute()\n", + " else: # Regular numpy array\n", + " cpu_batch = data[indices]\n", + " \n", + " # Move to GPU and back to CPU\n", + " gpu_batch = cp.asarray(cpu_batch)\n", + " result = cp.asnumpy(gpu_batch)\n", + " \n", + " # Clean up GPU memory\n", + " del gpu_batch, gpu_indices\n", + " cp.get_default_memory_pool().free_all_blocks()\n", + " \n", + " return result\n", + " \n", + " # Process source and target data\n", + " print(\"Processing source cell data on GPU...\")\n", + " src_results = process_indices_gpu(src_cell_data, \"Processing source indices on GPU\")\n", + " \n", + " print(\"Processing target cell data on GPU...\")\n", + " tgt_results = process_indices_gpu(tgt_cell_data, \"Processing target indices on GPU\")\n", + " \n", + " # Update original dictionaries\n", + " for key in src_results:\n", + " src_cell_data[key].update(src_results[key])\n", + " \n", + " for key in tgt_results:\n", + " tgt_cell_data[key].update(tgt_results[key])\n", + " \n", + " # Final memory cleanup\n", + " cp.get_default_memory_pool().free_all_blocks()\n", + " gc.collect()\n", + " \n", + " return src_cell_data, tgt_cell_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "6d1848e8", + "metadata": {}, + "outputs": [], + "source": [ + "for k in list(src_cell_data.keys()):\n", + " idx = src_cell_data[k][\"cell_data_index\"]\n", + " src_cell_data[k][\"cell_data\"] = cell_data[idx]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "dba951a7", + "metadata": {}, + "outputs": [], + "source": [ + "for k in list(tgt_cell_data.keys()):\n", + " idx = tgt_cell_data[k][\"cell_data_index\"]\n", + " tgt_cell_data[k][\"cell_data\"] = cell_data[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "010bd308", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Writing /src_cell_data: 100%|██████████| 50/50 [00:03<00:00, 13.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "done writing src_cell_data\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Writing /tgt_cell_data: 100%|██████████| 56827/56827 [22:05<00:00, 42.86it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "done writing tgt_cell_data\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "path = \"/lustre/groups/ml01/workspace/100mil/tahoe2.zarr\"\n", + "zgroup = zarr.open_group(path, mode=\"w\")\n", + "chunk_size = 131072\n", + "shard_size = chunk_size * 8\n", + "\n", + "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", + "\n", + "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", + " shard_size_used = shard_size\n", + " chunk_size_used = chunk_size\n", + " if chunk_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " elif chunk_size < shape[0] or shard_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " return chunk_size_used, shard_size_used\n", + "\n", + "\n", + "\n", + "\n", + "def write_single_array(group, key, arr, chunk_size, shard_size):\n", + " \"\"\"Write a single array - designed for threading\"\"\"\n", + " chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size)\n", + " \n", + " group.create_array(\n", + " name=key,\n", + " data=arr,\n", + " chunks=(chunk_size_used, arr.shape[1]),\n", + " shards=(shard_size_used, arr.shape[1]),\n", + " compressors=None,\n", + " )\n", + " return key\n", + "\n", + "def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8):\n", + " \"\"\"Write cell data using threading for I/O parallelism\"\"\"\n", + " \n", + " write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size)\n", + " \n", + " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", + " # Submit all write tasks\n", + " future_to_key = {\n", + " executor.submit(write_single_array, group, k, cell_data[k][\"cell_data\"], chunk_size, shard_size): k \n", + " for k in cell_data.keys()\n", + " }\n", + " \n", + " # Process results with progress bar\n", + " for future in tqdm.tqdm(\n", + " concurrent.futures.as_completed(future_to_key), \n", + " total=len(future_to_key),\n", + " desc=f\"Writing {group.name}\"\n", + " ):\n", + " key = future_to_key[future]\n", + " try:\n", + " future.result() # This will raise any exceptions\n", + " except Exception as exc:\n", + " print(f'Array {key} generated an exception: {exc}')\n", + " raise\n", + "\n", + "# %%\n", + "\n", + "\n", + "\n", + "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", + "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", + "\n", + "\n", + "# Use the fast threaded approach you already implemented\n", + "write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=14)\n", + "print(\"done writing src_cell_data\")\n", + "write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=14)\n", + "print(\"done writing tgt_cell_data\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd842ac9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hi\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "3bc2cc9d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 11695.68it/s]\n", + "Computing target to cell data: 100%|██████████| 56827/56827 [00:02<00:00, 27235.88it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[93]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 14\u001b[39m tgt_results = []\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ProgressBar():\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m src_results, tgt_results = \u001b[43mdask\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc_delayed_objs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_delayed_objs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m src_results:\n\u001b[32m 19\u001b[39m src_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = v\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:681\u001b[39m, in \u001b[36mcompute\u001b[39m\u001b[34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[39m\n\u001b[32m 678\u001b[39m expr = expr.optimize()\n\u001b[32m 679\u001b[39m keys = \u001b[38;5;28mlist\u001b[39m(flatten(expr.__dask_keys__()))\n\u001b[32m--> \u001b[39m\u001b[32m681\u001b[39m results = \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m repack(results)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_expr.py:1156\u001b[39m, in \u001b[36m_HLGExprSequence.__dask_graph__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1153\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__dask_graph__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 1154\u001b[39m \u001b[38;5;66;03m# This class has to override this and not just _layer to ensure the HLGs\u001b[39;00m\n\u001b[32m 1155\u001b[39m \u001b[38;5;66;03m# are not optimized individually\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1156\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m ensure_dict(\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_optimized_dsk\u001b[49m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:998\u001b[39m, in \u001b[36mcached_property.__get__\u001b[39m\u001b[34m(self, instance, owner)\u001b[39m\n\u001b[32m 996\u001b[39m val = cache.get(\u001b[38;5;28mself\u001b[39m.attrname, _NOT_FOUND)\n\u001b[32m 997\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val \u001b[38;5;129;01mis\u001b[39;00m _NOT_FOUND:\n\u001b[32m--> \u001b[39m\u001b[32m998\u001b[39m val = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43minstance\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 999\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1000\u001b[39m cache[\u001b[38;5;28mself\u001b[39m.attrname] = val\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_expr.py:1148\u001b[39m, in \u001b[36m_HLGExprSequence._optimized_dsk\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1146\u001b[39m dsk = hlgexpr.hlg\n\u001b[32m 1147\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (optimizer := hlgexpr.low_level_optimizer) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1148\u001b[39m dsk = \u001b[43moptimizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1149\u001b[39m graphs.append(dsk)\n\u001b[32m 1151\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m HighLevelGraph.merge(*graphs)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/delayed.py:669\u001b[39m, in \u001b[36moptimize\u001b[39m\u001b[34m(dsk, keys, **kwargs)\u001b[39m\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dsk, HighLevelGraph):\n\u001b[32m 668\u001b[39m dsk = HighLevelGraph.from_collections(\u001b[38;5;28mid\u001b[39m(dsk), dsk, dependencies=())\n\u001b[32m--> \u001b[39m\u001b[32m669\u001b[39m dsk = \u001b[43mdsk\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 670\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/highlevelgraph.py:769\u001b[39m, in \u001b[36mHighLevelGraph.cull\u001b[39m\u001b[34m(self, keys)\u001b[39m\n\u001b[32m 767\u001b[39m layer = \u001b[38;5;28mself\u001b[39m.layers[layer_name]\n\u001b[32m 768\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m keys_set:\n\u001b[32m--> \u001b[39m\u001b[32m769\u001b[39m culled_layer, culled_deps = \u001b[43mlayer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkeys_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_ext_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 770\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m culled_deps:\n\u001b[32m 771\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/highlevelgraph.py:179\u001b[39m, in \u001b[36mLayer.cull\u001b[39m\u001b[34m(self, keys, all_hlg_keys)\u001b[39m\n\u001b[32m 176\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 177\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdask\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m_task_spec\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m cull\n\u001b[32m--> \u001b[39m\u001b[32m179\u001b[39m out = \u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 180\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m MaterializedLayer(out, annotations=\u001b[38;5;28mself\u001b[39m.annotations), {\n\u001b[32m 181\u001b[39m k: \u001b[38;5;28mset\u001b[39m(v.dependencies) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m out.items()\n\u001b[32m 182\u001b[39m }\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_task_spec.py:1179\u001b[39m, in \u001b[36mcull\u001b[39m\u001b[34m(dsk, keys)\u001b[39m\n\u001b[32m 1177\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) == \u001b[38;5;28mlen\u001b[39m(dsk):\n\u001b[32m 1178\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk\n\u001b[32m-> \u001b[39m\u001b[32m1179\u001b[39m work = \u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1180\u001b[39m seen: \u001b[38;5;28mset\u001b[39m[KeyType] = \u001b[38;5;28mset\u001b[39m()\n\u001b[32m 1181\u001b[39m dsk2 = {}\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "\n", + "\n", + "src_delayed_objs = []\n", + "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", + " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " src_delayed_objs.append((str(src_idx), delayed_obj))\n", + "\n", + "tgt_delayed_objs = []\n", + "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", + " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " tgt_delayed_objs.append((str(tgt_idx), delayed_obj))\n", + "\n", + "src_results = []\n", + "tgt_results = []\n", + "with ProgressBar():\n", + " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", + "\n", + "for k, v in src_results:\n", + " src_cell_data[k][\"cell_data\"] = v\n", + "\n", + "for k, v in tgt_results:\n", + " tgt_cell_data[k][\"cell_data\"] = v\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "721747c3", + "metadata": {}, + "outputs": [], + "source": [ + "src_results = []\n", + "tgt_results = []\n", + "with ProgressBar():\n", + " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", + "\n", + "for k, v in src_results:\n", + " src_cell_data[k][\"cell_data\"] = v\n", + "\n", + "for k, v in tgt_results:\n", + " tgt_cell_data[k][\"cell_data\"] = v" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f9ad908", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# %%\n", + "\n", + "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", + "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", + "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", + "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", + "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", + "perturbation_idx_to_covariates = {\n", + " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", + "}\n", + "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", + "\n", + "train_data_dict = {\n", + " \"split_covariates_mask\": split_covariates_mask,\n", + " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", + " \"split_idx_to_covariates\": split_idx_to_covariates,\n", + " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", + " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", + " \"condition_data\": condition_data,\n", + " \"control_to_perturbation\": control_to_perturbation,\n", + " \"max_combination_length\": int(cond_data.max_combination_length),\n", + " # \"src_cell_data\": src_cell_data,\n", + " # \"tgt_cell_data\": tgt_cell_data,\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "402c899c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "print(\"prepared train_data_dict\")\n", + "# %%\n", + "path = \"/lustre/groups/ml01/workspace/100mil/tahoe2.zarr\"\n", + "zgroup = zarr.open_group(path, mode=\"w\")\n", + "chunk_size = 131072\n", + "shard_size = chunk_size * 8\n", + "\n", + "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ff428dd", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", + " shard_size_used = shard_size\n", + " chunk_size_used = chunk_size\n", + " if chunk_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " elif chunk_size < shape[0] or shard_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " return chunk_size_used, shard_size_used\n", + "\n", + "\n", + "\n", + "\n", + "def write_single_array(group, key, arr, chunk_size, shard_size):\n", + " \"\"\"Write a single array - designed for threading\"\"\"\n", + " chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size)\n", + " \n", + " group.create_array(\n", + " name=key,\n", + " data=arr,\n", + " chunks=(chunk_size_used, arr.shape[1]),\n", + " shards=(shard_size_used, arr.shape[1]),\n", + " compressors=None,\n", + " dtype=arr.dtype,\n", + " )\n", + " return key\n", + "\n", + "def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8):\n", + " \"\"\"Write cell data using threading for I/O parallelism\"\"\"\n", + " \n", + " write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size)\n", + " \n", + " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", + " # Submit all write tasks\n", + " future_to_key = {\n", + " executor.submit(write_single_array, group, k, cell_data[k][\"cell_data\"], chunk_size, shard_size): k \n", + " for k in cell_data.keys()\n", + " }\n", + " \n", + " # Process results with progress bar\n", + " for future in tqdm.tqdm(\n", + " concurrent.futures.as_completed(future_to_key), \n", + " total=len(future_to_key),\n", + " desc=f\"Writing {group.name}\"\n", + " ):\n", + " key = future_to_key[future]\n", + " try:\n", + " future.result() # This will raise any exceptions\n", + " except Exception as exc:\n", + " print(f'Array {key} generated an exception: {exc}')\n", + " raise\n", + "\n", + "# %%\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68a0c4d1", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", + "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", + "\n", + "\n", + "# Use the fast threaded approach you already implemented\n", + "write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=14)\n", + "print(\"done writing src_cell_data\")\n", + "write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=14)\n", + "print(\"done writing tgt_cell_data\")\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "754d6fa7", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# %%\n", + "\n", + "print(\"Writing mapping data\")\n", + "mapping_data = zgroup.create_group(\"mapping_data\", overwrite=True)\n", + "\n", + "write_sharded(\n", + " mapping_data,\n", + " train_data_dict,\n", + " chunk_size=chunk_size,\n", + " shard_size=shard_size,\n", + " compressors=None,\n", + ")\n", + "print(\"done\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e94414bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading data\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", + " return dispatch(args[0].__class__)(*args, **kw)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[########################################] | 100% Completed | 926.66 ms\n", + "[########################################] | 100% Completed | 23.50 s\n", + "[########################################] | 100% Completed | 262.74 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 60.94it/s]\n", + "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:05<00:00, 864.75it/s]\n", + "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 8124.09it/s]\n", + "Computing target to cell data: 100%|██████████| 56827/56827 [00:06<00:00, 9053.27it/s] \n" + ] + } + ], + "source": [ + "import anndata as ad\n", + "import h5py\n", + "import zarr\n", + "from cellflow.data._utils import write_sharded\n", + "from anndata.experimental import read_lazy\n", + "from cellflow.data import DataManager\n", + "import cupy as cp\n", + "import tqdm\n", + "import dask\n", + "import numpy as np\n", + "\n", + "print(\"loading data\")\n", + "with h5py.File(\"/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad\", \"r\") as f:\n", + " adata_all = ad.AnnData(\n", + " obs=ad.io.read_elem(f[\"obs\"]),\n", + " var=read_lazy(f[\"var\"]),\n", + " uns = read_lazy(f[\"uns\"]),\n", + " obsm = read_lazy(f[\"obsm\"]),\n", + " )\n", + "\n", + "dm = DataManager(adata_all, \n", + " sample_rep=\"X_pca\",\n", + " control_key=\"control\",\n", + " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", + " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", + " sample_covariates=[\"cell_line\"],\n", + " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", + " split_covariates=[\"cell_line\"],\n", + " max_combination_length=None,\n", + " null_value=0.0\n", + ")\n", + "\n", + "cond_data = dm._get_condition_data(adata=adata_all)\n", + "cell_data = dm._get_cell_data(adata_all)\n", + "\n", + "\n", + "\n", + "n_source_dists = len(cond_data.split_idx_to_covariates)\n", + "n_target_dists = len(cond_data.perturbation_idx_to_covariates)\n", + "\n", + "tgt_cell_data = {}\n", + "src_cell_data = {}\n", + "gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)\n", + "gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)\n", + "\n", + "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data idcs\"):\n", + " mask = gpu_spl_cov_mask == src_idx\n", + " src_cell_data[str(src_idx)] = {\n", + " \"cell_data_index\": cp.where(mask)[0].get(),\n", + " }\n", + "\n", + "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data idcs\"):\n", + " mask = gpu_per_cov_mask == tgt_idx\n", + " tgt_cell_data[str(tgt_idx)] = {\n", + " \"cell_data_index\": cp.where(mask)[0].get(),\n", + " }\n", + "\n", + "\n", + "\n", + "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", + " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " src_cell_data[str(src_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", + "\n", + "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", + " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", + " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", + " tgt_cell_data[str(tgt_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", + "\n", + "\n", + "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", + "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", + "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", + "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", + "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", + "perturbation_idx_to_covariates = {\n", + " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", + "}\n", + "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", + "\n", + "train_data_dict = {\n", + " \"split_covariates_mask\": split_covariates_mask,\n", + " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", + " \"split_idx_to_covariates\": split_idx_to_covariates,\n", + " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", + " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", + " \"condition_data\": condition_data,\n", + " \"control_to_perturbation\": control_to_perturbation,\n", + " \"max_combination_length\": int(cond_data.max_combination_length),\n", + " # \"src_cell_data\": src_cell_data,\n", + " # \"tgt_cell_data\": tgt_cell_data,\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8f224240", + "metadata": {}, + "outputs": [], + "source": [ + "path = \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", + "zgroup = zarr.open_group(path, mode=\"w\")\n", + "chunk_size = 65536\n", + "shard_size = chunk_size * 16\n", + "\n", + "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", + "\n", + "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", + " shard_size_used = shard_size\n", + " chunk_size_used = chunk_size\n", + " if chunk_size > shape[0] or shard_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " return chunk_size_used, shard_size_used\n", + "\n", + "import dask.array as da\n", + "from dask.diagnostics import ProgressBar\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e8aedd3b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(60135, 300)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "src_cell_data[str(0)][\"cell_data\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "710434e7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Writing src cell data: 0%| | 0/50 [00:00 dict[str, int | list | dict]:\n", + " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", + " \n", + " Parameters\n", + " ----------\n", + " data\n", + " The training data.\n", + " src_idx\n", + " The source distribution index.\n", + " include_condition_data\n", + " Whether to include condition data in memory calculations.\n", + " \n", + " Returns\n", + " -------\n", + " Dictionary with memory statistics in bytes for the source and its targets.\n", + " \"\"\"\n", + " if src_idx not in data.control_to_perturbation:\n", + " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", + " \n", + " # Get target indices for this source\n", + " target_indices = data.control_to_perturbation[src_idx]\n", + " \n", + " # Calculate memory for source cells\n", + " source_mask = data.split_covariates_mask == src_idx\n", + " n_source_cells = np.sum(source_mask)\n", + " source_memory = n_source_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", + " \n", + " # Calculate memory for target cells\n", + " target_memories = {}\n", + " total_target_memory = 0\n", + " \n", + " for target_idx in target_indices:\n", + " target_mask = data.perturbation_covariates_mask == target_idx\n", + " n_target_cells = np.sum(target_mask)\n", + " target_memory = n_target_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", + " target_memories[f\"target_{target_idx}\"] = target_memory\n", + " total_target_memory += target_memory\n", + " \n", + " # Calculate condition data memory if available and requested\n", + " condition_memory = 0\n", + " condition_details = {}\n", + " if include_condition_data and data.condition_data is not None:\n", + " for cond_name, cond_array in data.condition_data.items():\n", + " # Condition data is indexed by target indices\n", + " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", + " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", + " condition_memory += relevant_condition_size\n", + " \n", + " # Calculate total memory\n", + " total_memory = source_memory + total_target_memory + condition_memory\n", + " \n", + " # Calculate average target memory\n", + " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", + " \n", + " result = {\n", + " \"source_idx\": src_idx,\n", + " \"target_indices\": target_indices.tolist(),\n", + " \"source_memory\": source_memory,\n", + " \"source_cell_count\": int(n_source_cells),\n", + " \"total_target_memory\": total_target_memory,\n", + " \"avg_target_memory\": avg_target_memory,\n", + " \"condition_memory\": condition_memory,\n", + " \"total_memory\": total_memory,\n", + " \"target_details\": target_memories,\n", + " }\n", + " \n", + " if condition_details:\n", + " result[\"condition_details\"] = condition_details\n", + " \n", + " return result\n", + "\n", + "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", + " \"\"\"Format memory statistics into a human-readable string.\n", + " \n", + " Parameters\n", + " ----------\n", + " memory_stats\n", + " Dictionary with memory statistics from calculate_memory_cost.\n", + " unit\n", + " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", + " If 'auto', the most appropriate unit will be chosen automatically.\n", + " summary\n", + " If True, includes a summary with average, min, and max target memory statistics\n", + " and omits detailed per-target breakdown.\n", + " \n", + " Returns\n", + " -------\n", + " Human-readable string representation of memory statistics.\n", + " \"\"\"\n", + " def format_bytes(bytes_value, unit=\"auto\"):\n", + " if unit == \"auto\":\n", + " # Choose appropriate unit\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " elif unit == \"KB\":\n", + " bytes_value /= 1024\n", + " elif unit == \"MB\":\n", + " bytes_value /= (1024 * 1024)\n", + " elif unit == \"GB\":\n", + " bytes_value /= (1024 * 1024 * 1024)\n", + " \n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " src_idx = memory_stats[\"source_idx\"]\n", + " target_indices = memory_stats[\"target_indices\"]\n", + " \n", + " # Base information\n", + " lines = [\n", + " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", + " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", + " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", + " ]\n", + " \n", + " # Calculate min and max target memory if summary is requested\n", + " if summary and memory_stats[\"target_details\"]:\n", + " target_memories = list(memory_stats[\"target_details\"].values())\n", + " min_target = min(target_memories)\n", + " max_target = max(target_memories)\n", + " \n", + " lines.extend([\n", + " \"\\nTarget memory summary:\",\n", + " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", + " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", + " f\"- Min: {format_bytes(min_target, unit)}\",\n", + " f\"- Max: {format_bytes(max_target, unit)}\",\n", + " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", + " ])\n", + " \n", + " # Add condition memory summary if available\n", + " if memory_stats[\"condition_memory\"] > 0:\n", + " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", + " else:\n", + " # Detailed output (original format)\n", + " lines.extend([\n", + " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", + " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", + " \"\\nTarget details:\"\n", + " ])\n", + " \n", + " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", + " target_id = target_key.split(\"_\")[1]\n", + " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", + " \n", + " if \"condition_details\" in memory_stats:\n", + " lines.append(\"\\nCondition details:\")\n", + " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", + " cond_name = cond_key.split(\"_\", 1)[1]\n", + " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", + " \n", + " return \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "316e3a6a", + "metadata": {}, + "outputs": [], + "source": [ + "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3d101216", + "metadata": {}, + "outputs": [], + "source": [ + "stats = calculate_memory_cost(ztd, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a79f9fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics for source index 0 with 194 targets:\n", + "- Source cells: 60135 cells, 68.82 MB\n", + "- Total memory: 548.11 MB\n", + "\n", + "Target memory summary:\n", + "- Total: 479.28 MB\n", + "- Average: 2.47 MB\n", + "- Min: 44.53 KB\n", + "- Max: 6.35 MB\n", + "- Range: 6.31 MB\n", + "\n", + "Condition memory: 4.55 KB\n" + ] + } + ], + "source": [ + "print(format_memory_stats(stats, summary=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8c400080", + "metadata": {}, + "outputs": [], + "source": [ + "ztd_stats = {}\n", + "for i in range(ztd.n_controls):\n", + " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "710fb69d", + "metadata": {}, + "outputs": [], + "source": [ + "def print_average_memory_per_source(stats_dict):\n", + " \"\"\"Print the average total memory per source index.\n", + " \n", + " Parameters\n", + " ----------\n", + " stats_dict\n", + " Optional pre-calculated memory statistics dictionary.\n", + " If None, statistics will be calculated for all source indices.\n", + " \"\"\"\n", + " \n", + " \n", + " # Extract total memory for each source index\n", + " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", + " \n", + " # Calculate statistics\n", + " avg_memory = np.mean(total_memories)\n", + " min_memory = np.min(total_memories)\n", + " max_memory = np.max(total_memories)\n", + " median_memory = np.median(total_memories)\n", + " \n", + " # Format the output\n", + " def format_bytes(bytes_value):\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", + " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", + " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", + " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", + " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", + " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", + " \n", + " # Identify source indices with min and max memory\n", + " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " \n", + " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", + " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "e2f8f809", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics across 50 source indices:\n", + "- Average total memory per source: 423.18 MB\n", + "- Minimum total memory: 4.33 MB\n", + "- Maximum total memory: 1.29 GB\n", + "- Median total memory: 404.51 MB\n", + "- Range: 1.28 GB\n", + "\n", + "Source index with minimum memory: 39 (4.33 MB)\n", + "Source index with maximum memory: 22 (1.29 GB)\n" + ] + } + ], + "source": [ + "print_average_memory_per_source(ztd_stats)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "91207483", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import TrainSamplerWithPool\n", + "import numpy as np\n", + "rng = np.random.default_rng(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "17f1fc6c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" + ] + } + ], + "source": [ + "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", + "tswp.init_pool_n_cache(rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "782380b2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81017ffd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "replaced 47 with 34\n", + "replaced 32 with 30\n" + ] + } + ], + "source": [ + "import time\n", + "iter_times = []\n", + "rng = np.random.default_rng(0)\n", + "start_time = time.time()\n", + "for iter in range(40):\n", + " batch = tswp.sample(rng)\n", + " end_time = time.time()\n", + " iter_times.append(end_time - start_time)\n", + " start_time = end_time\n", + "\n", + "print(\"average time per iteration: \", np.mean(iter_times))\n", + "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + ] + }, + { + "cell_type": "markdown", + "id": "fe14be13", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "001e842a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pool_size': 20,\n", + " 'avg_usage': 1.95,\n", + " 'unique_sources': 20,\n", + " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", + " 10, 46, 33]),\n", + " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tswp.get_pool_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f07c55d9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/600_trainsampler copy.ipynb b/docs/notebooks/600_trainsampler copy.ipynb new file mode 100644 index 00000000..a2eb6b44 --- /dev/null +++ b/docs/notebooks/600_trainsampler copy.ipynb @@ -0,0 +1,454 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "5765bb6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from cellflow.data import MappedCellData" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5e77bb94", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "data_path = Path(\"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "cb38a3f8", + "metadata": {}, + "outputs": [], + "source": [ + "mcd = MappedCellData.read_zarr(data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "675044bc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: ,\n", + " 1: ,\n", + " 10: ,\n", + " 11: ,\n", + " 12: ,\n", + " 13: ,\n", + " 14: ,\n", + " 15: ,\n", + " 16: ,\n", + " 17: ,\n", + " 18: ,\n", + " 19: ,\n", + " 2: ,\n", + " 20: ,\n", + " 21: ,\n", + " 22: ,\n", + " 23: ,\n", + " 24: ,\n", + " 25: ,\n", + " 26: ,\n", + " 27: ,\n", + " 28: ,\n", + " 29: ,\n", + " 3: ,\n", + " 30: ,\n", + " 31: ,\n", + " 32: ,\n", + " 33: ,\n", + " 34: ,\n", + " 35: ,\n", + " 36: ,\n", + " 37: ,\n", + " 38: ,\n", + " 39: ,\n", + " 4: ,\n", + " 40: ,\n", + " 41: ,\n", + " 42: ,\n", + " 43: ,\n", + " 44: ,\n", + " 45: ,\n", + " 46: ,\n", + " 47: ,\n", + " 48: ,\n", + " 49: ,\n", + " 5: ,\n", + " 6: ,\n", + " 7: ,\n", + " 8: ,\n", + " 9: }" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcd.src_cell_data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "33793ea8", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import ReservoirSampler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05bd4946", + "metadata": {}, + "outputs": [], + "source": [ + "rs = ReservoirSampler(mcd, batch_size=1024, pool_size=3, replacement_prob=0.01)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b40a9520", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "rng = np.random.default_rng(0)\n", + "rs.init_pool(rng)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "799aad1f", + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "np.int64(37)", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 4\u001b[39m start_time = time.time()\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m \u001b[38;5;28miter\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m40\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m batch = \u001b[43mrs\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrng\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 7\u001b[39m end_time = time.time()\n\u001b[32m 8\u001b[39m iter_times.append(end_time - start_time)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_dataloader.py:114\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(self, rng)\u001b[39m\n\u001b[32m 111\u001b[39m target_dist_idx = \u001b[38;5;28mself\u001b[39m._sample_target_dist_idx(rng, source_dist_idx)\n\u001b[32m 113\u001b[39m \u001b[38;5;66;03m# Sample source and target cells\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m source_batch = \u001b[38;5;28mself\u001b[39m._sample_source_cells(rng, source_dist_idx)\n\u001b[32m 115\u001b[39m target_batch = \u001b[38;5;28mself\u001b[39m._sample_target_cells(rng, source_dist_idx, target_dist_idx)\n\u001b[32m 117\u001b[39m res = {\u001b[33m\"\u001b[39m\u001b[33msrc_cell_data\u001b[39m\u001b[33m\"\u001b[39m: source_batch, \u001b[33m\"\u001b[39m\u001b[33mtgt_cell_data\u001b[39m\u001b[33m\"\u001b[39m: target_batch}\n", + "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_dataloader.py:303\u001b[39m, in \u001b[36m_sample_target_cells\u001b[39m\u001b[34m(self, rng, source_dist_idx, target_dist_idx)\u001b[39m\n\u001b[32m 296\u001b[39m \u001b[38;5;28mself\u001b[39m.perturbation_to_control = \u001b[38;5;28mself\u001b[39m._get_perturbation_to_control(val_data)\n\u001b[32m 297\u001b[39m \u001b[38;5;28mself\u001b[39m.n_conditions_on_log_iteration = (\n\u001b[32m 298\u001b[39m val_data.n_conditions_on_log_iteration\n\u001b[32m 299\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val_data.n_conditions_on_log_iteration \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 300\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m val_data.n_perturbations\n\u001b[32m 301\u001b[39m )\n\u001b[32m 302\u001b[39m \u001b[38;5;28mself\u001b[39m.n_conditions_on_train_end = (\n\u001b[32m--> \u001b[39m\u001b[32m303\u001b[39m val_data.n_conditions_on_train_end\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val_data.n_conditions_on_train_end \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 305\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m val_data.n_perturbations\n\u001b[32m 306\u001b[39m )\n\u001b[32m 307\u001b[39m \u001b[38;5;28mself\u001b[39m.rng = np.random.default_rng(seed)\n\u001b[32m 308\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._data.condition_data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[31mKeyError\u001b[39m: np.int64(37)" + ] + } + ], + "source": [ + "import time\n", + "iter_times = []\n", + "rng = np.random.default_rng(0)\n", + "start_time = time.time()\n", + "for iter in range(40):\n", + " batch = rs.sample(rng)\n", + " end_time = time.time()\n", + " iter_times.append(end_time - start_time)\n", + " start_time = end_time\n", + "\n", + "print(\"average time per iteration: \", np.mean(iter_times))\n", + "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "316e3a6a", + "metadata": {}, + "outputs": [], + "source": [ + "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3d101216", + "metadata": {}, + "outputs": [], + "source": [ + "stats = calculate_memory_cost(ztd, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a79f9fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics for source index 0 with 194 targets:\n", + "- Source cells: 60135 cells, 68.82 MB\n", + "- Total memory: 548.11 MB\n", + "\n", + "Target memory summary:\n", + "- Total: 479.28 MB\n", + "- Average: 2.47 MB\n", + "- Min: 44.53 KB\n", + "- Max: 6.35 MB\n", + "- Range: 6.31 MB\n", + "\n", + "Condition memory: 4.55 KB\n" + ] + } + ], + "source": [ + "print(format_memory_stats(stats, summary=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8c400080", + "metadata": {}, + "outputs": [], + "source": [ + "ztd_stats = {}\n", + "for i in range(ztd.n_controls):\n", + " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "710fb69d", + "metadata": {}, + "outputs": [], + "source": [ + "def print_average_memory_per_source(stats_dict):\n", + " \"\"\"Print the average total memory per source index.\n", + " \n", + " Parameters\n", + " ----------\n", + " stats_dict\n", + " Optional pre-calculated memory statistics dictionary.\n", + " If None, statistics will be calculated for all source indices.\n", + " \"\"\"\n", + " \n", + " \n", + " # Extract total memory for each source index\n", + " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", + " \n", + " # Calculate statistics\n", + " avg_memory = np.mean(total_memories)\n", + " min_memory = np.min(total_memories)\n", + " max_memory = np.max(total_memories)\n", + " median_memory = np.median(total_memories)\n", + " \n", + " # Format the output\n", + " def format_bytes(bytes_value):\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", + " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", + " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", + " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", + " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", + " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", + " \n", + " # Identify source indices with min and max memory\n", + " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " \n", + " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", + " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "e2f8f809", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics across 50 source indices:\n", + "- Average total memory per source: 423.18 MB\n", + "- Minimum total memory: 4.33 MB\n", + "- Maximum total memory: 1.29 GB\n", + "- Median total memory: 404.51 MB\n", + "- Range: 1.28 GB\n", + "\n", + "Source index with minimum memory: 39 (4.33 MB)\n", + "Source index with maximum memory: 22 (1.29 GB)\n" + ] + } + ], + "source": [ + "print_average_memory_per_source(ztd_stats)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "91207483", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import TrainSamplerWithPool\n", + "import numpy as np\n", + "rng = np.random.default_rng(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "17f1fc6c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" + ] + } + ], + "source": [ + "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", + "tswp.init_pool_n_cache(rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "782380b2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81017ffd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "replaced 47 with 34\n", + "replaced 32 with 30\n" + ] + } + ], + "source": [ + "import time\n", + "iter_times = []\n", + "rng = np.random.default_rng(0)\n", + "start_time = time.time()\n", + "for iter in range(40):\n", + " batch = tswp.sample(rng)\n", + " end_time = time.time()\n", + " iter_times.append(end_time - start_time)\n", + " start_time = end_time\n", + "\n", + "print(\"average time per iteration: \", np.mean(iter_times))\n", + "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + ] + }, + { + "cell_type": "markdown", + "id": "fe14be13", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "001e842a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pool_size': 20,\n", + " 'avg_usage': 1.95,\n", + " 'unique_sources': 20,\n", + " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", + " 10, 46, 33]),\n", + " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tswp.get_pool_stats()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f07c55d9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/600_trainsampler.ipynb b/docs/notebooks/600_trainsampler.ipynb index d8fcf0bf..54640851 100644 --- a/docs/notebooks/600_trainsampler.ipynb +++ b/docs/notebooks/600_trainsampler.ipynb @@ -8,9 +8,7 @@ "outputs": [], "source": [ "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import cellflow as cf\n" + "%autoreload 2\n" ] }, { @@ -36,11 +34,16 @@ } ], "source": [ - "from cellflow.model import CellFlow\n", "import anndata as ad\n", "import h5py\n", - "\n", + "import zarr\n", + "from cellflow.data._utils import write_sharded\n", "from anndata.experimental import read_lazy\n", + "from cellflow.data import DataManager\n", + "import cupy as cp\n", + "import tqdm\n", + "import dask\n", + "import numpy as np\n", "\n", "print(\"loading data\")\n", "with h5py.File(\"/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad\", \"r\") as f:\n", @@ -49,73 +52,58 @@ " var=read_lazy(f[\"var\"]),\n", " uns = read_lazy(f[\"uns\"]),\n", " obsm = read_lazy(f[\"obsm\"]),\n", - " )" + " )\n", + "\n", + "dm = DataManager(adata_all, \n", + " sample_rep=\"X_pca\",\n", + " control_key=\"control\",\n", + " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", + " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", + " sample_covariates=[\"cell_line\"],\n", + " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", + " split_covariates=[\"cell_line\"],\n", + " max_combination_length=None,\n", + " null_value=0.0\n", + ")\n" ] }, { "cell_type": "code", "execution_count": 3, - "id": "8f224240", + "id": "37ac0f75", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[########################################] | 100% Completed | 908.17 ms\n", - "[########################################] | 100% Completed | 21.42 s\n", - "[########################################] | 100% Completed | 375.38 s\n" + "[########################################] | 100% Completed | 910.75 ms\n", + "[########################################] | 100% Completed | 23.67 s\n", + "[########################################] | 100% Completed | 252.54 s\n" ] } ], "source": [ - "from cellflow.data import DataManager\n", - "dm = DataManager(adata_all, \n", - " sample_rep=\"X_pca\",\n", - " control_key=\"control\",\n", - " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", - " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", - " sample_covariates=[\"cell_line\"],\n", - " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", - " split_covariates=[\"cell_line\"],\n", - " max_combination_length=None,\n", - " null_value=0.0\n", - ")\n", - "\n", "cond_data = dm._get_condition_data(adata=adata_all)\n", "cell_data = dm._get_cell_data(adata_all)" ] }, { "cell_type": "code", - "execution_count": 27, - "id": "c41b2a3b", + "execution_count": 4, + "id": "e9adbd71", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 75.74it/s]\n", - "Computing target to cell data idcs: 14%|█▍ | 8030/56827 [00:27<02:48, 289.15it/s]\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[27]\u001b[39m\u001b[32m, line 21\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m tgt_idx \u001b[38;5;129;01min\u001b[39;00m tqdm.tqdm(\u001b[38;5;28mrange\u001b[39m(n_target_dists), desc=\u001b[33m\"\u001b[39m\u001b[33mComputing target to cell data idcs\u001b[39m\u001b[33m\"\u001b[39m):\n\u001b[32m 19\u001b[39m mask = gpu_per_cov_mask == tgt_idx\n\u001b[32m 20\u001b[39m tgt_cell_data[\u001b[38;5;28mstr\u001b[39m(tgt_idx)] = {\n\u001b[32m---> \u001b[39m\u001b[32m21\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m: \u001b[43mcp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmask\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[32m 22\u001b[39m }\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " + "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 62.77it/s]\n", + "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:05<00:00, 863.54it/s]\n" ] } ], "source": [ - "import cupy as cp\n", - "import tqdm\n", - "\n", "n_source_dists = len(cond_data.split_idx_to_covariates)\n", "n_target_dists = len(cond_data.perturbation_idx_to_covariates)\n", "\n", @@ -134,144 +122,88 @@ " mask = gpu_per_cov_mask == tgt_idx\n", " tgt_cell_data[str(tgt_idx)] = {\n", " \"cell_data_index\": cp.where(mask)[0].get(),\n", - " }\n" + " }" ] }, { "cell_type": "code", - "execution_count": null, - "id": "5a352c69", + "execution_count": 5, + "id": "dad2d31c", "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Array Chunk
Bytes 106.87 GiB 1.14 MiB
Shape (95624334, 300) (1000, 300)
Dask graph 95625 chunks in 1 graph layer
Data type float32 numpy.ndarray
\n", - "
\n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - "\n", - " \n", - " 300\n", - " 95624334\n", - "\n", - "
" - ], - "text/plain": [ - "dask.array" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cell_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26c512d3", - "metadata": {}, - "outputs": [ + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 22329.13it/s]\n", + "Computing target to cell data: 6%|▌ | 3184/56827 [00:00<00:01, 31833.81it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computing target to cell data: 100%|██████████| 56827/56827 [00:02<00:00, 23426.54it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[##################### ] | 52% Completed | 36m 17ss" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 246.04it/s]\n", - "Computing target to cell data: 100%|██████████| 56827/56827 [00:08<00:00, 6554.24it/s]\n" + "IOStream.flush timed out\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[########################################] | 100% Completed | 73m 49s\n" ] } ], "source": [ - "import dask\n", "\n", + "import dask.array as da\n", + "from dask.diagnostics import ProgressBar\n", "\n", + "src_delayed_objs = []\n", "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " src_cell_data[str(src_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", + " src_delayed_objs.append((str(src_idx), delayed_obj))\n", "\n", + "tgt_delayed_objs = []\n", "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " tgt_cell_data[str(tgt_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n" + " tgt_delayed_objs.append((str(tgt_idx), delayed_obj))\n", + "\n", + "src_results = []\n", + "tgt_results = []\n", + "with ProgressBar():\n", + " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", + "\n", + "for k, v in src_results:\n", + " src_cell_data[k][\"cell_data\"] = v\n", + "\n", + "for k, v in tgt_results:\n", + " tgt_cell_data[k][\"cell_data\"] = v\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "6d01e392", + "execution_count": 6, + "id": "d6c007b8", "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "\n", "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", @@ -281,16 +213,8 @@ "perturbation_idx_to_covariates = {\n", " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", "}\n", - "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e49deaf9", - "metadata": {}, - "outputs": [], - "source": [ + "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", + "\n", "train_data_dict = {\n", " \"split_covariates_mask\": split_covariates_mask,\n", " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", @@ -300,488 +224,179 @@ " \"condition_data\": condition_data,\n", " \"control_to_perturbation\": control_to_perturbation,\n", " \"max_combination_length\": int(cond_data.max_combination_length),\n", - " \"src_cell_data\": src_cell_data,\n", - " \"tgt_cell_data\": tgt_cell_data,\n", - "}" + " # \"src_cell_data\": src_cell_data,\n", + " # \"tgt_cell_data\": tgt_cell_data,\n", + "}\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "32e27b1f", + "execution_count": 7, + "id": "8f224240", "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 6\u001b[39m chunk_size = \u001b[32m65536\u001b[39m\n\u001b[32m 7\u001b[39m shard_size = chunk_size * \u001b[32m16\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m \u001b[43mwrite_sharded\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 9\u001b[39m \u001b[43m \u001b[49m\u001b[43mzgroup\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 10\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_data_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 11\u001b[39m \u001b[43m \u001b[49m\u001b[43mchunk_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchunk_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 12\u001b[39m \u001b[43m \u001b[49m\u001b[43mshard_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mshard_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 13\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompressors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:65\u001b[39m, in \u001b[36mwrite_sharded\u001b[39m\u001b[34m(group, data, chunk_size, shard_size, compressors)\u001b[39m\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m 63\u001b[39m func(g, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m---> \u001b[39m\u001b[32m65\u001b[39m \u001b[43mad\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexperimental\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_dispatched\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m/\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 66\u001b[39m zarr.consolidate_metadata(group.store)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/experimental/_dispatch_io.py:77\u001b[39m, in \u001b[36mwrite_dispatched\u001b[39m\u001b[34m(store, key, elem, callback, dataset_kwargs)\u001b[39m\n\u001b[32m 73\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01manndata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m_io\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mspecs\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _REGISTRY, Writer\n\u001b[32m 75\u001b[39m writer = Writer(_REGISTRY, callback=callback)\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m \u001b[43mwriter\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:387\u001b[39m, in \u001b[36mwrite_mapping\u001b[39m\u001b[34m(f, k, v, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 385\u001b[39m g = f.require_group(k)\n\u001b[32m 386\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sub_k, sub_v \u001b[38;5;129;01min\u001b[39;00m v.items():\n\u001b[32m--> \u001b[39m\u001b[32m387\u001b[39m \u001b[43m_writer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_v\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:387\u001b[39m, in \u001b[36mwrite_mapping\u001b[39m\u001b[34m(f, k, v, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 385\u001b[39m g = f.require_group(k)\n\u001b[32m 386\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sub_k, sub_v \u001b[38;5;129;01min\u001b[39;00m v.items():\n\u001b[32m--> \u001b[39m\u001b[32m387\u001b[39m \u001b[43m_writer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwrite_elem\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_v\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:248\u001b[39m, in \u001b[36mreport_write_key_on_error..func_wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 246\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 250\u001b[39m path = _get_display_path(store)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:393\u001b[39m, in \u001b[36mWriter.write_elem\u001b[39m\u001b[34m(self, store, k, elem, dataset_kwargs, modifiers)\u001b[39m\n\u001b[32m 391\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.callback \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 392\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m write_func(store, k, elem, dataset_kwargs=dataset_kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m393\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 394\u001b[39m \u001b[43m \u001b[49m\u001b[43mwrite_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 395\u001b[39m \u001b[43m \u001b[49m\u001b[43mstore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 396\u001b[39m \u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 397\u001b[39m \u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 398\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 399\u001b[39m \u001b[43m \u001b[49m\u001b[43miospec\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43melem\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 400\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_utils.py:63\u001b[39m, in \u001b[36mwrite_sharded..callback\u001b[39m\u001b[34m(func, g, k, elem, dataset_kwargs, iospec)\u001b[39m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m iospec.encoding_type \u001b[38;5;129;01min\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mcsr_matrix\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mcsc_matrix\u001b[39m\u001b[33m\"\u001b[39m}:\n\u001b[32m 56\u001b[39m dataset_kwargs = {\n\u001b[32m 57\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mshards\u001b[39m\u001b[33m\"\u001b[39m: (shard_size,),\n\u001b[32m 58\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mchunks\u001b[39m\u001b[33m\"\u001b[39m: (chunk_size,),\n\u001b[32m 59\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcompressors\u001b[39m\u001b[33m\"\u001b[39m: compressors,\n\u001b[32m 60\u001b[39m **dataset_kwargs,\n\u001b[32m 61\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m63\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/registry.py:77\u001b[39m, in \u001b[36mwrite_spec..decorator..wrapper\u001b[39m\u001b[34m(g, k, *args, **kwargs)\u001b[39m\n\u001b[32m 75\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 76\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(g: GroupStorageType, k: \u001b[38;5;28mstr\u001b[39m, *args, **kwargs):\n\u001b[32m---> \u001b[39m\u001b[32m77\u001b[39m result = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 78\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-type\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_type)\n\u001b[32m 79\u001b[39m g[k].attrs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mencoding-version\u001b[39m\u001b[33m\"\u001b[39m, spec.encoding_version)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/utils.py:308\u001b[39m, in \u001b[36mzero_dim_array_as_scalar..func_wrapper\u001b[39m\u001b[34m(f, k, elem, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 306\u001b[39m _writer.write_elem(f, k, elem[()], dataset_kwargs=dataset_kwargs)\n\u001b[32m 307\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m308\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43melem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_writer\u001b[49m\u001b[43m=\u001b[49m\u001b[43m_writer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/anndata/_io/specs/methods.py:645\u001b[39m, in \u001b[36mwrite_vlen_string_array_zarr\u001b[39m\u001b[34m(f, k, elem, _writer, dataset_kwargs)\u001b[39m\n\u001b[32m 636\u001b[39m filters, fill_value = [VLenUTF8()], \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 637\u001b[39m f.create_array(\n\u001b[32m 638\u001b[39m k,\n\u001b[32m 639\u001b[39m shape=elem.shape,\n\u001b[32m (...)\u001b[39m\u001b[32m 643\u001b[39m **dataset_kwargs,\n\u001b[32m 644\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m645\u001b[39m \u001b[43mf\u001b[49m\u001b[43m[\u001b[49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m = elem\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/array.py:2902\u001b[39m, in \u001b[36mArray.__setitem__\u001b[39m\u001b[34m(self, selection, value)\u001b[39m\n\u001b[32m 2900\u001b[39m \u001b[38;5;28mself\u001b[39m.vindex[cast(\u001b[33m\"\u001b[39m\u001b[33mCoordinateSelection | MaskSelection\u001b[39m\u001b[33m\"\u001b[39m, selection)] = value\n\u001b[32m 2901\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m is_pure_orthogonal_indexing(pure_selection, \u001b[38;5;28mself\u001b[39m.ndim):\n\u001b[32m-> \u001b[39m\u001b[32m2902\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mset_orthogonal_selection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpure_selection\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfields\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2903\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 2904\u001b[39m \u001b[38;5;28mself\u001b[39m.set_basic_selection(cast(\u001b[33m\"\u001b[39m\u001b[33mBasicSelection\u001b[39m\u001b[33m\"\u001b[39m, pure_selection), value, fields=fields)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/array.py:3354\u001b[39m, in \u001b[36mArray.set_orthogonal_selection\u001b[39m\u001b[34m(self, selection, value, fields, prototype)\u001b[39m\n\u001b[32m 3352\u001b[39m prototype = default_buffer_prototype()\n\u001b[32m 3353\u001b[39m indexer = OrthogonalIndexer(selection, \u001b[38;5;28mself\u001b[39m.shape, \u001b[38;5;28mself\u001b[39m.metadata.chunk_grid)\n\u001b[32m-> \u001b[39m\u001b[32m3354\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msync\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3355\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_async_array\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_set_selection\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfields\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprototype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprototype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3356\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/zarr/core/sync.py:156\u001b[39m, in \u001b[36msync\u001b[39m\u001b[34m(coro, loop, timeout)\u001b[39m\n\u001b[32m 152\u001b[39m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[32m 154\u001b[39m future = asyncio.run_coroutine_threadsafe(_runner(coro), loop)\n\u001b[32m--> \u001b[39m\u001b[32m156\u001b[39m finished, unfinished = \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfuture\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_when\u001b[49m\u001b[43m=\u001b[49m\u001b[43masyncio\u001b[49m\u001b[43m.\u001b[49m\u001b[43mALL_COMPLETED\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 157\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(unfinished) > \u001b[32m0\u001b[39m:\n\u001b[32m 158\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCoroutine \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcoro\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m failed to finish within \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimeout\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m s\u001b[39m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/concurrent/futures/_base.py:305\u001b[39m, in \u001b[36mwait\u001b[39m\u001b[34m(fs, timeout, return_when)\u001b[39m\n\u001b[32m 301\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m DoneAndNotDoneFutures(done, not_done)\n\u001b[32m 303\u001b[39m waiter = _create_and_install_waiters(fs, return_when)\n\u001b[32m--> \u001b[39m\u001b[32m305\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43mevent\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 306\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m fs:\n\u001b[32m 307\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m f._condition:\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/threading.py:655\u001b[39m, in \u001b[36mEvent.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 653\u001b[39m signaled = \u001b[38;5;28mself\u001b[39m._flag\n\u001b[32m 654\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m signaled:\n\u001b[32m--> \u001b[39m\u001b[32m655\u001b[39m signaled = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cond\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 656\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m signaled\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/threading.py:355\u001b[39m, in \u001b[36mCondition.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 353\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[32m 354\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m355\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 356\u001b[39m gotit = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 357\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - } - ], + "outputs": [], "source": [ - "import zarr\n", - "from cellflow.data._utils import write_sharded\n", - "\n", - "path = \"test.zarr\"\n", + "path = \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", "zgroup = zarr.open_group(path, mode=\"w\")\n", "chunk_size = 65536\n", "shard_size = chunk_size * 16\n", - "write_sharded(\n", - " zgroup,\n", - " train_data_dict,\n", - " chunk_size=chunk_size,\n", - " shard_size=shard_size,\n", - " compressors=None,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3ed731bd", - "metadata": {}, - "outputs": [], - "source": [ - "from cellflow.data import TrainSamplerWithPool, ZarrTrainingData" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "62955dea", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", "\n", - "def calculate_memory_cost(\n", - " data: ZarrTrainingData,\n", - " src_idx: int,\n", - " include_condition_data: bool = True\n", - ") -> dict[str, int | list | dict]:\n", - " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", - " \n", - " Parameters\n", - " ----------\n", - " data\n", - " The training data.\n", - " src_idx\n", - " The source distribution index.\n", - " include_condition_data\n", - " Whether to include condition data in memory calculations.\n", - " \n", - " Returns\n", - " -------\n", - " Dictionary with memory statistics in bytes for the source and its targets.\n", - " \"\"\"\n", - " if src_idx not in data.control_to_perturbation:\n", - " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", - " \n", - " # Get target indices for this source\n", - " target_indices = data.control_to_perturbation[src_idx]\n", - " \n", - " # Calculate memory for source cells\n", - " source_mask = data.split_covariates_mask == src_idx\n", - " n_source_cells = np.sum(source_mask)\n", - " source_memory = n_source_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", - " \n", - " # Calculate memory for target cells\n", - " target_memories = {}\n", - " total_target_memory = 0\n", - " \n", - " for target_idx in target_indices:\n", - " target_mask = data.perturbation_covariates_mask == target_idx\n", - " n_target_cells = np.sum(target_mask)\n", - " target_memory = n_target_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", - " target_memories[f\"target_{target_idx}\"] = target_memory\n", - " total_target_memory += target_memory\n", - " \n", - " # Calculate condition data memory if available and requested\n", - " condition_memory = 0\n", - " condition_details = {}\n", - " if include_condition_data and data.condition_data is not None:\n", - " for cond_name, cond_array in data.condition_data.items():\n", - " # Condition data is indexed by target indices\n", - " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", - " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", - " condition_memory += relevant_condition_size\n", - " \n", - " # Calculate total memory\n", - " total_memory = source_memory + total_target_memory + condition_memory\n", - " \n", - " # Calculate average target memory\n", - " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", - " \n", - " result = {\n", - " \"source_idx\": src_idx,\n", - " \"target_indices\": target_indices.tolist(),\n", - " \"source_memory\": source_memory,\n", - " \"source_cell_count\": int(n_source_cells),\n", - " \"total_target_memory\": total_target_memory,\n", - " \"avg_target_memory\": avg_target_memory,\n", - " \"condition_memory\": condition_memory,\n", - " \"total_memory\": total_memory,\n", - " \"target_details\": target_memories,\n", - " }\n", - " \n", - " if condition_details:\n", - " result[\"condition_details\"] = condition_details\n", - " \n", - " return result\n", + "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", + "\n", + "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", + " shard_size_used = shard_size\n", + " chunk_size_used = chunk_size\n", + " if chunk_size > shape[0] or shard_size > shape[0]:\n", + " chunk_size_used = shard_size_used = shape[0]\n", + " return chunk_size_used, shard_size_used\n", "\n", - "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", - " \"\"\"Format memory statistics into a human-readable string.\n", - " \n", - " Parameters\n", - " ----------\n", - " memory_stats\n", - " Dictionary with memory statistics from calculate_memory_cost.\n", - " unit\n", - " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", - " If 'auto', the most appropriate unit will be chosen automatically.\n", - " summary\n", - " If True, includes a summary with average, min, and max target memory statistics\n", - " and omits detailed per-target breakdown.\n", - " \n", - " Returns\n", - " -------\n", - " Human-readable string representation of memory statistics.\n", - " \"\"\"\n", - " def format_bytes(bytes_value, unit=\"auto\"):\n", - " if unit == \"auto\":\n", - " # Choose appropriate unit\n", - " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", - " if bytes_value < 1024 or unit == \"GB\":\n", - " break\n", - " bytes_value /= 1024\n", - " elif unit == \"KB\":\n", - " bytes_value /= 1024\n", - " elif unit == \"MB\":\n", - " bytes_value /= (1024 * 1024)\n", - " elif unit == \"GB\":\n", - " bytes_value /= (1024 * 1024 * 1024)\n", - " \n", - " return f\"{bytes_value:.2f} {unit}\"\n", - " \n", - " src_idx = memory_stats[\"source_idx\"]\n", - " target_indices = memory_stats[\"target_indices\"]\n", - " \n", - " # Base information\n", - " lines = [\n", - " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", - " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", - " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", - " ]\n", - " \n", - " # Calculate min and max target memory if summary is requested\n", - " if summary and memory_stats[\"target_details\"]:\n", - " target_memories = list(memory_stats[\"target_details\"].values())\n", - " min_target = min(target_memories)\n", - " max_target = max(target_memories)\n", - " \n", - " lines.extend([\n", - " \"\\nTarget memory summary:\",\n", - " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", - " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", - " f\"- Min: {format_bytes(min_target, unit)}\",\n", - " f\"- Max: {format_bytes(max_target, unit)}\",\n", - " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", - " ])\n", - " \n", - " # Add condition memory summary if available\n", - " if memory_stats[\"condition_memory\"] > 0:\n", - " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", - " else:\n", - " # Detailed output (original format)\n", - " lines.extend([\n", - " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", - " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", - " \"\\nTarget details:\"\n", - " ])\n", - " \n", - " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", - " target_id = target_key.split(\"_\")[1]\n", - " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", - " \n", - " if \"condition_details\" in memory_stats:\n", - " lines.append(\"\\nCondition details:\")\n", - " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", - " cond_name = cond_key.split(\"_\", 1)[1]\n", - " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", - " \n", - " return \"\\n\".join(lines)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "316e3a6a", - "metadata": {}, - "outputs": [], - "source": [ - "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", "\n" ] }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3d101216", - "metadata": {}, - "outputs": [], - "source": [ - "stats = calculate_memory_cost(ztd, 0)" - ] - }, { "cell_type": "code", "execution_count": 8, - "id": "a79f9fc2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory statistics for source index 0 with 194 targets:\n", - "- Source cells: 60135 cells, 68.82 MB\n", - "- Total memory: 548.11 MB\n", - "\n", - "Target memory summary:\n", - "- Total: 479.28 MB\n", - "- Average: 2.47 MB\n", - "- Min: 44.53 KB\n", - "- Max: 6.35 MB\n", - "- Range: 6.31 MB\n", - "\n", - "Condition memory: 4.55 KB\n" - ] - } - ], - "source": [ - "print(format_memory_stats(stats, summary=True))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "8c400080", - "metadata": {}, - "outputs": [], - "source": [ - "ztd_stats = {}\n", - "for i in range(ztd.n_controls):\n", - " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "710fb69d", - "metadata": {}, - "outputs": [], - "source": [ - "def print_average_memory_per_source(stats_dict):\n", - " \"\"\"Print the average total memory per source index.\n", - " \n", - " Parameters\n", - " ----------\n", - " stats_dict\n", - " Optional pre-calculated memory statistics dictionary.\n", - " If None, statistics will be calculated for all source indices.\n", - " \"\"\"\n", - " \n", - " \n", - " # Extract total memory for each source index\n", - " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", - " \n", - " # Calculate statistics\n", - " avg_memory = np.mean(total_memories)\n", - " min_memory = np.min(total_memories)\n", - " max_memory = np.max(total_memories)\n", - " median_memory = np.median(total_memories)\n", - " \n", - " # Format the output\n", - " def format_bytes(bytes_value):\n", - " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", - " if bytes_value < 1024 or unit == \"GB\":\n", - " break\n", - " bytes_value /= 1024\n", - " return f\"{bytes_value:.2f} {unit}\"\n", - " \n", - " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", - " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", - " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", - " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", - " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", - " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", - " \n", - " # Identify source indices with min and max memory\n", - " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " \n", - " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", - " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "e2f8f809", + "id": "710434e7", "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Memory statistics across 50 source indices:\n", - "- Average total memory per source: 423.18 MB\n", - "- Minimum total memory: 4.33 MB\n", - "- Maximum total memory: 1.29 GB\n", - "- Median total memory: 404.51 MB\n", - "- Range: 1.28 GB\n", - "\n", - "Source index with minimum memory: 39 (4.33 MB)\n", - "Source index with maximum memory: 22 (1.29 GB)\n" + "Allocating cell data: 14%|█▍ | 7/50 [00:45<04:47, 6.69s/it]" ] - } - ], - "source": [ - "print_average_memory_per_source(ztd_stats)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "91207483", - "metadata": {}, - "outputs": [], - "source": [ - "from cellflow.data import TrainSamplerWithPool\n", - "import numpy as np\n", - "rng = np.random.default_rng(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "17f1fc6c", - "metadata": {}, - "outputs": [ + }, { "name": "stderr", "output_type": "stream", "text": [ - "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" + "Allocating cell data: 100%|██████████| 50/50 [06:47<00:00, 8.15s/it]\n", + "Allocating cell data: 100%|██████████| 56827/56827 [41:54<00:00, 22.60it/s] \n" ] } ], "source": [ - "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", - "tswp.init_pool_n_cache(rng)" + "\n", + "def write_arr(z_arr, arr, k):\n", + " z_arr[:] = arr\n", + " return k\n", + "\n", + "def allocate_cell_data(group, cell_data, chunk_size, shard_size):\n", + " delayed_objs = []\n", + "\n", + " for k in tqdm.tqdm(cell_data.keys(), desc=\"Allocating cell data\"):\n", + " chunk_size_used, shard_size_used = get_size(cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", + " arr = cell_data[k][\"cell_data\"]\n", + "\n", + " z_arr = group.create_array(\n", + " name=k,\n", + " shape=arr.shape,\n", + " chunks=(chunk_size_used, arr.shape[1]),\n", + " shards=(shard_size_used, arr.shape[1]),\n", + " compressors=None,\n", + " dtype=arr.dtype,\n", + " )\n", + "\n", + " delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))\n", + " \n", + " return delayed_objs\n", + "\n", + "\n", + "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", + "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", + "\n", + "\n", + "src_delayed_objs = allocate_cell_data(src_group, src_cell_data, chunk_size, shard_size)\n", + "tgt_delayed_objs = allocate_cell_data(tgt_group, tgt_cell_data, chunk_size, shard_size)\n", + "\n", + "\n", + "\n", + "# for k in tqdm.tqdm(src_cell_data.keys(), desc=\"Writing src cell data\"):\n", + "# chunk_size_used, shard_size_used = get_size(src_cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", + "# arr = src_cell_data[k][\"cell_data\"]\n", + "\n", + "# z_arr = src_group.create_array(\n", + "# name=k,\n", + "# shape=arr.shape,\n", + "# chunks=(chunk_size_used, arr.shape[1]),\n", + "# shards=(shard_size_used, arr.shape[1]),\n", + "# compressors=None,\n", + "# dtype=arr.dtype,\n", + "# )\n", + " \n", + "# delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))\n", + "\n", + "# for k in tqdm.tqdm(tgt_cell_data.keys(), desc=\"Writing tgt cell data\"):\n", + "# chunk_size_used, shard_size_used = get_size(tgt_cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", + "# arr = tgt_cell_data[k][\"cell_data\"]\n", + "# z_arr = tgt_group.create_array(\n", + "# name=k,\n", + "# shape=arr.shape,\n", + "# chunks=(chunk_size_used, arr.shape[1]),\n", + "# shards=(shard_size_used, arr.shape[1]),\n", + "# compressors=None,\n", + "# dtype=arr.dtype,\n", + "# )\n", + " \n", + " \n", + "# delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))" ] }, { "cell_type": "code", - "execution_count": null, - "id": "782380b2", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81017ffd", + "execution_count": 1, + "id": "c41b2a3b", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "replaced 47 with 34\n", - "replaced 32 with 30\n" + "ename": "NameError", + "evalue": "name 'ProgressBar' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mProgressBar\u001b[49m():\n\u001b[32m 2\u001b[39m res = dask.compute(tgt_delayed_objs)\n", + "\u001b[31mNameError\u001b[39m: name 'ProgressBar' is not defined" ] } ], "source": [ - "import time\n", - "iter_times = []\n", - "rng = np.random.default_rng(0)\n", - "start_time = time.time()\n", - "for iter in range(40):\n", - " batch = tswp.sample(rng)\n", - " end_time = time.time()\n", - " iter_times.append(end_time - start_time)\n", - " start_time = end_time\n", - "\n", - "print(\"average time per iteration: \", np.mean(iter_times))\n", - "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + "\n", + "with ProgressBar():\n", + " res = dask.compute(tgt_delayed_objs)" ] }, { "cell_type": "code", - "execution_count": 64, - "id": "001e842a", + "execution_count": 9, + "id": "28f54507", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "{'pool_size': 20,\n", - " 'avg_usage': 1.95,\n", - " 'unique_sources': 20,\n", - " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", - " 10, 46, 33]),\n", - " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" + "ename": "TypeError", + "evalue": "create_group() takes 1 positional argument but 2 were given", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m mapping_data = \u001b[43mzarr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcreate_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mzgroup\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmapping_data\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m write_sharded(\n\u001b[32m 4\u001b[39m mapping_data,\n\u001b[32m 5\u001b[39m train_data_dict,\n\u001b[32m (...)\u001b[39m\u001b[32m 8\u001b[39m compressors=\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 9\u001b[39m )\n\u001b[32m 10\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mdone\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mTypeError\u001b[39m: create_group() takes 1 positional argument but 2 were given" + ] } ], "source": [ - "tswp.get_pool_stats()" + "\n", + "mapping_data = zarr.create_group(zgroup, \"mapping_data\")\n", + "\n", + "write_sharded(\n", + " mapping_data,\n", + " train_data_dict,\n", + " chunk_size=chunk_size,\n", + " shard_size=shard_size,\n", + " compressors=None,\n", + ")\n", + "print(\"done\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f07c55d9", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/scripts/create_tahoe.py b/scripts/process_tahoe.py similarity index 50% rename from scripts/create_tahoe.py rename to scripts/process_tahoe.py index 636069bd..5edd1445 100644 --- a/scripts/create_tahoe.py +++ b/scripts/process_tahoe.py @@ -1,3 +1,9 @@ +# %% +# %load_ext autoreload +# %autoreload 2 + + +# %% import anndata as ad import h5py import zarr @@ -7,7 +13,11 @@ import cupy as cp import tqdm import dask +import concurrent.futures +from functools import partial import numpy as np +import dask.array as da +from dask.diagnostics import ProgressBar print("loading data") with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f: @@ -29,12 +39,13 @@ max_combination_length=None, null_value=0.0 ) +print("data loaded") +# %% cond_data = dm._get_condition_data(adata=adata_all) cell_data = dm._get_cell_data(adata_all) - - +# %% n_source_dists = len(cond_data.split_idx_to_covariates) n_target_dists = len(cond_data.perturbation_idx_to_covariates) @@ -55,19 +66,23 @@ "cell_data_index": cp.where(mask)[0].get(), } +# %% +print("Computing cell data") +cell_data = cell_data.compute() +print("cell data computed") for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"): indices = src_cell_data[str(src_idx)]["cell_data_index"] - delayed_obj = dask.delayed(lambda x: cell_data[x])(indices) - src_cell_data[str(src_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype) + src_cell_data[str(src_idx)]["cell_data"] = cell_data[indices] for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"): indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"] - delayed_obj = dask.delayed(lambda x: cell_data[x])(indices) - tgt_cell_data[str(tgt_idx)]["cell_data"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype) + tgt_cell_data[str(tgt_idx)]["cell_data"] = cell_data[indices] +# %% + split_covariates_mask = np.asarray(cond_data.split_covariates_mask) perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask) condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()} @@ -87,21 +102,109 @@ "condition_data": condition_data, "control_to_perturbation": control_to_perturbation, "max_combination_length": int(cond_data.max_combination_length), - "src_cell_data": src_cell_data, - "tgt_cell_data": tgt_cell_data, + # "src_cell_data": src_cell_data, + # "tgt_cell_data": tgt_cell_data, } - -print("writing data") +print("prepared train_data_dict") +# %% path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr" zgroup = zarr.open_group(path, mode="w") -chunk_size = 65536 -shard_size = chunk_size * 16 +chunk_size = 131072 +shard_size = chunk_size * 8 + +ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr + +def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]: + shard_size_used = shard_size + chunk_size_used = chunk_size + if chunk_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + elif chunk_size < shape[0] or shard_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + return chunk_size_used, shard_size_used + + + + +def write_single_array(group, key, arr, idxs, chunk_size, shard_size): + """Write a single array - designed for threading""" + chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size) + + group.create_array( + name=key, + data=arr, + chunks=(chunk_size_used, arr.shape[1]), + shards=(shard_size_used, arr.shape[1]), + compressors=None, + ) + + group.create_array( + name=f"{key}_index", + data=idxs, + chunks=(len(idxs),), + shards=(len(idxs),), + compressors=None, + ) + return key + +def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8): + """Write cell data using threading for I/O parallelism""" + + write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all write tasks + future_to_key = { + executor.submit(write_single_array, group, k, cell_data[k]["cell_data"], cell_data[k]["cell_data_index"], chunk_size, shard_size): k + for k in cell_data.keys() + } + + # Process results with progress bar + for future in tqdm.tqdm( + concurrent.futures.as_completed(future_to_key), + total=len(future_to_key), + desc=f"Writing {group.name}" + ): + key = future_to_key[future] + try: + future.result() # This will raise any exceptions + except Exception as exc: + print(f'Array {key} generated an exception: {exc}') + raise + +# %% + + +src_group = zgroup.create_group("src_cell_data", overwrite=True) +tgt_group = zgroup.create_group("tgt_cell_data", overwrite=True) + + +# Use the fast threaded approach you already implemented +write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=24) +print("done writing src_cell_data") +write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=24) +print("done writing tgt_cell_data") + + + + + + +# %% + +print("Writing mapping data") +mapping_data = zgroup.create_group("mapping_data", overwrite=True) + + write_sharded( - zgroup, - train_data_dict, + group=mapping_data, + name="mapping_data", + data=train_data_dict, chunk_size=chunk_size, shard_size=shard_size, compressors=None, ) -print("done") \ No newline at end of file +print("done") + + diff --git a/scripts/process_tahoe.sbatch b/scripts/process_tahoe.sbatch new file mode 100644 index 00000000..fecb5f55 --- /dev/null +++ b/scripts/process_tahoe.sbatch @@ -0,0 +1,17 @@ +#!/bin/zsh + +#SBATCH -o logs/process_tahoe.out +#SBATCH -e logs/process_tahoe.err +#SBATCH -J process_tahoe +#SBATCH --nice=1 +#SBATCH --time=23:00:00 +#SBATCH --partition=gpu_p +#SBATCH --qos=gpu_normal +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=26 +#SBATCH --mem=500G + +source /home/icb/selman.ozleyen/.zshrc + +mamba activate lpert +python /home/icb/selman.ozleyen/projects/CellFlow2/scripts/process_tahoe.py diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py index 5121b1c3..d54e4b67 100644 --- a/src/cellflow/data/__init__.py +++ b/src/cellflow/data/__init__.py @@ -4,12 +4,12 @@ PredictionData, TrainingData, ValidationData, - ZarrTrainingData, + MappedCellData, ) from cellflow.data._dataloader import ( PredictionSampler, TrainSampler, - TrainSamplerWithPool, + ReservoirSampler, ValidationSampler, ) from cellflow.data._datamanager import DataManager @@ -23,11 +23,11 @@ "PredictionData", "TrainingData", "ValidationData", - "ZarrTrainingData", + "MappedCellData", "TrainSampler", "ValidationSampler", "PredictionSampler", "TorchCombinedTrainSampler", "JaxOutOfCoreTrainSampler", - "TrainSamplerWithPool", + "ReservoirSampler", ] diff --git a/src/cellflow/data/_data.py b/src/cellflow/data/_data.py index 97d73ea6..219dbeae 100644 --- a/src/cellflow/data/_data.py +++ b/src/cellflow/data/_data.py @@ -17,7 +17,7 @@ "PredictionData", "TrainingData", "ValidationData", - "ZarrTrainingData", + "MappedCellData", ] @@ -235,9 +235,7 @@ class ValidationData(BaseDataMixin): n_conditions_on_train_end: int | None = None -def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]: - keys = zgroup[key].keys() - return {k: zgroup[key][k] for k in keys} + @dataclass @@ -275,7 +273,7 @@ class PredictionData(BaseDataMixin): @dataclass -class ZarrTrainingData(BaseDataMixin): +class MappedCellData(BaseDataMixin): """Lazy, Zarr-backed variant of :class:`TrainingData`. Fields mirror those in :class:`TrainingData`, but array-like members are @@ -288,7 +286,10 @@ class ZarrTrainingData(BaseDataMixin): # Note: annotations use Any to allow zarr.Array and zarr groups without # importing zarr at module import time. - cell_data: Any + src_cell_data: dict[str, Any] + tgt_cell_data: dict[str, Any] + src_cell_idx: dict[str, Any] + tgt_cell_idx: dict[str, Any] split_covariates_mask: Any perturbation_covariates_mask: Any split_idx_to_covariates: dict[int, tuple[Any, ...]] @@ -297,24 +298,44 @@ class ZarrTrainingData(BaseDataMixin): condition_data: dict[str, Any] control_to_perturbation: dict[int, Any] max_combination_length: int + mapping_data_full_cached: bool = False def __post_init__(self): # load everything except cell_data to memory # load masks as numpy arrays - self.split_covariates_mask = self.split_covariates_mask[...] - self.perturbation_covariates_mask = self.perturbation_covariates_mask[...] - self.condition_data = {k: np.asarray(v) for k, v in self.condition_data.items()} self.control_to_perturbation = {int(k): np.asarray(v) for k, v in self.control_to_perturbation.items()} - self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} - self.perturbation_idx_to_covariates = { - int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items() - } - self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} + if self.mapping_data_full_cached: + # used in validation usually + self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} + self.perturbation_idx_to_covariates = { + int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items() + } + # not used in nested structure + self.src_cell_idx = self.src_cell_idx[...] + self.tgt_cell_idx = self.tgt_cell_idx[...] + self.split_covariates_mask = self.split_covariates_mask[...] + self.perturbation_covariates_mask = self.perturbation_covariates_mask[...] + self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} + + @staticmethod + def _get_mapping_data(group: zarr.Group) -> dict[str, Any]: + return group["mapping_data"]["mapping_data"] + + @staticmethod + def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]: + keys = zgroup[key].keys() + return {k: zgroup[key][k] for k in keys} + + @staticmethod + def _read_cell_data(zgroup: zarr.Group, key: str) -> dict[int, Any]: + keys = sorted(zgroup[key].keys()) + data_key = [k for k in keys if not k.endswith("_index")] + return {int(k): zgroup[key][k] for k in data_key}, {int(k): zgroup[key][f"{k}_index"] for k in data_key} @classmethod - def read_zarr(cls, path: str) -> ZarrTrainingData: + def read_zarr(cls, path: str) -> MappedCellData: if isinstance(path, str): path = LocalStore(path, read_only=True) group = zarr.open_group(path, mode="r") @@ -327,14 +348,21 @@ def read_zarr(cls, path: str) -> ZarrTrainingData: except Exception: # noqa: BLE001 max_combination_length = int(max_len_node) + mapping_group = cls._get_mapping_data(group) + + src_cell_data, src_cell_idx = cls._read_cell_data(group, "src_cell_data") + tgt_cell_data, tgt_cell_idx = cls._read_cell_data(group, "tgt_cell_data") return cls( - cell_data=group["cell_data"], - split_covariates_mask=group["split_covariates_mask"], - perturbation_covariates_mask=group["perturbation_covariates_mask"], - split_idx_to_covariates=_read_dict(group, "split_idx_to_covariates"), - perturbation_idx_to_covariates=_read_dict(group, "perturbation_idx_to_covariates"), - perturbation_idx_to_id=_read_dict(group, "perturbation_idx_to_id"), - condition_data=_read_dict(group, "condition_data"), - control_to_perturbation=_read_dict(group, "control_to_perturbation"), + tgt_cell_data=tgt_cell_data, + tgt_cell_idx=tgt_cell_idx, + src_cell_data=src_cell_data, + src_cell_idx=src_cell_idx, + split_covariates_mask=mapping_group["split_covariates_mask"], + perturbation_covariates_mask=mapping_group["perturbation_covariates_mask"], + split_idx_to_covariates=cls._read_dict(mapping_group, "split_idx_to_covariates"), + perturbation_idx_to_covariates=cls._read_dict(mapping_group, "perturbation_idx_to_covariates"), + perturbation_idx_to_id=cls._read_dict(mapping_group, "perturbation_idx_to_id"), + condition_data=cls._read_dict(mapping_group, "condition_data"), + control_to_perturbation=cls._read_dict(mapping_group, "control_to_perturbation"), max_combination_length=max_combination_length, ) diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index 4061e0bc..b104359c 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -8,13 +8,14 @@ PredictionData, TrainingData, ValidationData, - ZarrTrainingData, + MappedCellData, ) __all__ = [ "TrainSampler", "ValidationSampler", "PredictionSampler", + "ReservoirSampler", ] @@ -30,7 +31,7 @@ class TrainSampler: """ - def __init__(self, data: TrainingData | ZarrTrainingData, batch_size: int = 1024): + def __init__(self, data: TrainingData, batch_size: int = 1024): self._data = data self._data_idcs = np.arange(data.cell_data.shape[0]) self.batch_size = batch_size @@ -120,12 +121,12 @@ def sample(self, rng) -> dict[str, Any]: return res @property - def data(self) -> TrainingData | ZarrTrainingData: + def data(self) -> TrainingData: """The training data.""" return self._data -class TrainSamplerWithPool(TrainSampler): +class ReservoirSampler(TrainSampler): """Data sampler with gradual pool replacement using reservoir sampling. This approach replaces pool elements one by one rather than refreshing @@ -150,35 +151,24 @@ class TrainSamplerWithPool(TrainSampler): def __init__( self, - data: TrainingData | ZarrTrainingData, + data: MappedCellData, batch_size: int = 1024, pool_size: int = 100, replacement_prob: float = 0.01, ): - super().__init__(data, batch_size) + self.batch_size = batch_size + self.n_source_dists = data.n_controls + self.n_target_dists = data.n_perturbations + self._data = data + + self._control_to_perturbation_keys = sorted(data.control_to_perturbation.keys()) + self._has_condition_data = data.condition_data is not None self._pool_size = pool_size self._replacement_prob = replacement_prob - self._src_idx_pool = np.empty(self._pool_size, dtype=int) self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int) self._initialized = False - def _compute_idx_mappings(self): - import cupy as cp - - self._tgt_to_cell_data_idcs = [None] * self.n_target_dists - gpu_per_cov_mask = cp.asarray(self._data.perturbation_covariates_mask) - gpu_spl_cov_mask = cp.asarray(self._data.split_covariates_mask) - - for tgt_idx in tqdm.tqdm(range(self.n_target_dists), desc="Computing target to cell data idcs"): - mask = gpu_per_cov_mask == tgt_idx - self._tgt_to_cell_data_idcs[tgt_idx] = cp.where(mask)[0].get() - self._src_to_cell_data_idcs = [None] * self.n_source_dists - for src_idx in tqdm.tqdm(range(self.n_source_dists), desc="Computing source to cell data idcs"): - mask = gpu_spl_cov_mask == src_idx - self._src_to_cell_data_idcs[src_idx] = cp.where(mask)[0].get() - - def init_pool_n_cache(self, rng): - self._compute_idx_mappings() + def init_pool(self, rng): self._init_pool(rng) self._init_cache_pool_elements() @@ -191,57 +181,9 @@ def _get_target_idx_pool(src_idx_pool: np.ndarray, control_to_perturbation: dict def _init_cache_pool_elements(self): if not self._initialized: - raise ValueError("Pool not initialized. Call init_pool_n_cache(rng) first.") - - # Build concatenated row indices and slice maps for sources - src_concat = [] - src_slices: dict[int, slice] = {} - offset = 0 - for src_idx in self._src_idx_pool: - idcs = self._src_to_cell_data_idcs[src_idx] - n = len(idcs) - src_slices[src_idx] = slice(offset, offset + n) - src_concat.append(idcs) - offset += n - src_concat = np.concatenate(src_concat) if len(src_concat) else np.empty((0,), dtype=int) - - # Build concatenated row indices and slice maps for targets - tgt_pool = TrainSamplerWithPool._get_target_idx_pool(self._src_idx_pool, self._data.control_to_perturbation) - tgt_concat = [] - tgt_slices: dict[int, slice] = {} - offset = 0 - for tgt_idx in tqdm.tqdm(sorted(tgt_pool), desc="Caching target cells"): - idcs = self._tgt_to_cell_data_idcs[tgt_idx] - n = len(idcs) - tgt_slices[tgt_idx] = slice(offset, offset + n) - tgt_concat.append(idcs) - offset += n - tgt_concat = np.concatenate(tgt_concat) if len(tgt_concat) else np.empty((0,), dtype=int) - - # Single orthogonal-index reads (fast path) - self._src_block = ( - self._data.cell_data.oindex[src_concat, :] - if src_concat.size - else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) - ) - self._tgt_block = ( - self._data.cell_data.oindex[tgt_concat, :] - if tgt_concat.size - else np.empty((0, self._data.cell_data.shape[1]), dtype=self._data.cell_data.dtype) - ) - - # Views into the blocks (no extra copies) - self._cached_srcs = {src_idx: self._src_block[sli] for src_idx, sli in src_slices.items()} - tgt_views = {tgt_idx: self._tgt_block[sli] for tgt_idx, sli in tgt_slices.items()} - self._cached_tgts = { - src_idx: { - tgt_idx: tgt_views[tgt_idx] - for tgt_idx in self._data.control_to_perturbation[src_idx] - if tgt_idx in tgt_views - } - for src_idx in self._src_idx_pool - } - self._initialized = True + raise ValueError("Pool not initialized. Call init_pool(rng) first.") + self._cached_srcs = {i: np.asarray(self._data.src_cell_data[i]) for i in self._src_idx_pool} + self._cached_tgts = {j: np.asarray(self._data.tgt_cell_data[j]) for i in self._src_idx_pool for j in self._data.control_to_perturbation[i]} def _init_pool(self, rng): """Initialize the pool with random source distribution indices.""" @@ -251,11 +193,9 @@ def _init_pool(self, rng): def _sample_source_dist_idx(self, rng) -> int: """Sample a source distribution index with gradual pool replacement.""" if not self._initialized: - self._init_pool(rng) - + raise ValueError("Pool not initialized. Call init_pool(rng) first.") # Sample from current pool - pool_idx = rng.choice(self._pool_size) - source_idx = self._src_idx_pool[pool_idx] + source_idx = rng.choice(sorted(self._cached_srcs.keys())) # Increment usage count for monitoring self._pool_usage_count[source_idx] += 1 @@ -282,8 +222,21 @@ def _replace_pool_element(self, rng): least_used_weight /= least_used_weight.sum() new_pool_idx = rng.choice(self.n_source_dists, p=least_used_weight) self._src_idx_pool[in_pool_idx] = new_pool_idx + self._update_cache(replaced_pool_idx, new_pool_idx) print(f"replaced {replaced_pool_idx} with {new_pool_idx}") + def _update_cache(self, replaced_pool_idx: int, new_pool_idx: int): + print(f"updating cache for {replaced_pool_idx} and {new_pool_idx}") + del self._cached_srcs[replaced_pool_idx] + for k in self._data.control_to_perturbation[replaced_pool_idx]: + del self._cached_tgts[k] + self._cached_srcs[new_pool_idx] = np.asarray(self._data.src_cell_data[new_pool_idx]) + for k in self._data.control_to_perturbation[new_pool_idx]: + self._cached_tgts[k] = np.asarray(self._data.tgt_cell_data[k]) + print(f"updated cache for {replaced_pool_idx} and {new_pool_idx}") + + + def get_pool_stats(self) -> dict: """Get statistics about the current pool state.""" if self._src_idx_pool is None: @@ -300,7 +253,7 @@ def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: return rng.choice(self._cached_srcs[source_dist_idx], size=self.batch_size, replace=True) def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: - return rng.choice(self._cached_tgts[source_dist_idx][target_dist_idx], size=self.batch_size, replace=True) + return rng.choice(self._cached_tgts[target_dist_idx], size=self.batch_size, replace=True) class BaseValidSampler(abc.ABC): diff --git a/src/cellflow/data/_jax_dataloader.py b/src/cellflow/data/_jax_dataloader.py index b0c40358..1c181243 100644 --- a/src/cellflow/data/_jax_dataloader.py +++ b/src/cellflow/data/_jax_dataloader.py @@ -8,7 +8,6 @@ from cellflow.data._data import ( TrainingData, - ZarrTrainingData, ) from cellflow.data._dataloader import TrainSampler @@ -83,7 +82,7 @@ class JaxOutOfCoreTrainSampler: """ - data: TrainingData | ZarrTrainingData + data: TrainingData seed: int batch_size: int = 1024 num_workers: int = 4 diff --git a/src/cellflow/data/_torch_dataloader.py b/src/cellflow/data/_torch_dataloader.py index 22560ee2..61f12040 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/cellflow/data/_torch_dataloader.py @@ -4,8 +4,8 @@ import numpy as np from cellflow.compat import TorchIterableDataset -from cellflow.data._data import ZarrTrainingData -from cellflow.data._dataloader import TrainSampler +from cellflow.data._data import MappedCellData +from cellflow.data._dataloader import TrainSampler, ReservoirSampler def _worker_init_fn_helper(worker_id, random_generators): @@ -74,8 +74,8 @@ def combine_zarr_training_samplers( seq = np.random.SeedSequence(seed) random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) - data = [ZarrTrainingData.read_zarr(path) for path in data_paths] - samplers = [TrainSampler(data[i], batch_size) for i in range(len(data))] + data = [MappedCellData.read_zarr(path) for path in data_paths] + samplers = [ReservoirSampler(data[i], batch_size) for i in range(len(data))] combined_sampler = cls(samplers, weights=weights, dataset_names=dataset_names) return torch.utils.data.DataLoader( combined_sampler, diff --git a/src/cellflow/data/_utils.py b/src/cellflow/data/_utils.py index d7e4a728..d741300b 100644 --- a/src/cellflow/data/_utils.py +++ b/src/cellflow/data/_utils.py @@ -10,6 +10,7 @@ def write_sharded( group: zarr.Group, data: dict[str, Any], + name: str, chunk_size: int = 4096, shard_size: int = 65536, compressors: Iterable[BytesBytesCodec] = ( @@ -37,6 +38,15 @@ def write_sharded( # https://github.com/laminlabs/arrayloaders/blob/main/arrayloaders/io/store_creation.py ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr + def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]: + shard_size_used = shard_size + chunk_size_used = chunk_size + if chunk_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + elif chunk_size < shape[0] or shard_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + return chunk_size_used, shard_size_used + def callback( func: ad.experimental.Write, g: zarr.Group, @@ -46,9 +56,14 @@ def callback( iospec: ad.experimental.IOSpec, ): if iospec.encoding_type in {"array"}: + # Calculate greatest common divisor for first dimension + # or use smallest dimension as chunk size + + chunk_size_used, shard_size_used = get_size(elem.shape, chunk_size, shard_size) + dataset_kwargs = { - "shards": (shard_size,) + (elem.shape[1:]), # only shard over 1st dim - "chunks": (chunk_size,) + (elem.shape[1:]), # only chunk over 1st dim + "shards": (shard_size_used,) + (elem.shape[1:]), # only shard over 1st dim + "chunks": (chunk_size_used,) + (elem.shape[1:]), # only chunk over 1st dim "compressors": compressors, **dataset_kwargs, } @@ -62,7 +77,7 @@ def callback( func(g, k, elem, dataset_kwargs=dataset_kwargs) - ad.experimental.write_dispatched(group, "/", data, callback=callback) + ad.experimental.write_dispatched(group, name, data, callback=callback) zarr.consolidate_metadata(group.store) From dce9c3401500a738222093e877a13a90bd135ce7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 29 Sep 2025 14:06:40 +0200 Subject: [PATCH 29/35] nonblocking version --- docs/notebooks/600_trainsampler copy 2.ipynb | 1853 ------------------ docs/notebooks/600_trainsampler copy.ipynb | 270 +-- docs/notebooks/600_trainsampler.ipynb | 423 ---- docs/notebooks/tahoe_sizes.ipynb | 348 ++++ src/cellflow/data/_dataloader.py | 134 +- 5 files changed, 459 insertions(+), 2569 deletions(-) delete mode 100644 docs/notebooks/600_trainsampler copy 2.ipynb delete mode 100644 docs/notebooks/600_trainsampler.ipynb create mode 100644 docs/notebooks/tahoe_sizes.ipynb diff --git a/docs/notebooks/600_trainsampler copy 2.ipynb b/docs/notebooks/600_trainsampler copy 2.ipynb deleted file mode 100644 index 10e98336..00000000 --- a/docs/notebooks/600_trainsampler copy 2.ipynb +++ /dev/null @@ -1,1853 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "5765bb6c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading data\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", - " return dispatch(args[0].__class__)(*args, **kw)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data loaded\n", - "[########################################] | 100% Completed | 1.11 sms\n", - "[########################################] | 100% Completed | 25.93 s\n", - "[########################################] | 100% Completed | 294.61 s\n" - ] - } - ], - "source": [ - "%%\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "\n", - "# %%\n", - "import anndata as ad\n", - "import h5py\n", - "import zarr\n", - "from cellflow.data._utils import write_sharded\n", - "from anndata.experimental import read_lazy\n", - "from cellflow.data import DataManager\n", - "import cupy as cp\n", - "import tqdm\n", - "import dask\n", - "import concurrent.futures\n", - "from functools import partial\n", - "import numpy as np\n", - "import dask.array as da\n", - "from dask.diagnostics import ProgressBar\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05bd4946", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 87.79it/s]\n", - "Computing target to cell data idcs: 68%|██████▊ | 38602/56827 [00:45<00:21, 854.71it/s]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:06<00:00, 852.83it/s]\n" - ] - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 91, - "id": "35894e7d", - "metadata": {}, - "outputs": [], - "source": [ - "cell_data = cell_data.compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "310df180", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Array Chunk
Bytes 106.87 GiB 1.14 MiB
Shape (95624334, 300) (1000, 300)
Dask graph 95625 chunks in 3 graph layers
Data type float32 numpy.ndarray
\n", - "
\n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - "\n", - " \n", - " 300\n", - " 95624334\n", - "\n", - "
" - ], - "text/plain": [ - "dask.array" - ] - }, - "execution_count": 76, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cell_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6303064c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(2, 300)" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "da.take(cell_data, , axis=0).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "9a9e4693", - "metadata": {}, - "outputs": [], - "source": [ - "cell_data_batch = cell_data[:100_000].compute()" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "id": "8b45658e", - "metadata": {}, - "outputs": [], - "source": [ - "spl_cov_mask_batch = gpu_spl_cov_mask[:100_000]" - ] - }, - { - "cell_type": "code", - "execution_count": 74, - "id": "ee2e0760", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,\n", - " 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,\n", - " 33, 34, 35, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],\n", - " dtype=int32)" - ] - }, - "execution_count": 74, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cp.unique(spl_cov_mask_batch)" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "id": "55ded52a", - "metadata": {}, - "outputs": [], - "source": [ - "mapping = (gpu_per_cov_mask-gpu_spl_cov_mask+50)" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "id": "99ee8158", - "metadata": {}, - "outputs": [], - "source": [ - "sorted_indices = cp.argsort(mapping)\n", - "ordered_mapping = mapping[sorted_indices]" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "id": "615b59c9", - "metadata": {}, - "outputs": [], - "source": [ - "unique_values, inverse_indices = cp.unique(ordered_mapping, return_inverse=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "id": "cd48c25a", - "metadata": {}, - "outputs": [], - "source": [ - "ord_cell_data = da.take(cell_data,sorted_indices.get(),axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "id": "ffe82904", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Array Chunk
Bytes 106.87 GiB 1.21 MiB
Shape (95624334, 300) (1053, 300)
Dask graph 95720 chunks in 4 graph layers
Data type float32 numpy.ndarray
\n", - "
\n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - "\n", - " \n", - " 300\n", - " 95624334\n", - "\n", - "
" - ], - "text/plain": [ - "dask.array" - ] - }, - "execution_count": 89, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ord_cell_data" - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "id": "a636955a", - "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[90]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m ord_cell_data = \u001b[43mord_cell_data\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:373\u001b[39m, in \u001b[36mDaskMethodsMixin.compute\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 349\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, **kwargs):\n\u001b[32m 350\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[32m 351\u001b[39m \n\u001b[32m 352\u001b[39m \u001b[33;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 371\u001b[39m \u001b[33;03m dask.compute\u001b[39;00m\n\u001b[32m 372\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m373\u001b[39m (result,) = \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 374\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:681\u001b[39m, in \u001b[36mcompute\u001b[39m\u001b[34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[39m\n\u001b[32m 678\u001b[39m expr = expr.optimize()\n\u001b[32m 679\u001b[39m keys = \u001b[38;5;28mlist\u001b[39m(flatten(expr.__dask_keys__()))\n\u001b[32m--> \u001b[39m\u001b[32m681\u001b[39m results = \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m repack(results)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_task_spec.py:272\u001b[39m, in \u001b[36mconvert_legacy_graph\u001b[39m\u001b[34m(dsk, all_keys)\u001b[39m\n\u001b[32m 270\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(t, GraphNode):\n\u001b[32m 271\u001b[39m t = DataNode(k, t)\n\u001b[32m--> \u001b[39m\u001b[32m272\u001b[39m new_dsk[k] = t\n\u001b[32m 273\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m new_dsk\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - } - ], - "source": [ - "ord_cell_data = ord_cell_data.compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c48f9c4c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([-1, -1, -1, ..., 32, 31, 38], shape=(95624334,), dtype=int32)" - ] - }, - "execution_count": 69, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "gpu_spl_cov_mask[:100_000].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1fcc7c7", - "metadata": {}, - "outputs": [], - "source": [ - "for k in list(src_cell_data.keys()):\n", - " idx = src_cell_data[k][\"cell_data_index\"]\n", - " src_cell_data[k][\"cell_data\"] = da.take(cell_data, idx, axis=0)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "df755b79", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Exception ignored in: >\n", - "Traceback (most recent call last):\n", - " File \"/home/icb/selman.ozleyen/.local/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 790, in _clean_thread_parent_frames\n", - " active_threads = {thread.ident for thread in threading.enumerate()}\n", - " ^^^^^^^^^^^^\n", - "KeyboardInterrupt: \n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[62]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(tgt_cell_data.keys()):\n\u001b[32m 2\u001b[39m idx = tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mda\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcell_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/routines.py:2013\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(a, indices, axis)\u001b[39m\n\u001b[32m 2011\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _take_dask_array_from_numpy(a, indices, axis)\n\u001b[32m 2012\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2013\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43ma\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mslice\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/core.py:2038\u001b[39m, in \u001b[36mArray.__getitem__\u001b[39m\u001b[34m(self, index)\u001b[39m\n\u001b[32m 2036\u001b[39m out = \u001b[33m\"\u001b[39m\u001b[33mgetitem-\u001b[39m\u001b[33m\"\u001b[39m + tokenize(\u001b[38;5;28mself\u001b[39m, index2)\n\u001b[32m 2037\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2038\u001b[39m dsk, chunks = \u001b[43mslice_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2039\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m SlicingNoop:\n\u001b[32m 2040\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/contextlib.py:81\u001b[39m, in \u001b[36mContextDecorator.__call__..inner\u001b[39m\u001b[34m(*args, **kwds)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(*args, **kwds):\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._recreate_cm():\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:172\u001b[39m, in \u001b[36mslice_array\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 169\u001b[39m index += (\u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m),) * missing\n\u001b[32m 171\u001b[39m \u001b[38;5;66;03m# Pass down to next function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m dsk_out, bd_out = \u001b[43mslice_with_newaxes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 174\u001b[39m bd_out = \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mtuple\u001b[39m, bd_out))\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk_out, bd_out\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:194\u001b[39m, in \u001b[36mslice_with_newaxes\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 191\u001b[39m where_none[i] -= n\n\u001b[32m 193\u001b[39m \u001b[38;5;66;03m# Pass down and do work\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m194\u001b[39m dsk, blockdims2 = \u001b[43mslice_wrap_lists\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 195\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwhere_none\u001b[49m\n\u001b[32m 196\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 198\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m where_none:\n\u001b[32m 199\u001b[39m expand = expander(where_none)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:271\u001b[39m, in \u001b[36mslice_wrap_lists\u001b[39m\u001b[34m(out_name, in_name, blockdims, index, allow_getitem_optimization)\u001b[39m\n\u001b[32m 269\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(is_arraylike(i) \u001b[38;5;129;01mor\u001b[39;00m i == \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index):\n\u001b[32m 270\u001b[39m axis = where_list[\u001b[32m0\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m271\u001b[39m blockdims2, dsk3 = \u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 272\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m[\u001b[49m\u001b[43mwhere_list\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\n\u001b[32m 273\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 274\u001b[39m \u001b[38;5;66;03m# Mixed case. Both slices/integers and lists. slice/integer then take\u001b[39;00m\n\u001b[32m 275\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 276\u001b[39m \u001b[38;5;66;03m# Do first pass without lists\u001b[39;00m\n\u001b[32m 277\u001b[39m tmp = \u001b[33m\"\u001b[39m\u001b[33mslice-\u001b[39m\u001b[33m\"\u001b[39m + tokenize((out_name, in_name, blockdims, index))\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:630\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(outname, inname, chunks, index, axis)\u001b[39m\n\u001b[32m 623\u001b[39m indexer.append(index[i : i + average_chunk_size].tolist())\n\u001b[32m 625\u001b[39m token = (\n\u001b[32m 626\u001b[39m outname.split(\u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 627\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m outname\n\u001b[32m 628\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m tokenize(outname, chunks, index, axis)\n\u001b[32m 629\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m630\u001b[39m chunks, graph = \u001b[43m_shuffle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 631\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m chunks, graph\n\u001b[32m 632\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(chunks[axis]) == \u001b[32m1\u001b[39m:\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/_shuffle.py:212\u001b[39m, in \u001b[36m_shuffle\u001b[39m\u001b[34m(chunks, indexer, axis, in_name, out_name, token)\u001b[39m\n\u001b[32m 209\u001b[39m new_chunks.append(current_chunk)\n\u001b[32m 211\u001b[39m \u001b[38;5;66;03m# force 64 bit to avoid potential integer overflows on win32 and numpy<2\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m212\u001b[39m chunk_boundaries = np.cumsum(\u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m[\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43muint64\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[32m 214\u001b[39m \u001b[38;5;66;03m# Get existing chunk tuple locations\u001b[39;00m\n\u001b[32m 215\u001b[39m chunk_tuples = \u001b[38;5;28mlist\u001b[39m(\n\u001b[32m 216\u001b[39m product(*(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(c)) \u001b[38;5;28;01mfor\u001b[39;00m i, c \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(chunks) \u001b[38;5;28;01mif\u001b[39;00m i != axis))\n\u001b[32m 217\u001b[39m )\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[64]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(tgt_cell_data.keys()):\n\u001b[32m 2\u001b[39m idx = tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data_index\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m tgt_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mda\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcell_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/routines.py:2013\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(a, indices, axis)\u001b[39m\n\u001b[32m 2011\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _take_dask_array_from_numpy(a, indices, axis)\n\u001b[32m 2012\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2013\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43ma\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mslice\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/core.py:2038\u001b[39m, in \u001b[36mArray.__getitem__\u001b[39m\u001b[34m(self, index)\u001b[39m\n\u001b[32m 2036\u001b[39m out = \u001b[33m\"\u001b[39m\u001b[33mgetitem-\u001b[39m\u001b[33m\"\u001b[39m + tokenize(\u001b[38;5;28mself\u001b[39m, index2)\n\u001b[32m 2037\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2038\u001b[39m dsk, chunks = \u001b[43mslice_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2039\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m SlicingNoop:\n\u001b[32m 2040\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/contextlib.py:81\u001b[39m, in \u001b[36mContextDecorator.__call__..inner\u001b[39m\u001b[34m(*args, **kwds)\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(*args, **kwds):\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._recreate_cm():\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:172\u001b[39m, in \u001b[36mslice_array\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 169\u001b[39m index += (\u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m),) * missing\n\u001b[32m 171\u001b[39m \u001b[38;5;66;03m# Pass down to next function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m172\u001b[39m dsk_out, bd_out = \u001b[43mslice_with_newaxes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 174\u001b[39m bd_out = \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mtuple\u001b[39m, bd_out))\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk_out, bd_out\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:194\u001b[39m, in \u001b[36mslice_with_newaxes\u001b[39m\u001b[34m(out_name, in_name, blockdims, index)\u001b[39m\n\u001b[32m 191\u001b[39m where_none[i] -= n\n\u001b[32m 193\u001b[39m \u001b[38;5;66;03m# Pass down and do work\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m194\u001b[39m dsk, blockdims2 = \u001b[43mslice_wrap_lists\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 195\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwhere_none\u001b[49m\n\u001b[32m 196\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 198\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m where_none:\n\u001b[32m 199\u001b[39m expand = expander(where_none)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:271\u001b[39m, in \u001b[36mslice_wrap_lists\u001b[39m\u001b[34m(out_name, in_name, blockdims, index, allow_getitem_optimization)\u001b[39m\n\u001b[32m 269\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(is_arraylike(i) \u001b[38;5;129;01mor\u001b[39;00m i == \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index):\n\u001b[32m 270\u001b[39m axis = where_list[\u001b[32m0\u001b[39m]\n\u001b[32m--> \u001b[39m\u001b[32m271\u001b[39m blockdims2, dsk3 = \u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 272\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblockdims\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m[\u001b[49m\u001b[43mwhere_list\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\n\u001b[32m 273\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 274\u001b[39m \u001b[38;5;66;03m# Mixed case. Both slices/integers and lists. slice/integer then take\u001b[39;00m\n\u001b[32m 275\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 276\u001b[39m \u001b[38;5;66;03m# Do first pass without lists\u001b[39;00m\n\u001b[32m 277\u001b[39m tmp = \u001b[33m\"\u001b[39m\u001b[33mslice-\u001b[39m\u001b[33m\"\u001b[39m + tokenize((out_name, in_name, blockdims, index))\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/slicing.py:630\u001b[39m, in \u001b[36mtake\u001b[39m\u001b[34m(outname, inname, chunks, index, axis)\u001b[39m\n\u001b[32m 623\u001b[39m indexer.append(index[i : i + average_chunk_size].tolist())\n\u001b[32m 625\u001b[39m token = (\n\u001b[32m 626\u001b[39m outname.split(\u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m)[-\u001b[32m1\u001b[39m]\n\u001b[32m 627\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m-\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m outname\n\u001b[32m 628\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m tokenize(outname, chunks, index, axis)\n\u001b[32m 629\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m630\u001b[39m chunks, graph = \u001b[43m_shuffle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchunks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 631\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m chunks, graph\n\u001b[32m 632\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(chunks[axis]) == \u001b[32m1\u001b[39m:\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/array/_shuffle.py:229\u001b[39m, in \u001b[36m_shuffle\u001b[39m\u001b[34m(chunks, indexer, axis, in_name, out_name, token)\u001b[39m\n\u001b[32m 225\u001b[39m sorter_name = \u001b[33m\"\u001b[39m\u001b[33mshuffle-sorter-\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 226\u001b[39m taker_name = \u001b[33m\"\u001b[39m\u001b[33mshuffle-taker-\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 228\u001b[39m old_blocks = {\n\u001b[32m--> \u001b[39m\u001b[32m229\u001b[39m old_index: (in_name,) + old_index\n\u001b[32m 230\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m old_index \u001b[38;5;129;01min\u001b[39;00m np.ndindex(\u001b[38;5;28mtuple\u001b[39m([\u001b[38;5;28mlen\u001b[39m(c) \u001b[38;5;28;01mfor\u001b[39;00m c \u001b[38;5;129;01min\u001b[39;00m chunks]))\n\u001b[32m 231\u001b[39m }\n\u001b[32m 232\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m new_chunk_idx, new_chunk_taker \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(new_chunks):\n\u001b[32m 233\u001b[39m new_chunk_taker = np.array(new_chunk_taker)\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - } - ], - "source": [ - "for k in list(tgt_cell_data.keys()):\n", - " idx = tgt_cell_data[k][\"cell_data_index\"]\n", - " tgt_cell_data[k][\"cell_data\"] = da.take(cell_data, idx, axis=0)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb2e3a2c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Availible mem 34.36\n", - "Using batch size: 2,300,000 rows\n", - "Estimated memory per batch: 5.52 GB\n" - ] - } - ], - "source": [ - "mempool = cp.get_default_memory_pool()\n", - "mempool.set_limit(40 * 1024**3) # Set limit to 40 GB\n", - "batch_size = 2_300_000\n", - "gpu_fraction = 0.8\n", - "available_memory = mempool.get_limit() * gpu_fraction\n", - "\n", - "# Calculate optimal batch size based on memory\n", - "bytes_per_element = cell_data.dtype.itemsize\n", - "elements_per_row = cell_data.shape[1]\n", - "bytes_per_row = bytes_per_element * elements_per_row\n", - "\n", - "# Reserve memory for both input and output\n", - "max_batch_size = int(available_memory / (bytes_per_row * 2))\n", - "actual_batch_size = min(batch_size, max_batch_size)\n", - "\n", - "print(f\"Using batch size: {actual_batch_size:,} rows\")\n", - "print(f\"Estimated memory per batch: {(actual_batch_size * bytes_per_row * 2) / 1e9:.2f} GB\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b36da19", - "metadata": {}, - "outputs": [], - "source": [ - "def process_indices_gpu(indices_dict: Dict, description: str) -> Dict:\n", - " \"\"\"Process a dictionary of indices on GPU\"\"\"\n", - " results = {}\n", - " \n", - " for key in tqdm.tqdm(indices_dict.keys(), desc=description):\n", - " indices = indices_dict[key][\"cell_data_index\"]\n", - " \n", - " if len(indices) == 0:\n", - " results[key] = {\"cell_data\": np.empty((0, cell_data.shape[1]), dtype=cell_data.dtype)}\n", - " continue\n", - " \n", - " # Process in batches if indices are large\n", - " if len(indices) <= actual_batch_size:\n", - " # Small enough to process at once\n", - " gpu_result = process_single_batch_gpu(cell_data, indices)\n", - " results[key] = {\"cell_data\": gpu_result}\n", - " else:\n", - " # Process in multiple batches\n", - " batched_results = []\n", - " n_batches = (len(indices) + actual_batch_size - 1) // actual_batch_size\n", - " \n", - " for batch_idx in range(n_batches):\n", - " start_idx = batch_idx * actual_batch_size\n", - " end_idx = min((batch_idx + 1) * actual_batch_size, len(indices))\n", - " batch_indices = indices[start_idx:end_idx]\n", - " \n", - " batch_result = process_single_batch_gpu(cell_data, batch_indices)\n", - " batched_results.append(batch_result)\n", - " \n", - " # Clear GPU memory between batches\n", - " cp.get_default_memory_pool().free_all_blocks()\n", - " \n", - " # Concatenate results\n", - " final_result = np.concatenate(batched_results, axis=0)\n", - " results[key] = {\"cell_data\": final_result}\n", - " \n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "67262e7f", - "metadata": {}, - "outputs": [], - "source": [ - "import cupy as cp\n", - "import numpy as np\n", - "from typing import Dict, List, Tuple\n", - "import gc\n", - "\n", - "def process_cell_data_gpu_batched(\n", - " cell_data,\n", - " src_cell_data: Dict,\n", - " tgt_cell_data: Dict,\n", - " batch_size: int = 20000, # Adjust based on GPU memory\n", - " gpu_memory_fraction: float = 0.8\n", - ") -> Tuple[Dict, Dict]:\n", - " \"\"\"\n", - " Process cell data indexing on GPU in batches to manage memory efficiently.\n", - " \n", - " Parameters\n", - " ----------\n", - " cell_data : dask.array or numpy array\n", - " The main cell data array\n", - " src_cell_data : dict\n", - " Dictionary containing source cell data indices\n", - " tgt_cell_data : dict\n", - " Dictionary containing target cell data indices\n", - " batch_size : int\n", - " Number of rows to process per batch\n", - " gpu_memory_fraction : float\n", - " Fraction of GPU memory to use for cell data\n", - " \n", - " Returns\n", - " -------\n", - " Tuple of updated src_cell_data and tgt_cell_data dictionaries\n", - " \"\"\"\n", - " \n", - " # Get available GPU memory\n", - " mempool = cp.get_default_memory_pool()\n", - " available_memory = mempool.get_limit() * gpu_memory_fraction\n", - " \n", - " # Calculate optimal batch size based on memory\n", - " bytes_per_element = cell_data.dtype.itemsize\n", - " elements_per_row = cell_data.shape[1]\n", - " bytes_per_row = bytes_per_element * elements_per_row\n", - " \n", - " # Reserve memory for both input and output\n", - " max_batch_size = int(available_memory / (bytes_per_row * 2))\n", - " actual_batch_size = min(batch_size, max_batch_size)\n", - " \n", - " \n", - " def process_indices_gpu(indices_dict: Dict, description: str) -> Dict:\n", - " \"\"\"Process a dictionary of indices on GPU\"\"\"\n", - " results = {}\n", - " \n", - " for key in tqdm.tqdm(indices_dict.keys(), desc=description):\n", - " indices = indices_dict[key][\"cell_data_index\"]\n", - " \n", - " if len(indices) == 0:\n", - " results[key] = {\"cell_data\": np.empty((0, cell_data.shape[1]), dtype=cell_data.dtype)}\n", - " continue\n", - " \n", - " # Process in batches if indices are large\n", - " if len(indices) <= actual_batch_size:\n", - " # Small enough to process at once\n", - " gpu_result = process_single_batch_gpu(cell_data, indices)\n", - " results[key] = {\"cell_data\": gpu_result}\n", - " else:\n", - " # Process in multiple batches\n", - " batched_results = []\n", - " n_batches = (len(indices) + actual_batch_size - 1) // actual_batch_size\n", - " \n", - " for batch_idx in range(n_batches):\n", - " start_idx = batch_idx * actual_batch_size\n", - " end_idx = min((batch_idx + 1) * actual_batch_size, len(indices))\n", - " batch_indices = indices[start_idx:end_idx]\n", - " \n", - " batch_result = process_single_batch_gpu(cell_data, batch_indices)\n", - " batched_results.append(batch_result)\n", - " \n", - " # Clear GPU memory between batches\n", - " cp.get_default_memory_pool().free_all_blocks()\n", - " \n", - " # Concatenate results\n", - " final_result = np.concatenate(batched_results, axis=0)\n", - " results[key] = {\"cell_data\": final_result}\n", - " \n", - " return results\n", - " \n", - " def process_single_batch_gpu(data, indices):\n", - " \"\"\"Process a single batch of indices on GPU\"\"\"\n", - " # Move indices to GPU\n", - " gpu_indices = cp.asarray(indices)\n", - " \n", - " # Move data batch to GPU (only the needed rows)\n", - " if hasattr(data, 'compute'): # Dask array\n", - " # For dask arrays, compute only the needed slices\n", - " cpu_batch = data[indices].compute()\n", - " else: # Regular numpy array\n", - " cpu_batch = data[indices]\n", - " \n", - " # Move to GPU and back to CPU\n", - " gpu_batch = cp.asarray(cpu_batch)\n", - " result = cp.asnumpy(gpu_batch)\n", - " \n", - " # Clean up GPU memory\n", - " del gpu_batch, gpu_indices\n", - " cp.get_default_memory_pool().free_all_blocks()\n", - " \n", - " return result\n", - " \n", - " # Process source and target data\n", - " print(\"Processing source cell data on GPU...\")\n", - " src_results = process_indices_gpu(src_cell_data, \"Processing source indices on GPU\")\n", - " \n", - " print(\"Processing target cell data on GPU...\")\n", - " tgt_results = process_indices_gpu(tgt_cell_data, \"Processing target indices on GPU\")\n", - " \n", - " # Update original dictionaries\n", - " for key in src_results:\n", - " src_cell_data[key].update(src_results[key])\n", - " \n", - " for key in tgt_results:\n", - " tgt_cell_data[key].update(tgt_results[key])\n", - " \n", - " # Final memory cleanup\n", - " cp.get_default_memory_pool().free_all_blocks()\n", - " gc.collect()\n", - " \n", - " return src_cell_data, tgt_cell_data\n" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "6d1848e8", - "metadata": {}, - "outputs": [], - "source": [ - "for k in list(src_cell_data.keys()):\n", - " idx = src_cell_data[k][\"cell_data_index\"]\n", - " src_cell_data[k][\"cell_data\"] = cell_data[idx]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "id": "dba951a7", - "metadata": {}, - "outputs": [], - "source": [ - "for k in list(tgt_cell_data.keys()):\n", - " idx = tgt_cell_data[k][\"cell_data_index\"]\n", - " tgt_cell_data[k][\"cell_data\"] = cell_data[idx]" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "010bd308", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Writing /src_cell_data: 100%|██████████| 50/50 [00:03<00:00, 13.89it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "done writing src_cell_data\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Writing /tgt_cell_data: 100%|██████████| 56827/56827 [22:05<00:00, 42.86it/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "done writing tgt_cell_data\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "path = \"/lustre/groups/ml01/workspace/100mil/tahoe2.zarr\"\n", - "zgroup = zarr.open_group(path, mode=\"w\")\n", - "chunk_size = 131072\n", - "shard_size = chunk_size * 8\n", - "\n", - "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", - "\n", - "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", - " shard_size_used = shard_size\n", - " chunk_size_used = chunk_size\n", - " if chunk_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " elif chunk_size < shape[0] or shard_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " return chunk_size_used, shard_size_used\n", - "\n", - "\n", - "\n", - "\n", - "def write_single_array(group, key, arr, chunk_size, shard_size):\n", - " \"\"\"Write a single array - designed for threading\"\"\"\n", - " chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size)\n", - " \n", - " group.create_array(\n", - " name=key,\n", - " data=arr,\n", - " chunks=(chunk_size_used, arr.shape[1]),\n", - " shards=(shard_size_used, arr.shape[1]),\n", - " compressors=None,\n", - " )\n", - " return key\n", - "\n", - "def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8):\n", - " \"\"\"Write cell data using threading for I/O parallelism\"\"\"\n", - " \n", - " write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size)\n", - " \n", - " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", - " # Submit all write tasks\n", - " future_to_key = {\n", - " executor.submit(write_single_array, group, k, cell_data[k][\"cell_data\"], chunk_size, shard_size): k \n", - " for k in cell_data.keys()\n", - " }\n", - " \n", - " # Process results with progress bar\n", - " for future in tqdm.tqdm(\n", - " concurrent.futures.as_completed(future_to_key), \n", - " total=len(future_to_key),\n", - " desc=f\"Writing {group.name}\"\n", - " ):\n", - " key = future_to_key[future]\n", - " try:\n", - " future.result() # This will raise any exceptions\n", - " except Exception as exc:\n", - " print(f'Array {key} generated an exception: {exc}')\n", - " raise\n", - "\n", - "# %%\n", - "\n", - "\n", - "\n", - "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", - "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", - "\n", - "\n", - "# Use the fast threaded approach you already implemented\n", - "write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=14)\n", - "print(\"done writing src_cell_data\")\n", - "write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=14)\n", - "print(\"done writing tgt_cell_data\")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd842ac9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hi\n" - ] - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 93, - "id": "3bc2cc9d", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 11695.68it/s]\n", - "Computing target to cell data: 100%|██████████| 56827/56827 [00:02<00:00, 27235.88it/s]\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[93]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 14\u001b[39m tgt_results = []\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ProgressBar():\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m src_results, tgt_results = \u001b[43mdask\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc_delayed_objs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_delayed_objs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m src_results:\n\u001b[32m 19\u001b[39m src_cell_data[k][\u001b[33m\"\u001b[39m\u001b[33mcell_data\u001b[39m\u001b[33m\"\u001b[39m] = v\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/base.py:681\u001b[39m, in \u001b[36mcompute\u001b[39m\u001b[34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[39m\n\u001b[32m 678\u001b[39m expr = expr.optimize()\n\u001b[32m 679\u001b[39m keys = \u001b[38;5;28mlist\u001b[39m(flatten(expr.__dask_keys__()))\n\u001b[32m--> \u001b[39m\u001b[32m681\u001b[39m results = \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 683\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m repack(results)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_expr.py:1156\u001b[39m, in \u001b[36m_HLGExprSequence.__dask_graph__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1153\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__dask_graph__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 1154\u001b[39m \u001b[38;5;66;03m# This class has to override this and not just _layer to ensure the HLGs\u001b[39;00m\n\u001b[32m 1155\u001b[39m \u001b[38;5;66;03m# are not optimized individually\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1156\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m ensure_dict(\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_optimized_dsk\u001b[49m)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:998\u001b[39m, in \u001b[36mcached_property.__get__\u001b[39m\u001b[34m(self, instance, owner)\u001b[39m\n\u001b[32m 996\u001b[39m val = cache.get(\u001b[38;5;28mself\u001b[39m.attrname, _NOT_FOUND)\n\u001b[32m 997\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val \u001b[38;5;129;01mis\u001b[39;00m _NOT_FOUND:\n\u001b[32m--> \u001b[39m\u001b[32m998\u001b[39m val = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43minstance\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 999\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1000\u001b[39m cache[\u001b[38;5;28mself\u001b[39m.attrname] = val\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_expr.py:1148\u001b[39m, in \u001b[36m_HLGExprSequence._optimized_dsk\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1146\u001b[39m dsk = hlgexpr.hlg\n\u001b[32m 1147\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (optimizer := hlgexpr.low_level_optimizer) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1148\u001b[39m dsk = \u001b[43moptimizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1149\u001b[39m graphs.append(dsk)\n\u001b[32m 1151\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m HighLevelGraph.merge(*graphs)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/delayed.py:669\u001b[39m, in \u001b[36moptimize\u001b[39m\u001b[34m(dsk, keys, **kwargs)\u001b[39m\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dsk, HighLevelGraph):\n\u001b[32m 668\u001b[39m dsk = HighLevelGraph.from_collections(\u001b[38;5;28mid\u001b[39m(dsk), dsk, dependencies=())\n\u001b[32m--> \u001b[39m\u001b[32m669\u001b[39m dsk = \u001b[43mdsk\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 670\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/highlevelgraph.py:769\u001b[39m, in \u001b[36mHighLevelGraph.cull\u001b[39m\u001b[34m(self, keys)\u001b[39m\n\u001b[32m 767\u001b[39m layer = \u001b[38;5;28mself\u001b[39m.layers[layer_name]\n\u001b[32m 768\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m keys_set:\n\u001b[32m--> \u001b[39m\u001b[32m769\u001b[39m culled_layer, culled_deps = \u001b[43mlayer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkeys_set\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_ext_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 770\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m culled_deps:\n\u001b[32m 771\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/highlevelgraph.py:179\u001b[39m, in \u001b[36mLayer.cull\u001b[39m\u001b[34m(self, keys, all_hlg_keys)\u001b[39m\n\u001b[32m 176\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 177\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mdask\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m_task_spec\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m cull\n\u001b[32m--> \u001b[39m\u001b[32m179\u001b[39m out = \u001b[43mcull\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 180\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m MaterializedLayer(out, annotations=\u001b[38;5;28mself\u001b[39m.annotations), {\n\u001b[32m 181\u001b[39m k: \u001b[38;5;28mset\u001b[39m(v.dependencies) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m out.items()\n\u001b[32m 182\u001b[39m }\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/site-packages/dask/_task_spec.py:1179\u001b[39m, in \u001b[36mcull\u001b[39m\u001b[34m(dsk, keys)\u001b[39m\n\u001b[32m 1177\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(keys) == \u001b[38;5;28mlen\u001b[39m(dsk):\n\u001b[32m 1178\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m dsk\n\u001b[32m-> \u001b[39m\u001b[32m1179\u001b[39m work = \u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1180\u001b[39m seen: \u001b[38;5;28mset\u001b[39m[KeyType] = \u001b[38;5;28mset\u001b[39m()\n\u001b[32m 1181\u001b[39m dsk2 = {}\n", - "\u001b[31mKeyboardInterrupt\u001b[39m: " - ] - } - ], - "source": [ - "\n", - "\n", - "src_delayed_objs = []\n", - "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", - " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " src_delayed_objs.append((str(src_idx), delayed_obj))\n", - "\n", - "tgt_delayed_objs = []\n", - "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", - " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " tgt_delayed_objs.append((str(tgt_idx), delayed_obj))\n", - "\n", - "src_results = []\n", - "tgt_results = []\n", - "with ProgressBar():\n", - " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", - "\n", - "for k, v in src_results:\n", - " src_cell_data[k][\"cell_data\"] = v\n", - "\n", - "for k, v in tgt_results:\n", - " tgt_cell_data[k][\"cell_data\"] = v\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "721747c3", - "metadata": {}, - "outputs": [], - "source": [ - "src_results = []\n", - "tgt_results = []\n", - "with ProgressBar():\n", - " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", - "\n", - "for k, v in src_results:\n", - " src_cell_data[k][\"cell_data\"] = v\n", - "\n", - "for k, v in tgt_results:\n", - " tgt_cell_data[k][\"cell_data\"] = v" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f9ad908", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# %%\n", - "\n", - "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", - "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", - "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", - "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", - "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", - "perturbation_idx_to_covariates = {\n", - " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", - "}\n", - "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", - "\n", - "train_data_dict = {\n", - " \"split_covariates_mask\": split_covariates_mask,\n", - " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", - " \"split_idx_to_covariates\": split_idx_to_covariates,\n", - " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", - " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", - " \"condition_data\": condition_data,\n", - " \"control_to_perturbation\": control_to_perturbation,\n", - " \"max_combination_length\": int(cond_data.max_combination_length),\n", - " # \"src_cell_data\": src_cell_data,\n", - " # \"tgt_cell_data\": tgt_cell_data,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "402c899c", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "print(\"prepared train_data_dict\")\n", - "# %%\n", - "path = \"/lustre/groups/ml01/workspace/100mil/tahoe2.zarr\"\n", - "zgroup = zarr.open_group(path, mode=\"w\")\n", - "chunk_size = 131072\n", - "shard_size = chunk_size * 8\n", - "\n", - "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2ff428dd", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", - " shard_size_used = shard_size\n", - " chunk_size_used = chunk_size\n", - " if chunk_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " elif chunk_size < shape[0] or shard_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " return chunk_size_used, shard_size_used\n", - "\n", - "\n", - "\n", - "\n", - "def write_single_array(group, key, arr, chunk_size, shard_size):\n", - " \"\"\"Write a single array - designed for threading\"\"\"\n", - " chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size)\n", - " \n", - " group.create_array(\n", - " name=key,\n", - " data=arr,\n", - " chunks=(chunk_size_used, arr.shape[1]),\n", - " shards=(shard_size_used, arr.shape[1]),\n", - " compressors=None,\n", - " dtype=arr.dtype,\n", - " )\n", - " return key\n", - "\n", - "def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8):\n", - " \"\"\"Write cell data using threading for I/O parallelism\"\"\"\n", - " \n", - " write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size)\n", - " \n", - " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", - " # Submit all write tasks\n", - " future_to_key = {\n", - " executor.submit(write_single_array, group, k, cell_data[k][\"cell_data\"], chunk_size, shard_size): k \n", - " for k in cell_data.keys()\n", - " }\n", - " \n", - " # Process results with progress bar\n", - " for future in tqdm.tqdm(\n", - " concurrent.futures.as_completed(future_to_key), \n", - " total=len(future_to_key),\n", - " desc=f\"Writing {group.name}\"\n", - " ):\n", - " key = future_to_key[future]\n", - " try:\n", - " future.result() # This will raise any exceptions\n", - " except Exception as exc:\n", - " print(f'Array {key} generated an exception: {exc}')\n", - " raise\n", - "\n", - "# %%\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68a0c4d1", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", - "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", - "\n", - "\n", - "# Use the fast threaded approach you already implemented\n", - "write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=14)\n", - "print(\"done writing src_cell_data\")\n", - "write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=14)\n", - "print(\"done writing tgt_cell_data\")\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "754d6fa7", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "# %%\n", - "\n", - "print(\"Writing mapping data\")\n", - "mapping_data = zgroup.create_group(\"mapping_data\", overwrite=True)\n", - "\n", - "write_sharded(\n", - " mapping_data,\n", - " train_data_dict,\n", - " chunk_size=chunk_size,\n", - " shard_size=shard_size,\n", - " compressors=None,\n", - ")\n", - "print(\"done\")\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "e94414bb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading data\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", - " return dispatch(args[0].__class__)(*args, **kw)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[########################################] | 100% Completed | 926.66 ms\n", - "[########################################] | 100% Completed | 23.50 s\n", - "[########################################] | 100% Completed | 262.74 s\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 60.94it/s]\n", - "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:05<00:00, 864.75it/s]\n", - "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 8124.09it/s]\n", - "Computing target to cell data: 100%|██████████| 56827/56827 [00:06<00:00, 9053.27it/s] \n" - ] - } - ], - "source": [ - "import anndata as ad\n", - "import h5py\n", - "import zarr\n", - "from cellflow.data._utils import write_sharded\n", - "from anndata.experimental import read_lazy\n", - "from cellflow.data import DataManager\n", - "import cupy as cp\n", - "import tqdm\n", - "import dask\n", - "import numpy as np\n", - "\n", - "print(\"loading data\")\n", - "with h5py.File(\"/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad\", \"r\") as f:\n", - " adata_all = ad.AnnData(\n", - " obs=ad.io.read_elem(f[\"obs\"]),\n", - " var=read_lazy(f[\"var\"]),\n", - " uns = read_lazy(f[\"uns\"]),\n", - " obsm = read_lazy(f[\"obsm\"]),\n", - " )\n", - "\n", - "dm = DataManager(adata_all, \n", - " sample_rep=\"X_pca\",\n", - " control_key=\"control\",\n", - " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", - " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", - " sample_covariates=[\"cell_line\"],\n", - " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", - " split_covariates=[\"cell_line\"],\n", - " max_combination_length=None,\n", - " null_value=0.0\n", - ")\n", - "\n", - "cond_data = dm._get_condition_data(adata=adata_all)\n", - "cell_data = dm._get_cell_data(adata_all)\n", - "\n", - "\n", - "\n", - "n_source_dists = len(cond_data.split_idx_to_covariates)\n", - "n_target_dists = len(cond_data.perturbation_idx_to_covariates)\n", - "\n", - "tgt_cell_data = {}\n", - "src_cell_data = {}\n", - "gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)\n", - "gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)\n", - "\n", - "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data idcs\"):\n", - " mask = gpu_spl_cov_mask == src_idx\n", - " src_cell_data[str(src_idx)] = {\n", - " \"cell_data_index\": cp.where(mask)[0].get(),\n", - " }\n", - "\n", - "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data idcs\"):\n", - " mask = gpu_per_cov_mask == tgt_idx\n", - " tgt_cell_data[str(tgt_idx)] = {\n", - " \"cell_data_index\": cp.where(mask)[0].get(),\n", - " }\n", - "\n", - "\n", - "\n", - "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", - " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " src_cell_data[str(src_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", - "\n", - "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", - " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " tgt_cell_data[str(tgt_idx)][\"cell_data\"] = dask.array.from_delayed(delayed_obj, shape=(len(indices), cell_data.shape[1]), dtype=cell_data.dtype)\n", - "\n", - "\n", - "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", - "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", - "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", - "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", - "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", - "perturbation_idx_to_covariates = {\n", - " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", - "}\n", - "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", - "\n", - "train_data_dict = {\n", - " \"split_covariates_mask\": split_covariates_mask,\n", - " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", - " \"split_idx_to_covariates\": split_idx_to_covariates,\n", - " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", - " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", - " \"condition_data\": condition_data,\n", - " \"control_to_perturbation\": control_to_perturbation,\n", - " \"max_combination_length\": int(cond_data.max_combination_length),\n", - " # \"src_cell_data\": src_cell_data,\n", - " # \"tgt_cell_data\": tgt_cell_data,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "8f224240", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", - "zgroup = zarr.open_group(path, mode=\"w\")\n", - "chunk_size = 65536\n", - "shard_size = chunk_size * 16\n", - "\n", - "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", - "\n", - "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", - " shard_size_used = shard_size\n", - " chunk_size_used = chunk_size\n", - " if chunk_size > shape[0] or shard_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " return chunk_size_used, shard_size_used\n", - "\n", - "import dask.array as da\n", - "from dask.diagnostics import ProgressBar\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "e8aedd3b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(60135, 300)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "src_cell_data[str(0)][\"cell_data\"].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "710434e7", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Writing src cell data: 0%| | 0/50 [00:00 dict[str, int | list | dict]:\n", - " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", - " \n", - " Parameters\n", - " ----------\n", - " data\n", - " The training data.\n", - " src_idx\n", - " The source distribution index.\n", - " include_condition_data\n", - " Whether to include condition data in memory calculations.\n", - " \n", - " Returns\n", - " -------\n", - " Dictionary with memory statistics in bytes for the source and its targets.\n", - " \"\"\"\n", - " if src_idx not in data.control_to_perturbation:\n", - " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", - " \n", - " # Get target indices for this source\n", - " target_indices = data.control_to_perturbation[src_idx]\n", - " \n", - " # Calculate memory for source cells\n", - " source_mask = data.split_covariates_mask == src_idx\n", - " n_source_cells = np.sum(source_mask)\n", - " source_memory = n_source_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", - " \n", - " # Calculate memory for target cells\n", - " target_memories = {}\n", - " total_target_memory = 0\n", - " \n", - " for target_idx in target_indices:\n", - " target_mask = data.perturbation_covariates_mask == target_idx\n", - " n_target_cells = np.sum(target_mask)\n", - " target_memory = n_target_cells * data.cell_data.shape[1] * data.cell_data.dtype.itemsize\n", - " target_memories[f\"target_{target_idx}\"] = target_memory\n", - " total_target_memory += target_memory\n", - " \n", - " # Calculate condition data memory if available and requested\n", - " condition_memory = 0\n", - " condition_details = {}\n", - " if include_condition_data and data.condition_data is not None:\n", - " for cond_name, cond_array in data.condition_data.items():\n", - " # Condition data is indexed by target indices\n", - " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", - " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", - " condition_memory += relevant_condition_size\n", - " \n", - " # Calculate total memory\n", - " total_memory = source_memory + total_target_memory + condition_memory\n", - " \n", - " # Calculate average target memory\n", - " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", - " \n", - " result = {\n", - " \"source_idx\": src_idx,\n", - " \"target_indices\": target_indices.tolist(),\n", - " \"source_memory\": source_memory,\n", - " \"source_cell_count\": int(n_source_cells),\n", - " \"total_target_memory\": total_target_memory,\n", - " \"avg_target_memory\": avg_target_memory,\n", - " \"condition_memory\": condition_memory,\n", - " \"total_memory\": total_memory,\n", - " \"target_details\": target_memories,\n", - " }\n", - " \n", - " if condition_details:\n", - " result[\"condition_details\"] = condition_details\n", - " \n", - " return result\n", - "\n", - "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", - " \"\"\"Format memory statistics into a human-readable string.\n", - " \n", - " Parameters\n", - " ----------\n", - " memory_stats\n", - " Dictionary with memory statistics from calculate_memory_cost.\n", - " unit\n", - " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", - " If 'auto', the most appropriate unit will be chosen automatically.\n", - " summary\n", - " If True, includes a summary with average, min, and max target memory statistics\n", - " and omits detailed per-target breakdown.\n", - " \n", - " Returns\n", - " -------\n", - " Human-readable string representation of memory statistics.\n", - " \"\"\"\n", - " def format_bytes(bytes_value, unit=\"auto\"):\n", - " if unit == \"auto\":\n", - " # Choose appropriate unit\n", - " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", - " if bytes_value < 1024 or unit == \"GB\":\n", - " break\n", - " bytes_value /= 1024\n", - " elif unit == \"KB\":\n", - " bytes_value /= 1024\n", - " elif unit == \"MB\":\n", - " bytes_value /= (1024 * 1024)\n", - " elif unit == \"GB\":\n", - " bytes_value /= (1024 * 1024 * 1024)\n", - " \n", - " return f\"{bytes_value:.2f} {unit}\"\n", - " \n", - " src_idx = memory_stats[\"source_idx\"]\n", - " target_indices = memory_stats[\"target_indices\"]\n", - " \n", - " # Base information\n", - " lines = [\n", - " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", - " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", - " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", - " ]\n", - " \n", - " # Calculate min and max target memory if summary is requested\n", - " if summary and memory_stats[\"target_details\"]:\n", - " target_memories = list(memory_stats[\"target_details\"].values())\n", - " min_target = min(target_memories)\n", - " max_target = max(target_memories)\n", - " \n", - " lines.extend([\n", - " \"\\nTarget memory summary:\",\n", - " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", - " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", - " f\"- Min: {format_bytes(min_target, unit)}\",\n", - " f\"- Max: {format_bytes(max_target, unit)}\",\n", - " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", - " ])\n", - " \n", - " # Add condition memory summary if available\n", - " if memory_stats[\"condition_memory\"] > 0:\n", - " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", - " else:\n", - " # Detailed output (original format)\n", - " lines.extend([\n", - " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", - " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", - " \"\\nTarget details:\"\n", - " ])\n", - " \n", - " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", - " target_id = target_key.split(\"_\")[1]\n", - " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", - " \n", - " if \"condition_details\" in memory_stats:\n", - " lines.append(\"\\nCondition details:\")\n", - " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", - " cond_name = cond_key.split(\"_\", 1)[1]\n", - " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", - " \n", - " return \"\\n\".join(lines)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "316e3a6a", - "metadata": {}, - "outputs": [], - "source": [ - "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3d101216", - "metadata": {}, - "outputs": [], - "source": [ - "stats = calculate_memory_cost(ztd, 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a79f9fc2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory statistics for source index 0 with 194 targets:\n", - "- Source cells: 60135 cells, 68.82 MB\n", - "- Total memory: 548.11 MB\n", - "\n", - "Target memory summary:\n", - "- Total: 479.28 MB\n", - "- Average: 2.47 MB\n", - "- Min: 44.53 KB\n", - "- Max: 6.35 MB\n", - "- Range: 6.31 MB\n", - "\n", - "Condition memory: 4.55 KB\n" - ] - } - ], - "source": [ - "print(format_memory_stats(stats, summary=True))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "8c400080", - "metadata": {}, - "outputs": [], - "source": [ - "ztd_stats = {}\n", - "for i in range(ztd.n_controls):\n", - " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "710fb69d", - "metadata": {}, - "outputs": [], - "source": [ - "def print_average_memory_per_source(stats_dict):\n", - " \"\"\"Print the average total memory per source index.\n", - " \n", - " Parameters\n", - " ----------\n", - " stats_dict\n", - " Optional pre-calculated memory statistics dictionary.\n", - " If None, statistics will be calculated for all source indices.\n", - " \"\"\"\n", - " \n", - " \n", - " # Extract total memory for each source index\n", - " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", - " \n", - " # Calculate statistics\n", - " avg_memory = np.mean(total_memories)\n", - " min_memory = np.min(total_memories)\n", - " max_memory = np.max(total_memories)\n", - " median_memory = np.median(total_memories)\n", - " \n", - " # Format the output\n", - " def format_bytes(bytes_value):\n", - " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", - " if bytes_value < 1024 or unit == \"GB\":\n", - " break\n", - " bytes_value /= 1024\n", - " return f\"{bytes_value:.2f} {unit}\"\n", - " \n", - " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", - " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", - " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", - " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", - " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", - " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", - " \n", - " # Identify source indices with min and max memory\n", - " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " \n", - " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", - " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "e2f8f809", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory statistics across 50 source indices:\n", - "- Average total memory per source: 423.18 MB\n", - "- Minimum total memory: 4.33 MB\n", - "- Maximum total memory: 1.29 GB\n", - "- Median total memory: 404.51 MB\n", - "- Range: 1.28 GB\n", - "\n", - "Source index with minimum memory: 39 (4.33 MB)\n", - "Source index with maximum memory: 22 (1.29 GB)\n" - ] - } - ], - "source": [ - "print_average_memory_per_source(ztd_stats)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "91207483", - "metadata": {}, - "outputs": [], - "source": [ - "from cellflow.data import TrainSamplerWithPool\n", - "import numpy as np\n", - "rng = np.random.default_rng(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "17f1fc6c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" - ] - } - ], - "source": [ - "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", - "tswp.init_pool_n_cache(rng)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "782380b2", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81017ffd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "replaced 47 with 34\n", - "replaced 32 with 30\n" - ] - } - ], - "source": [ - "import time\n", - "iter_times = []\n", - "rng = np.random.default_rng(0)\n", - "start_time = time.time()\n", - "for iter in range(40):\n", - " batch = tswp.sample(rng)\n", - " end_time = time.time()\n", - " iter_times.append(end_time - start_time)\n", - " start_time = end_time\n", - "\n", - "print(\"average time per iteration: \", np.mean(iter_times))\n", - "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" - ] - }, - { - "cell_type": "markdown", - "id": "fe14be13", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "001e842a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'pool_size': 20,\n", - " 'avg_usage': 1.95,\n", - " 'unique_sources': 20,\n", - " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", - " 10, 46, 33]),\n", - " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tswp.get_pool_stats()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f07c55d9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/notebooks/600_trainsampler copy.ipynb b/docs/notebooks/600_trainsampler copy.ipynb index a2eb6b44..56b04788 100644 --- a/docs/notebooks/600_trainsampler copy.ipynb +++ b/docs/notebooks/600_trainsampler copy.ipynb @@ -126,12 +126,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 93, "id": "05bd4946", "metadata": {}, "outputs": [], "source": [ - "rs = ReservoirSampler(mcd, batch_size=1024, pool_size=3, replacement_prob=0.01)\n" + "rs = ReservoirSampler(mcd, batch_size=1024, pool_size=40, replacement_prob=0.01)\n" ] }, { @@ -146,241 +146,29 @@ "rs.init_pool(rng)\n" ] }, - { - "cell_type": "code", - "execution_count": 40, - "id": "799aad1f", - "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "np.int64(37)", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mKeyError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 4\u001b[39m start_time = time.time()\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m \u001b[38;5;28miter\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m40\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m batch = \u001b[43mrs\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrng\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 7\u001b[39m end_time = time.time()\n\u001b[32m 8\u001b[39m iter_times.append(end_time - start_time)\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_dataloader.py:114\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(self, rng)\u001b[39m\n\u001b[32m 111\u001b[39m target_dist_idx = \u001b[38;5;28mself\u001b[39m._sample_target_dist_idx(rng, source_dist_idx)\n\u001b[32m 113\u001b[39m \u001b[38;5;66;03m# Sample source and target cells\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m source_batch = \u001b[38;5;28mself\u001b[39m._sample_source_cells(rng, source_dist_idx)\n\u001b[32m 115\u001b[39m target_batch = \u001b[38;5;28mself\u001b[39m._sample_target_cells(rng, source_dist_idx, target_dist_idx)\n\u001b[32m 117\u001b[39m res = {\u001b[33m\"\u001b[39m\u001b[33msrc_cell_data\u001b[39m\u001b[33m\"\u001b[39m: source_batch, \u001b[33m\"\u001b[39m\u001b[33mtgt_cell_data\u001b[39m\u001b[33m\"\u001b[39m: target_batch}\n", - "\u001b[36mFile \u001b[39m\u001b[32m/ictstr01/home/icb/selman.ozleyen/projects/CellFlow2/src/cellflow/data/_dataloader.py:303\u001b[39m, in \u001b[36m_sample_target_cells\u001b[39m\u001b[34m(self, rng, source_dist_idx, target_dist_idx)\u001b[39m\n\u001b[32m 296\u001b[39m \u001b[38;5;28mself\u001b[39m.perturbation_to_control = \u001b[38;5;28mself\u001b[39m._get_perturbation_to_control(val_data)\n\u001b[32m 297\u001b[39m \u001b[38;5;28mself\u001b[39m.n_conditions_on_log_iteration = (\n\u001b[32m 298\u001b[39m val_data.n_conditions_on_log_iteration\n\u001b[32m 299\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val_data.n_conditions_on_log_iteration \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 300\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m val_data.n_perturbations\n\u001b[32m 301\u001b[39m )\n\u001b[32m 302\u001b[39m \u001b[38;5;28mself\u001b[39m.n_conditions_on_train_end = (\n\u001b[32m--> \u001b[39m\u001b[32m303\u001b[39m val_data.n_conditions_on_train_end\n\u001b[32m 304\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m val_data.n_conditions_on_train_end \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 305\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m val_data.n_perturbations\n\u001b[32m 306\u001b[39m )\n\u001b[32m 307\u001b[39m \u001b[38;5;28mself\u001b[39m.rng = np.random.default_rng(seed)\n\u001b[32m 308\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._data.condition_data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "\u001b[31mKeyError\u001b[39m: np.int64(37)" - ] - } - ], - "source": [ - "import time\n", - "iter_times = []\n", - "rng = np.random.default_rng(0)\n", - "start_time = time.time()\n", - "for iter in range(40):\n", - " batch = rs.sample(rng)\n", - " end_time = time.time()\n", - " iter_times.append(end_time - start_time)\n", - " start_time = end_time\n", - "\n", - "print(\"average time per iteration: \", np.mean(iter_times))\n", - "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "316e3a6a", - "metadata": {}, - "outputs": [], - "source": [ - "ztd = ZarrTrainingData.read_zarr(data_paths[0])\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3d101216", - "metadata": {}, - "outputs": [], - "source": [ - "stats = calculate_memory_cost(ztd, 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a79f9fc2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory statistics for source index 0 with 194 targets:\n", - "- Source cells: 60135 cells, 68.82 MB\n", - "- Total memory: 548.11 MB\n", - "\n", - "Target memory summary:\n", - "- Total: 479.28 MB\n", - "- Average: 2.47 MB\n", - "- Min: 44.53 KB\n", - "- Max: 6.35 MB\n", - "- Range: 6.31 MB\n", - "\n", - "Condition memory: 4.55 KB\n" - ] - } - ], - "source": [ - "print(format_memory_stats(stats, summary=True))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "8c400080", - "metadata": {}, - "outputs": [], - "source": [ - "ztd_stats = {}\n", - "for i in range(ztd.n_controls):\n", - " ztd_stats[i] = calculate_memory_cost(ztd, i)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "710fb69d", - "metadata": {}, - "outputs": [], - "source": [ - "def print_average_memory_per_source(stats_dict):\n", - " \"\"\"Print the average total memory per source index.\n", - " \n", - " Parameters\n", - " ----------\n", - " stats_dict\n", - " Optional pre-calculated memory statistics dictionary.\n", - " If None, statistics will be calculated for all source indices.\n", - " \"\"\"\n", - " \n", - " \n", - " # Extract total memory for each source index\n", - " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", - " \n", - " # Calculate statistics\n", - " avg_memory = np.mean(total_memories)\n", - " min_memory = np.min(total_memories)\n", - " max_memory = np.max(total_memories)\n", - " median_memory = np.median(total_memories)\n", - " \n", - " # Format the output\n", - " def format_bytes(bytes_value):\n", - " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", - " if bytes_value < 1024 or unit == \"GB\":\n", - " break\n", - " bytes_value /= 1024\n", - " return f\"{bytes_value:.2f} {unit}\"\n", - " \n", - " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", - " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", - " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", - " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", - " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", - " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", - " \n", - " # Identify source indices with min and max memory\n", - " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", - " \n", - " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", - " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "e2f8f809", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Memory statistics across 50 source indices:\n", - "- Average total memory per source: 423.18 MB\n", - "- Minimum total memory: 4.33 MB\n", - "- Maximum total memory: 1.29 GB\n", - "- Median total memory: 404.51 MB\n", - "- Range: 1.28 GB\n", - "\n", - "Source index with minimum memory: 39 (4.33 MB)\n", - "Source index with maximum memory: 22 (1.29 GB)\n" - ] - } - ], - "source": [ - "print_average_memory_per_source(ztd_stats)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "91207483", - "metadata": {}, - "outputs": [], - "source": [ - "from cellflow.data import TrainSamplerWithPool\n", - "import numpy as np\n", - "rng = np.random.default_rng(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "17f1fc6c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing target to cell data idcs: 100%|██████████| 9980/9980 [00:11<00:00, 891.95it/s] \n", - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 1232.06it/s]\n" - ] - } - ], - "source": [ - "tswp = TrainSamplerWithPool(ztd, batch_size=1024, pool_size=20, replacement_prob=0.01)\n", - "tswp.init_pool_n_cache(rng)" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "782380b2", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81017ffd", + "id": "799aad1f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "replaced 47 with 34\n", - "replaced 32 with 30\n" + "scheduled replacement of 40 with 10 (slot 2)\n", + "scheduled replacement of 49 with 16 (slot 0)\n", + "average time per iteration: 0.0003986350893974304\n", + "iterations per second: 2508.5599000117672\n" ] } ], "source": [ "import time\n", "iter_times = []\n", - "rng = np.random.default_rng(0)\n", "start_time = time.time()\n", - "for iter in range(40):\n", - " batch = tswp.sample(rng)\n", + "for iter in range(4000):\n", + " batch = rs.sample(rng)\n", " end_time = time.time()\n", " iter_times.append(end_time - start_time)\n", " start_time = end_time\n", @@ -388,46 +176,6 @@ "print(\"average time per iteration: \", np.mean(iter_times))\n", "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" ] - }, - { - "cell_type": "markdown", - "id": "fe14be13", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "001e842a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'pool_size': 20,\n", - " 'avg_usage': 1.95,\n", - " 'unique_sources': 20,\n", - " 'pool_elements': array([31, 18, 47, 34, 12, 35, 29, 23, 32, 14, 6, 41, 25, 3, 1, 49, 24,\n", - " 10, 46, 33]),\n", - " 'usage_counts': array([2, 2, 3, 2, 1, 0, 2, 2, 2, 0, 3, 1, 2, 0, 3, 3, 2, 6, 1, 2])}" - ] - }, - "execution_count": 64, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tswp.get_pool_stats()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f07c55d9", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/notebooks/600_trainsampler.ipynb b/docs/notebooks/600_trainsampler.ipynb deleted file mode 100644 index 54640851..00000000 --- a/docs/notebooks/600_trainsampler.ipynb +++ /dev/null @@ -1,423 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "5765bb6c", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "e94414bb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading data\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/ictstr01/home/icb/selman.ozleyen/.local/share/mamba/envs/lpert/lib/python3.12/functools.py:912: ImplicitModificationWarning: Transforming to str index.\n", - " return dispatch(args[0].__class__)(*args, **kw)\n" - ] - } - ], - "source": [ - "import anndata as ad\n", - "import h5py\n", - "import zarr\n", - "from cellflow.data._utils import write_sharded\n", - "from anndata.experimental import read_lazy\n", - "from cellflow.data import DataManager\n", - "import cupy as cp\n", - "import tqdm\n", - "import dask\n", - "import numpy as np\n", - "\n", - "print(\"loading data\")\n", - "with h5py.File(\"/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad\", \"r\") as f:\n", - " adata_all = ad.AnnData(\n", - " obs=ad.io.read_elem(f[\"obs\"]),\n", - " var=read_lazy(f[\"var\"]),\n", - " uns = read_lazy(f[\"uns\"]),\n", - " obsm = read_lazy(f[\"obsm\"]),\n", - " )\n", - "\n", - "dm = DataManager(adata_all, \n", - " sample_rep=\"X_pca\",\n", - " control_key=\"control\",\n", - " perturbation_covariates={\"drugs\": (\"drug\",), \"dosage\": (\"dosage\",)},\n", - " perturbation_covariate_reps={\"drugs\": \"drug_embeddings\"},\n", - " sample_covariates=[\"cell_line\"],\n", - " sample_covariate_reps={\"cell_line\": \"cell_line_embeddings\"},\n", - " split_covariates=[\"cell_line\"],\n", - " max_combination_length=None,\n", - " null_value=0.0\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "37ac0f75", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[########################################] | 100% Completed | 910.75 ms\n", - "[########################################] | 100% Completed | 23.67 s\n", - "[########################################] | 100% Completed | 252.54 s\n" - ] - } - ], - "source": [ - "cond_data = dm._get_condition_data(adata=adata_all)\n", - "cell_data = dm._get_cell_data(adata_all)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e9adbd71", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing source to cell data idcs: 100%|██████████| 50/50 [00:00<00:00, 62.77it/s]\n", - "Computing target to cell data idcs: 100%|██████████| 56827/56827 [01:05<00:00, 863.54it/s]\n" - ] - } - ], - "source": [ - "n_source_dists = len(cond_data.split_idx_to_covariates)\n", - "n_target_dists = len(cond_data.perturbation_idx_to_covariates)\n", - "\n", - "tgt_cell_data = {}\n", - "src_cell_data = {}\n", - "gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask)\n", - "gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask)\n", - "\n", - "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data idcs\"):\n", - " mask = gpu_spl_cov_mask == src_idx\n", - " src_cell_data[str(src_idx)] = {\n", - " \"cell_data_index\": cp.where(mask)[0].get(),\n", - " }\n", - "\n", - "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data idcs\"):\n", - " mask = gpu_per_cov_mask == tgt_idx\n", - " tgt_cell_data[str(tgt_idx)] = {\n", - " \"cell_data_index\": cp.where(mask)[0].get(),\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "dad2d31c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing source to cell data: 100%|██████████| 50/50 [00:00<00:00, 22329.13it/s]\n", - "Computing target to cell data: 6%|▌ | 3184/56827 [00:00<00:01, 31833.81it/s]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Computing target to cell data: 100%|██████████| 56827/56827 [00:02<00:00, 23426.54it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[##################### ] | 52% Completed | 36m 17ss" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "IOStream.flush timed out\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[########################################] | 100% Completed | 73m 49s\n" - ] - } - ], - "source": [ - "\n", - "import dask.array as da\n", - "from dask.diagnostics import ProgressBar\n", - "\n", - "src_delayed_objs = []\n", - "for src_idx in tqdm.tqdm(range(n_source_dists), desc=\"Computing source to cell data\"):\n", - " indices = src_cell_data[str(src_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " src_delayed_objs.append((str(src_idx), delayed_obj))\n", - "\n", - "tgt_delayed_objs = []\n", - "for tgt_idx in tqdm.tqdm(range(n_target_dists), desc=\"Computing target to cell data\"):\n", - " indices = tgt_cell_data[str(tgt_idx)][\"cell_data_index\"]\n", - " delayed_obj = dask.delayed(lambda x: cell_data[x])(indices)\n", - " tgt_delayed_objs.append((str(tgt_idx), delayed_obj))\n", - "\n", - "src_results = []\n", - "tgt_results = []\n", - "with ProgressBar():\n", - " src_results, tgt_results = dask.compute(src_delayed_objs, tgt_delayed_objs)\n", - "\n", - "for k, v in src_results:\n", - " src_cell_data[k][\"cell_data\"] = v\n", - "\n", - "for k, v in tgt_results:\n", - " tgt_cell_data[k][\"cell_data\"] = v\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "d6c007b8", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "split_covariates_mask = np.asarray(cond_data.split_covariates_mask)\n", - "perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask)\n", - "condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()}\n", - "control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()}\n", - "split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()}\n", - "perturbation_idx_to_covariates = {\n", - " str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items()\n", - "}\n", - "perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()}\n", - "\n", - "train_data_dict = {\n", - " \"split_covariates_mask\": split_covariates_mask,\n", - " \"perturbation_covariates_mask\": perturbation_covariates_mask,\n", - " \"split_idx_to_covariates\": split_idx_to_covariates,\n", - " \"perturbation_idx_to_covariates\": perturbation_idx_to_covariates,\n", - " \"perturbation_idx_to_id\": perturbation_idx_to_id,\n", - " \"condition_data\": condition_data,\n", - " \"control_to_perturbation\": control_to_perturbation,\n", - " \"max_combination_length\": int(cond_data.max_combination_length),\n", - " # \"src_cell_data\": src_cell_data,\n", - " # \"tgt_cell_data\": tgt_cell_data,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "8f224240", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", - "zgroup = zarr.open_group(path, mode=\"w\")\n", - "chunk_size = 65536\n", - "shard_size = chunk_size * 16\n", - "\n", - "ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr\n", - "\n", - "def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]:\n", - " shard_size_used = shard_size\n", - " chunk_size_used = chunk_size\n", - " if chunk_size > shape[0] or shard_size > shape[0]:\n", - " chunk_size_used = shard_size_used = shape[0]\n", - " return chunk_size_used, shard_size_used\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "710434e7", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Allocating cell data: 14%|█▍ | 7/50 [00:45<04:47, 6.69s/it]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Allocating cell data: 100%|██████████| 50/50 [06:47<00:00, 8.15s/it]\n", - "Allocating cell data: 100%|██████████| 56827/56827 [41:54<00:00, 22.60it/s] \n" - ] - } - ], - "source": [ - "\n", - "def write_arr(z_arr, arr, k):\n", - " z_arr[:] = arr\n", - " return k\n", - "\n", - "def allocate_cell_data(group, cell_data, chunk_size, shard_size):\n", - " delayed_objs = []\n", - "\n", - " for k in tqdm.tqdm(cell_data.keys(), desc=\"Allocating cell data\"):\n", - " chunk_size_used, shard_size_used = get_size(cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", - " arr = cell_data[k][\"cell_data\"]\n", - "\n", - " z_arr = group.create_array(\n", - " name=k,\n", - " shape=arr.shape,\n", - " chunks=(chunk_size_used, arr.shape[1]),\n", - " shards=(shard_size_used, arr.shape[1]),\n", - " compressors=None,\n", - " dtype=arr.dtype,\n", - " )\n", - "\n", - " delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))\n", - " \n", - " return delayed_objs\n", - "\n", - "\n", - "src_group = zgroup.create_group(\"src_cell_data\", overwrite=True)\n", - "tgt_group = zgroup.create_group(\"tgt_cell_data\", overwrite=True)\n", - "\n", - "\n", - "src_delayed_objs = allocate_cell_data(src_group, src_cell_data, chunk_size, shard_size)\n", - "tgt_delayed_objs = allocate_cell_data(tgt_group, tgt_cell_data, chunk_size, shard_size)\n", - "\n", - "\n", - "\n", - "# for k in tqdm.tqdm(src_cell_data.keys(), desc=\"Writing src cell data\"):\n", - "# chunk_size_used, shard_size_used = get_size(src_cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", - "# arr = src_cell_data[k][\"cell_data\"]\n", - "\n", - "# z_arr = src_group.create_array(\n", - "# name=k,\n", - "# shape=arr.shape,\n", - "# chunks=(chunk_size_used, arr.shape[1]),\n", - "# shards=(shard_size_used, arr.shape[1]),\n", - "# compressors=None,\n", - "# dtype=arr.dtype,\n", - "# )\n", - " \n", - "# delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))\n", - "\n", - "# for k in tqdm.tqdm(tgt_cell_data.keys(), desc=\"Writing tgt cell data\"):\n", - "# chunk_size_used, shard_size_used = get_size(tgt_cell_data[k][\"cell_data\"].shape, chunk_size, shard_size)\n", - "# arr = tgt_cell_data[k][\"cell_data\"]\n", - "# z_arr = tgt_group.create_array(\n", - "# name=k,\n", - "# shape=arr.shape,\n", - "# chunks=(chunk_size_used, arr.shape[1]),\n", - "# shards=(shard_size_used, arr.shape[1]),\n", - "# compressors=None,\n", - "# dtype=arr.dtype,\n", - "# )\n", - " \n", - " \n", - "# delayed_objs.append(dask.delayed(write_arr)(z_arr, arr, k))" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c41b2a3b", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'ProgressBar' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mProgressBar\u001b[49m():\n\u001b[32m 2\u001b[39m res = dask.compute(tgt_delayed_objs)\n", - "\u001b[31mNameError\u001b[39m: name 'ProgressBar' is not defined" - ] - } - ], - "source": [ - "\n", - "with ProgressBar():\n", - " res = dask.compute(tgt_delayed_objs)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "28f54507", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "create_group() takes 1 positional argument but 2 were given", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m mapping_data = \u001b[43mzarr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcreate_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mzgroup\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mmapping_data\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m write_sharded(\n\u001b[32m 4\u001b[39m mapping_data,\n\u001b[32m 5\u001b[39m train_data_dict,\n\u001b[32m (...)\u001b[39m\u001b[32m 8\u001b[39m compressors=\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 9\u001b[39m )\n\u001b[32m 10\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mdone\u001b[39m\u001b[33m\"\u001b[39m)\n", - "\u001b[31mTypeError\u001b[39m: create_group() takes 1 positional argument but 2 were given" - ] - } - ], - "source": [ - "\n", - "mapping_data = zarr.create_group(zgroup, \"mapping_data\")\n", - "\n", - "write_sharded(\n", - " mapping_data,\n", - " train_data_dict,\n", - " chunk_size=chunk_size,\n", - " shard_size=shard_size,\n", - " compressors=None,\n", - ")\n", - "print(\"done\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/notebooks/tahoe_sizes.ipynb b/docs/notebooks/tahoe_sizes.ipynb new file mode 100644 index 00000000..57e27865 --- /dev/null +++ b/docs/notebooks/tahoe_sizes.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3ed731bd", + "metadata": {}, + "outputs": [], + "source": [ + "from cellflow.data import MappedCellData" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "62955dea", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def calculate_memory_cost(\n", + " data: MappedCellData,\n", + " src_idx: int,\n", + " include_condition_data: bool = True\n", + ") -> dict[str, int | list | dict]:\n", + " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", + " \n", + " Parameters\n", + " ----------\n", + " data\n", + " The training data.\n", + " src_idx\n", + " The source distribution index.\n", + " include_condition_data\n", + " Whether to include condition data in memory calculations.\n", + " \n", + " Returns\n", + " -------\n", + " Dictionary with memory statistics in bytes for the source and its targets.\n", + " \"\"\"\n", + " if src_idx not in data.control_to_perturbation:\n", + " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", + " \n", + " # Get target indices for this source\n", + " target_indices = data.control_to_perturbation[src_idx]\n", + " \n", + " # Calculate memory for source cells\n", + " source_mask = data.split_covariates_mask == src_idx\n", + " n_source_cells = data.src_cell_idx[src_idx].shape[0]\n", + " source_memory = data.src_cell_data[src_idx].nbytes\n", + " \n", + " # Calculate memory for target cells\n", + " target_memories = {}\n", + " total_target_memory = 0\n", + " \n", + " for target_idx in target_indices:\n", + " n_target_cells = data.tgt_cell_idx[target_idx].shape[0]\n", + " target_memory = data.tgt_cell_data[target_idx].nbytes\n", + " target_memories[f\"target_{target_idx}\"] = target_memory\n", + " total_target_memory += target_memory\n", + " \n", + " # Calculate condition data memory if available and requested\n", + " condition_memory = 0\n", + " condition_details = {}\n", + " if include_condition_data and data.condition_data is not None:\n", + " for cond_name, cond_array in data.condition_data.items():\n", + " # Condition data is indexed by target indices\n", + " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", + " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", + " condition_memory += relevant_condition_size\n", + " \n", + " # Calculate total memory\n", + " total_memory = source_memory + total_target_memory + condition_memory\n", + " \n", + " # Calculate average target memory\n", + " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", + " \n", + " result = {\n", + " \"source_idx\": src_idx,\n", + " \"target_indices\": target_indices.tolist(),\n", + " \"source_memory\": source_memory,\n", + " \"source_cell_count\": int(n_source_cells),\n", + " \"total_target_memory\": total_target_memory,\n", + " \"avg_target_memory\": avg_target_memory,\n", + " \"condition_memory\": condition_memory,\n", + " \"total_memory\": total_memory,\n", + " \"target_details\": target_memories,\n", + " }\n", + " \n", + " if condition_details:\n", + " result[\"condition_details\"] = condition_details\n", + " \n", + " return result\n", + "\n", + "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", + " \"\"\"Format memory statistics into a human-readable string.\n", + " \n", + " Parameters\n", + " ----------\n", + " memory_stats\n", + " Dictionary with memory statistics from calculate_memory_cost.\n", + " unit\n", + " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", + " If 'auto', the most appropriate unit will be chosen automatically.\n", + " summary\n", + " If True, includes a summary with average, min, and max target memory statistics\n", + " and omits detailed per-target breakdown.\n", + " \n", + " Returns\n", + " -------\n", + " Human-readable string representation of memory statistics.\n", + " \"\"\"\n", + " def format_bytes(bytes_value, unit=\"auto\"):\n", + " if unit == \"auto\":\n", + " # Choose appropriate unit\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " elif unit == \"KB\":\n", + " bytes_value /= 1024\n", + " elif unit == \"MB\":\n", + " bytes_value /= (1024 * 1024)\n", + " elif unit == \"GB\":\n", + " bytes_value /= (1024 * 1024 * 1024)\n", + " \n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " src_idx = memory_stats[\"source_idx\"]\n", + " target_indices = memory_stats[\"target_indices\"]\n", + " \n", + " # Base information\n", + " lines = [\n", + " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", + " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", + " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", + " ]\n", + " \n", + " # Calculate min and max target memory if summary is requested\n", + " if summary and memory_stats[\"target_details\"]:\n", + " target_memories = list(memory_stats[\"target_details\"].values())\n", + " min_target = min(target_memories)\n", + " max_target = max(target_memories)\n", + " \n", + " lines.extend([\n", + " \"\\nTarget memory summary:\",\n", + " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", + " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", + " f\"- Min: {format_bytes(min_target, unit)}\",\n", + " f\"- Max: {format_bytes(max_target, unit)}\",\n", + " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", + " ])\n", + " \n", + " # Add condition memory summary if available\n", + " if memory_stats[\"condition_memory\"] > 0:\n", + " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", + " else:\n", + " # Detailed output (original format)\n", + " lines.extend([\n", + " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", + " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", + " \"\\nTarget details:\"\n", + " ])\n", + " \n", + " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", + " target_id = target_key.split(\"_\")[1]\n", + " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", + " \n", + " if \"condition_details\" in memory_stats:\n", + " lines.append(\"\\nCondition details:\")\n", + " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", + " cond_name = cond_key.split(\"_\", 1)[1]\n", + " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", + " \n", + " return \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "316e3a6a", + "metadata": {}, + "outputs": [], + "source": [ + "data = MappedCellData.read_zarr(\n", + " \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3d101216", + "metadata": {}, + "outputs": [], + "source": [ + "stats = calculate_memory_cost(data, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a79f9fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics for source index 0 with 194 targets:\n", + "- Source cells: 60135 cells, 68.82 MB\n", + "- Total memory: 548.11 MB\n", + "\n", + "Target memory summary:\n", + "- Total: 479.28 MB\n", + "- Average: 2.47 MB\n", + "- Min: 44.53 KB\n", + "- Max: 6.35 MB\n", + "- Range: 6.31 MB\n", + "\n", + "Condition memory: 4.55 KB\n" + ] + } + ], + "source": [ + "print(format_memory_stats(stats, summary=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8c400080", + "metadata": {}, + "outputs": [], + "source": [ + "data_stats = {}\n", + "for i in range(data.n_controls):\n", + " data_stats[i] = calculate_memory_cost(data, i)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "710fb69d", + "metadata": {}, + "outputs": [], + "source": [ + "def print_average_memory_per_source(stats_dict):\n", + " \"\"\"Print the average total memory per source index.\n", + " \n", + " Parameters\n", + " ----------\n", + " stats_dict\n", + " Optional pre-calculated memory statistics dictionary.\n", + " If None, statistics will be calculated for all source indices.\n", + " \"\"\"\n", + " \n", + " \n", + " # Extract total memory for each source index\n", + " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", + " \n", + " # Calculate statistics\n", + " avg_memory = np.mean(total_memories)\n", + " min_memory = np.min(total_memories)\n", + " max_memory = np.max(total_memories)\n", + " median_memory = np.median(total_memories)\n", + " \n", + " # Format the output\n", + " def format_bytes(bytes_value):\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", + " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", + " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", + " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", + " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", + " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", + " \n", + " # Identify source indices with min and max memory\n", + " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " \n", + " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", + " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e2f8f809", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics across 50 source indices:\n", + "- Average total memory per source: 2.14 GB\n", + "- Minimum total memory: 21.01 MB\n", + "- Maximum total memory: 6.75 GB\n", + "- Median total memory: 2.05 GB\n", + "- Range: 6.73 GB\n", + "\n", + "Source index with minimum memory: 39 (21.01 MB)\n", + "Source index with maximum memory: 22 (6.75 GB)\n" + ] + } + ], + "source": [ + "print_average_memory_per_source(data_stats)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f07c55d9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py index b104359c..614a5004 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/cellflow/data/_dataloader.py @@ -3,6 +3,8 @@ import numpy as np import tqdm +import threading +from concurrent.futures import ThreadPoolExecutor, Future from cellflow.data._data import ( PredictionData, @@ -125,7 +127,6 @@ def data(self) -> TrainingData: """The training data.""" return self._data - class ReservoirSampler(TrainSampler): """Data sampler with gradual pool replacement using reservoir sampling. @@ -168,6 +169,12 @@ def __init__( self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int) self._initialized = False + # Concurrency primitives + self._lock = threading.RLock() + self._executor = ThreadPoolExecutor(max_workers=2) + # Map pool position -> {"old": int, "new": int, "future": Future} + self._pending_replacements: dict[int, dict[str, Any]] = {} + def init_pool(self, rng): self._init_pool(rng) self._init_cache_pool_elements() @@ -182,8 +189,13 @@ def _get_target_idx_pool(src_idx_pool: np.ndarray, control_to_perturbation: dict def _init_cache_pool_elements(self): if not self._initialized: raise ValueError("Pool not initialized. Call init_pool(rng) first.") - self._cached_srcs = {i: np.asarray(self._data.src_cell_data[i]) for i in self._src_idx_pool} - self._cached_tgts = {j: np.asarray(self._data.tgt_cell_data[j]) for i in self._src_idx_pool for j in self._data.control_to_perturbation[i]} + with self._lock: + self._cached_srcs = {i: self._data.src_cell_data[i][...] for i in self._src_idx_pool} + self._cached_tgts = { + j: self._data.tgt_cell_data[j][...] + for i in self._src_idx_pool + for j in self._data.control_to_perturbation[i] + } def _init_pool(self, rng): """Initialize the pool with random source distribution indices.""" @@ -194,48 +206,102 @@ def _sample_source_dist_idx(self, rng) -> int: """Sample a source distribution index with gradual pool replacement.""" if not self._initialized: raise ValueError("Pool not initialized. Call init_pool(rng) first.") + + # Opportunistically apply any ready replacements (non-blocking) + self._apply_ready_replacements() + # Sample from current pool - source_idx = rng.choice(sorted(self._cached_srcs.keys())) + with self._lock: + source_idx = rng.choice(sorted(self._cached_srcs.keys())) # Increment usage count for monitoring self._pool_usage_count[source_idx] += 1 - # Gradually replace elements based on replacement probability + # Gradually replace elements based on replacement probability (schedule only) if rng.random() < self._replacement_prob: - self._replace_pool_element(rng) + self._schedule_replacement(rng) return source_idx - def _replace_pool_element(self, rng): - """Replace a single pool element with a new one.""" - # instead sample weighted by usage count - # let's only consider the pool_usage_count.min() for least used - # and the pool_usage_count.max() for most used + def _schedule_replacement(self, rng): + """Schedule a single pool element replacement without blocking.""" + # weights same as previous logic most_used_weight = (self._pool_usage_count == self._pool_usage_count.max()).astype(float) + if most_used_weight.sum() == 0: + return most_used_weight /= most_used_weight.sum() - - # weight by most used replaced_pool_idx = rng.choice(self.n_source_dists, p=most_used_weight) - if replaced_pool_idx in set(self._src_idx_pool): - in_pool_idx = np.where(self._src_idx_pool == replaced_pool_idx)[0][0] - least_used_weight = (self._pool_usage_count == self._pool_usage_count.min()).astype(float) - least_used_weight /= least_used_weight.sum() - new_pool_idx = rng.choice(self.n_source_dists, p=least_used_weight) - self._src_idx_pool[in_pool_idx] = new_pool_idx - self._update_cache(replaced_pool_idx, new_pool_idx) - print(f"replaced {replaced_pool_idx} with {new_pool_idx}") - def _update_cache(self, replaced_pool_idx: int, new_pool_idx: int): - print(f"updating cache for {replaced_pool_idx} and {new_pool_idx}") - del self._cached_srcs[replaced_pool_idx] - for k in self._data.control_to_perturbation[replaced_pool_idx]: - del self._cached_tgts[k] - self._cached_srcs[new_pool_idx] = np.asarray(self._data.src_cell_data[new_pool_idx]) - for k in self._data.control_to_perturbation[new_pool_idx]: - self._cached_tgts[k] = np.asarray(self._data.tgt_cell_data[k]) - print(f"updated cache for {replaced_pool_idx} and {new_pool_idx}") + with self._lock: + pool_set = set(self._src_idx_pool.tolist()) + if replaced_pool_idx not in pool_set: + return + in_pool_idx = int(np.where(self._src_idx_pool == replaced_pool_idx)[0][0]) + # If there's already a pending replacement for this pool slot, skip + if in_pool_idx in self._pending_replacements: + return + least_used_weight = (self._pool_usage_count == self._pool_usage_count.min()).astype(float) + if least_used_weight.sum() == 0: + return + least_used_weight /= least_used_weight.sum() + new_pool_idx = int(rng.choice(self.n_source_dists, p=least_used_weight)) + + # Kick off background load for new indices + fut: Future = self._executor.submit(self._load_new_cache, new_pool_idx) + self._pending_replacements[in_pool_idx] = { + "old": replaced_pool_idx, + "new": new_pool_idx, + "future": fut, + } + print(f"scheduled replacement of {replaced_pool_idx} with {new_pool_idx} (slot {in_pool_idx})") + + def _apply_ready_replacements(self): + """Apply any finished background loads; non-blocking.""" + to_apply: list[int] = [] + with self._lock: + for slot, info in self._pending_replacements.items(): + fut: Future = info["future"] + if fut.done() and not fut.cancelled(): + to_apply.append(slot) + + for slot in to_apply: + with self._lock: + info = self._pending_replacements.pop(slot, None) + if info is None: + continue + old_idx = int(info["old"]) + new_idx = int(info["new"]) + fut: Future = info["future"] + try: + prepared = fut.result(timeout=0) # already done + except Exception as e: + print(f"background load failed for {new_idx}: {e}") + continue + + # Swap pool index + self._src_idx_pool[slot] = new_idx + + # Add new entries first + self._cached_srcs[new_idx] = prepared["src"] + for k, arr in prepared["tgts"].items(): + self._cached_tgts[k] = arr + + # Remove old entries + if old_idx in self._cached_srcs: + del self._cached_srcs[old_idx] + for k in self._data.control_to_perturbation[old_idx]: + if k in self._cached_tgts: + del self._cached_tgts[k] + + print(f"applied replacement: {old_idx} -> {new_idx} (slot {slot})") + + def _load_new_cache(self, src_idx: int) -> dict[str, Any]: + """Load new src and corresponding tgt arrays in the background.""" + src_arr = self._data.src_cell_data[src_idx][...] + tgt_dict = {k: self._data.tgt_cell_data[k][...] for k in self._data.control_to_perturbation[src_idx]} + return {"src": src_arr, "tgts": tgt_dict} def get_pool_stats(self) -> dict: """Get statistics about the current pool state.""" @@ -250,10 +316,14 @@ def get_pool_stats(self) -> dict: } def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: - return rng.choice(self._cached_srcs[source_dist_idx], size=self.batch_size, replace=True) + with self._lock: + arr = self._cached_srcs[source_dist_idx] + return rng.choice(arr, size=self.batch_size, replace=True) def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: - return rng.choice(self._cached_tgts[target_dist_idx], size=self.batch_size, replace=True) + with self._lock: + arr = self._cached_tgts[target_dist_idx] + return rng.choice(arr, size=self.batch_size, replace=True) class BaseValidSampler(abc.ABC): From 4e89f5060feb02f831da13ce503ea077deedd1ab Mon Sep 17 00:00:00 2001 From: AlejandroTL Date: Mon, 29 Sep 2025 14:52:02 +0200 Subject: [PATCH 30/35] name change to avoid conflicts --- pyproject.toml | 6 +- src/cellflow/__init__.py | 4 - src/cellflow/external/__init__.py | 6 - src/cellflow/model/__init__.py | 3 - src/cellflow/plotting/__init__.py | 3 - src/cellflow/preprocessing/__init__.py | 9 - src/cellflow/solvers/__init__.py | 4 - src/cfp/__init__.py | 2 +- src/scaleflow/__init__.py | 4 + src/{cellflow => scaleflow}/_constants.py | 0 src/{cellflow => scaleflow}/_logging.py | 0 src/{cellflow => scaleflow}/_optional.py | 0 src/{cellflow => scaleflow}/_types.py | 0 .../compat/__init__.py | 0 src/{cellflow => scaleflow}/compat/torch_.py | 2 +- src/{cellflow => scaleflow}/data/__init__.py | 10 +- src/{cellflow => scaleflow}/data/_data.py | 4 +- .../data/_dataloader.py | 31 +- .../data/_datamanager.py | 10 +- .../data/_jax_dataloader.py | 7 +- .../data/_torch_dataloader.py | 6 +- src/{cellflow => scaleflow}/data/_utils.py | 0 src/{cellflow => scaleflow}/datasets.py | 2 +- src/scaleflow/external/__init__.py | 6 + src/{cellflow => scaleflow}/external/_scvi.py | 4 +- .../external/_scvi_utils.py | 0 .../metrics/__init__.py | 2 +- .../metrics/_metrics.py | 0 src/scaleflow/model/__init__.py | 3 + .../model/_cellflow.py | 156 ++++--- src/{cellflow => scaleflow}/model/_utils.py | 2 +- .../networks/__init__.py | 6 +- .../networks/_set_encoders.py | 4 +- .../networks/_utils.py | 2 +- .../networks/_velocity_field.py | 114 ++++- src/scaleflow/plotting/__init__.py | 3 + .../plotting/_plotting.py | 6 +- .../plotting/_utils.py | 2 +- src/scaleflow/preprocessing/__init__.py | 9 + .../preprocessing/_gene_emb.py | 2 +- .../preprocessing/_pca.py | 2 +- .../preprocessing/_preprocessing.py | 6 +- .../preprocessing/_wknn.py | 4 +- src/scaleflow/solvers/__init__.py | 4 + src/{cellflow => scaleflow}/solvers/_genot.py | 10 +- src/scaleflow/solvers/_multitask_otfm.py | 399 ++++++++++++++++++ src/{cellflow => scaleflow}/solvers/_otfm.py | 16 +- src/{cellflow => scaleflow}/solvers/utils.py | 0 .../training/__init__.py | 4 +- .../training/_callbacks.py | 36 +- .../training/_trainer.py | 35 +- .../training/_utils.py | 0 src/{cellflow => scaleflow}/utils.py | 0 53 files changed, 756 insertions(+), 194 deletions(-) delete mode 100644 src/cellflow/__init__.py delete mode 100644 src/cellflow/external/__init__.py delete mode 100644 src/cellflow/model/__init__.py delete mode 100644 src/cellflow/plotting/__init__.py delete mode 100644 src/cellflow/preprocessing/__init__.py delete mode 100644 src/cellflow/solvers/__init__.py create mode 100644 src/scaleflow/__init__.py rename src/{cellflow => scaleflow}/_constants.py (100%) rename src/{cellflow => scaleflow}/_logging.py (100%) rename src/{cellflow => scaleflow}/_optional.py (100%) rename src/{cellflow => scaleflow}/_types.py (100%) rename src/{cellflow => scaleflow}/compat/__init__.py (100%) rename src/{cellflow => scaleflow}/compat/torch_.py (86%) rename src/{cellflow => scaleflow}/data/__init__.py (66%) rename src/{cellflow => scaleflow}/data/_data.py (99%) rename src/{cellflow => scaleflow}/data/_dataloader.py (92%) rename src/{cellflow => scaleflow}/data/_datamanager.py (99%) rename src/{cellflow => scaleflow}/data/_jax_dataloader.py (94%) rename src/{cellflow => scaleflow}/data/_torch_dataloader.py (94%) rename src/{cellflow => scaleflow}/data/_utils.py (100%) rename src/{cellflow => scaleflow}/datasets.py (98%) create mode 100644 src/scaleflow/external/__init__.py rename src/{cellflow => scaleflow}/external/_scvi.py (98%) rename src/{cellflow => scaleflow}/external/_scvi_utils.py (100%) rename src/{cellflow => scaleflow}/metrics/__init__.py (92%) rename src/{cellflow => scaleflow}/metrics/_metrics.py (100%) create mode 100644 src/scaleflow/model/__init__.py rename src/{cellflow => scaleflow}/model/_cellflow.py (85%) rename src/{cellflow => scaleflow}/model/_utils.py (96%) rename src/{cellflow => scaleflow}/networks/__init__.py (70%) rename src/{cellflow => scaleflow}/networks/_set_encoders.py (98%) rename src/{cellflow => scaleflow}/networks/_utils.py (99%) rename src/{cellflow => scaleflow}/networks/_velocity_field.py (85%) create mode 100644 src/scaleflow/plotting/__init__.py rename src/{cellflow => scaleflow}/plotting/_plotting.py (96%) rename src/{cellflow => scaleflow}/plotting/_utils.py (98%) create mode 100644 src/scaleflow/preprocessing/__init__.py rename src/{cellflow => scaleflow}/preprocessing/_gene_emb.py (99%) rename src/{cellflow => scaleflow}/preprocessing/_pca.py (99%) rename src/{cellflow => scaleflow}/preprocessing/_preprocessing.py (98%) rename src/{cellflow => scaleflow}/preprocessing/_wknn.py (99%) create mode 100644 src/scaleflow/solvers/__init__.py rename src/{cellflow => scaleflow}/solvers/_genot.py (97%) create mode 100644 src/scaleflow/solvers/_multitask_otfm.py rename src/{cellflow => scaleflow}/solvers/_otfm.py (95%) rename src/{cellflow => scaleflow}/solvers/utils.py (100%) rename src/{cellflow => scaleflow}/training/__init__.py (79%) rename src/{cellflow => scaleflow}/training/_callbacks.py (93%) rename src/{cellflow => scaleflow}/training/_trainer.py (80%) rename src/{cellflow => scaleflow}/training/_utils.py (100%) rename src/{cellflow => scaleflow}/utils.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 152c690b..dbcf7619 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ build-backend = "hatchling.build" requires = [ "hatch-vcs", "hatchling" ] [project] -name = "cellflow-tools" +name = "scaleflow-tools" description = "Modeling complex perturbations with flow matching at single-cell resolution" readme = "README.md" license = "PolyForm-Noncommercial-1.0.0" @@ -103,7 +103,7 @@ urls.Home-page = "https://github.com/theislab/cellflow" urls.Source = "https://github.com/theislab/cellflow" [tool.hatch.build.targets.wheel] -packages = [ 'src/cellflow' ] +packages = [ 'src/scaleflow' ] [tool.hatch.version] source = "vcs" @@ -201,7 +201,7 @@ extras = test,pp,external,embedding pass_env = PYTEST_*,CI commands = coverage run -m pytest {tty:--color=yes} {posargs: \ - --cov={env_site_packages_dir}{/}cellflow --cov-config={tox_root}{/}pyproject.toml \ + --cov={env_site_packages_dir}{/}scaleflow --cov-config={tox_root}{/}pyproject.toml \ --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} [testenv:lint-code] diff --git a/src/cellflow/__init__.py b/src/cellflow/__init__.py deleted file mode 100644 index 526fc741..00000000 --- a/src/cellflow/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from importlib import metadata - -import cellflow.preprocessing as pp -from cellflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/external/__init__.py b/src/cellflow/external/__init__.py deleted file mode 100644 index 7a03a1c8..00000000 --- a/src/cellflow/external/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -try: - from cellflow.external._scvi import CFJaxSCVI -except ImportError as e: - raise ImportError( - "cellflow.external requires more dependencies. Please install via pip install 'cellflow[external]'" - ) from e diff --git a/src/cellflow/model/__init__.py b/src/cellflow/model/__init__.py deleted file mode 100644 index 8731f241..00000000 --- a/src/cellflow/model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.model._cellflow import CellFlow - -__all__ = ["CellFlow"] diff --git a/src/cellflow/plotting/__init__.py b/src/cellflow/plotting/__init__.py deleted file mode 100644 index c7fd387e..00000000 --- a/src/cellflow/plotting/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.plotting._plotting import plot_condition_embedding - -__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/preprocessing/__init__.py b/src/cellflow/preprocessing/__init__.py deleted file mode 100644 index 21eaa993..00000000 --- a/src/cellflow/preprocessing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from cellflow.preprocessing._gene_emb import ( - GeneInfo, - get_esm_embedding, - prot_sequence_from_ensembl, - protein_features_from_genes, -) -from cellflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca -from cellflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints -from cellflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/solvers/__init__.py b/src/cellflow/solvers/__init__.py deleted file mode 100644 index a02a5510..00000000 --- a/src/cellflow/solvers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cellflow.solvers._genot import GENOT -from cellflow.solvers._otfm import OTFlowMatching - -__all__ = ["GENOT", "OTFlowMatching"] diff --git a/src/cfp/__init__.py b/src/cfp/__init__.py index 76f0742f..aeca0ffd 100644 --- a/src/cfp/__init__.py +++ b/src/cfp/__init__.py @@ -1 +1 @@ -from cellflow import * # noqa: F403 +from scaleflow import * # noqa: F403 diff --git a/src/scaleflow/__init__.py b/src/scaleflow/__init__.py new file mode 100644 index 00000000..60891e49 --- /dev/null +++ b/src/scaleflow/__init__.py @@ -0,0 +1,4 @@ +from importlib import metadata + +import scaleflow.preprocessing as pp +from scaleflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/_constants.py b/src/scaleflow/_constants.py similarity index 100% rename from src/cellflow/_constants.py rename to src/scaleflow/_constants.py diff --git a/src/cellflow/_logging.py b/src/scaleflow/_logging.py similarity index 100% rename from src/cellflow/_logging.py rename to src/scaleflow/_logging.py diff --git a/src/cellflow/_optional.py b/src/scaleflow/_optional.py similarity index 100% rename from src/cellflow/_optional.py rename to src/scaleflow/_optional.py diff --git a/src/cellflow/_types.py b/src/scaleflow/_types.py similarity index 100% rename from src/cellflow/_types.py rename to src/scaleflow/_types.py diff --git a/src/cellflow/compat/__init__.py b/src/scaleflow/compat/__init__.py similarity index 100% rename from src/cellflow/compat/__init__.py rename to src/scaleflow/compat/__init__.py diff --git a/src/cellflow/compat/torch_.py b/src/scaleflow/compat/torch_.py similarity index 86% rename from src/cellflow/compat/torch_.py rename to src/scaleflow/compat/torch_.py index 5a51fa5e..b79f134e 100644 --- a/src/cellflow/compat/torch_.py +++ b/src/scaleflow/compat/torch_.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from cellflow._optional import OptionalDependencyNotAvailable, torch_required_msg +from scaleflow._optional import OptionalDependencyNotAvailable, torch_required_msg try: from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore diff --git a/src/cellflow/data/__init__.py b/src/scaleflow/data/__init__.py similarity index 66% rename from src/cellflow/data/__init__.py rename to src/scaleflow/data/__init__.py index 5121b1c3..43266d8e 100644 --- a/src/cellflow/data/__init__.py +++ b/src/scaleflow/data/__init__.py @@ -1,4 +1,4 @@ -from cellflow.data._data import ( +from scaleflow.data._data import ( BaseDataMixin, ConditionData, PredictionData, @@ -6,15 +6,15 @@ ValidationData, ZarrTrainingData, ) -from cellflow.data._dataloader import ( +from scaleflow.data._dataloader import ( PredictionSampler, TrainSampler, TrainSamplerWithPool, ValidationSampler, ) -from cellflow.data._datamanager import DataManager -from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler -from cellflow.data._torch_dataloader import TorchCombinedTrainSampler +from scaleflow.data._datamanager import DataManager +from scaleflow.data._jax_dataloader import JaxOutOfCoreTrainSampler +from scaleflow.data._torch_dataloader import TorchCombinedTrainSampler __all__ = [ "DataManager", diff --git a/src/cellflow/data/_data.py b/src/scaleflow/data/_data.py similarity index 99% rename from src/cellflow/data/_data.py rename to src/scaleflow/data/_data.py index 97d73ea6..ea727097 100644 --- a/src/cellflow/data/_data.py +++ b/src/scaleflow/data/_data.py @@ -8,8 +8,8 @@ import zarr from zarr.storage import LocalStore -from cellflow._types import ArrayLike -from cellflow.data._utils import write_sharded +from scaleflow._types import ArrayLike +from scaleflow.data._utils import write_sharded __all__ = [ "BaseDataMixin", diff --git a/src/cellflow/data/_dataloader.py b/src/scaleflow/data/_dataloader.py similarity index 92% rename from src/cellflow/data/_dataloader.py rename to src/scaleflow/data/_dataloader.py index 4061e0bc..cef1eac2 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/scaleflow/data/_dataloader.py @@ -4,7 +4,7 @@ import numpy as np import tqdm -from cellflow.data._data import ( +from scaleflow.data._data import ( PredictionData, TrainingData, ValidationData, @@ -19,7 +19,7 @@ class TrainSampler: - """Data sampler for :class:`~cellflow.data.TrainingData`. + """Data sampler for :class:`~scaleflow.data.TrainingData`. Parameters ---------- @@ -326,7 +326,7 @@ def _get_condition_data(self, cond_idx: int) -> dict[str, np.ndarray]: class ValidationSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.ValidationData`. + """Data sampler for :class:`~scaleflow.data.ValidationData`. Parameters ---------- @@ -334,9 +334,12 @@ class ValidationSampler(BaseValidSampler): The validation data. seed Random seed. + validation_batch_size + Maximum number of cells to sample per condition during validation. + If None, uses all available cells. """ - def __init__(self, val_data: ValidationData, seed: int = 0) -> None: + def __init__(self, val_data: ValidationData, seed: int = 0, validation_batch_size: int | None = None) -> None: self._data = val_data self.perturbation_to_control = self._get_perturbation_to_control(val_data) self.n_conditions_on_log_iteration = ( @@ -349,6 +352,7 @@ def __init__(self, val_data: ValidationData, seed: int = 0) -> None: if val_data.n_conditions_on_train_end is not None else val_data.n_perturbations ) + self.validation_batch_size = validation_batch_size self.rng = np.random.default_rng(seed) if self._data.condition_data is None: raise NotImplementedError("Validation data must have condition data.") @@ -373,6 +377,12 @@ def sample(self, mode: Literal["on_log_iteration", "on_train_end"]) -> Any: source_cells = [self._data.cell_data[mask] for mask in source_cells_mask] target_cells_mask = [cond_idx == self._data.perturbation_covariates_mask for cond_idx in condition_idcs] target_cells = [self._data.cell_data[mask] for mask in target_cells_mask] + + # Apply validation batch size if specified + if self.validation_batch_size is not None: + source_cells = self._subsample_cells(source_cells) + target_cells = self._subsample_cells(target_cells) + conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs] cell_rep_dict = {} cond_dict = {} @@ -385,6 +395,17 @@ def sample(self, mode: Literal["on_log_iteration", "on_train_end"]) -> Any: return {"source": cell_rep_dict, "condition": cond_dict, "target": true_dict} + def _subsample_cells(self, cells_list: list[np.ndarray]) -> list[np.ndarray]: + """Subsample cells from each condition to validation_batch_size.""" + subsampled_cells = [] + for cells in cells_list: + if len(cells) > self.validation_batch_size: + indices = self.rng.choice(len(cells), size=self.validation_batch_size, replace=False) + subsampled_cells.append(cells[indices]) + else: + subsampled_cells.append(cells) + return subsampled_cells + @property def data(self) -> ValidationData: """The validation data.""" @@ -392,7 +413,7 @@ def data(self) -> ValidationData: class PredictionSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.PredictionData`. + """Data sampler for :class:`~scaleflow.data.PredictionData`. Parameters ---------- diff --git a/src/cellflow/data/_datamanager.py b/src/scaleflow/data/_datamanager.py similarity index 99% rename from src/cellflow/data/_datamanager.py rename to src/scaleflow/data/_datamanager.py index 065cddd2..c83b462a 100644 --- a/src/cellflow/data/_datamanager.py +++ b/src/scaleflow/data/_datamanager.py @@ -13,9 +13,9 @@ from dask.diagnostics import ProgressBar from pandas.api.types import is_numeric_dtype -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData from ._utils import _flatten_list, _to_list @@ -223,8 +223,8 @@ def get_prediction_data( is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. covariate_data A :class:`~pandas.DataFrame` with columns defining the covariates as - in :meth:`cellflow.model.CellFlow.prepare_data` and stored in - :attr:`cellflow.model.CellFlow.data_manager`. + in :meth:`scaleflow.model.CellFlow.prepare_data` and stored in + :attr:`scaleflow.model.CellFlow.data_manager`. rep_dict Dictionary with representations of the covariates. If not provided, :attr:`~anndata.AnnData.uns` is used. diff --git a/src/cellflow/data/_jax_dataloader.py b/src/scaleflow/data/_jax_dataloader.py similarity index 94% rename from src/cellflow/data/_jax_dataloader.py rename to src/scaleflow/data/_jax_dataloader.py index b0c40358..9cde0fec 100644 --- a/src/cellflow/data/_jax_dataloader.py +++ b/src/scaleflow/data/_jax_dataloader.py @@ -6,11 +6,11 @@ import numpy as np -from cellflow.data._data import ( +from scaleflow.data._data import ( TrainingData, ZarrTrainingData, ) -from cellflow.data._dataloader import TrainSampler +from scaleflow.data._dataloader import TrainSampler def _prefetch_to_device( @@ -95,7 +95,8 @@ def __post_init__(self): def set_sampler(self, num_iterations: int) -> None: self._iterator = _prefetch_to_device( - sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor + sampler=self.inner, seed=self.seed, num_iterations=num_iterations, + prefetch_factor=self.prefetch_factor, num_workers=self.num_workers ) def sample(self, rng=None) -> dict[str, Any]: diff --git a/src/cellflow/data/_torch_dataloader.py b/src/scaleflow/data/_torch_dataloader.py similarity index 94% rename from src/cellflow/data/_torch_dataloader.py rename to src/scaleflow/data/_torch_dataloader.py index 22560ee2..b70248b5 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/scaleflow/data/_torch_dataloader.py @@ -3,9 +3,9 @@ import numpy as np -from cellflow.compat import TorchIterableDataset -from cellflow.data._data import ZarrTrainingData -from cellflow.data._dataloader import TrainSampler +from scaleflow.compat import TorchIterableDataset +from scaleflow.data._data import ZarrTrainingData +from scaleflow.data._dataloader import TrainSampler def _worker_init_fn_helper(worker_id, random_generators): diff --git a/src/cellflow/data/_utils.py b/src/scaleflow/data/_utils.py similarity index 100% rename from src/cellflow/data/_utils.py rename to src/scaleflow/data/_utils.py diff --git a/src/cellflow/datasets.py b/src/scaleflow/datasets.py similarity index 98% rename from src/cellflow/datasets.py rename to src/scaleflow/datasets.py index d07ce340..7b2bbbd3 100644 --- a/src/cellflow/datasets.py +++ b/src/scaleflow/datasets.py @@ -4,7 +4,7 @@ import anndata as ad from scanpy.readwrite import _check_datafile_present_and_download -from cellflow._types import PathLike +from scaleflow._types import PathLike __all__ = [ "ineurons", diff --git a/src/scaleflow/external/__init__.py b/src/scaleflow/external/__init__.py new file mode 100644 index 00000000..eba74ed0 --- /dev/null +++ b/src/scaleflow/external/__init__.py @@ -0,0 +1,6 @@ +try: + from scaleflow.external._scvi import CFJaxSCVI +except ImportError as e: + raise ImportError( + "scaleflow.external requires more dependencies. Please install via pip install 'cellflow[external]'" + ) from e diff --git a/src/cellflow/external/_scvi.py b/src/scaleflow/external/_scvi.py similarity index 98% rename from src/cellflow/external/_scvi.py rename to src/scaleflow/external/_scvi.py index f979d93c..e84a912f 100644 --- a/src/cellflow/external/_scvi.py +++ b/src/scaleflow/external/_scvi.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike if TYPE_CHECKING: from typing import Literal @@ -25,7 +25,7 @@ class CFJaxSCVI(JaxSCVI): - from cellflow.external._scvi_utils import CFJaxVAE + from scaleflow.external._scvi_utils import CFJaxVAE _module_cls = CFJaxVAE diff --git a/src/cellflow/external/_scvi_utils.py b/src/scaleflow/external/_scvi_utils.py similarity index 100% rename from src/cellflow/external/_scvi_utils.py rename to src/scaleflow/external/_scvi_utils.py diff --git a/src/cellflow/metrics/__init__.py b/src/scaleflow/metrics/__init__.py similarity index 92% rename from src/cellflow/metrics/__init__.py rename to src/scaleflow/metrics/__init__.py index 79cb1738..63a2aa52 100644 --- a/src/cellflow/metrics/__init__.py +++ b/src/scaleflow/metrics/__init__.py @@ -1,4 +1,4 @@ -from cellflow.metrics._metrics import ( +from scaleflow.metrics._metrics import ( compute_e_distance, compute_e_distance_fast, compute_mean_metrics, diff --git a/src/cellflow/metrics/_metrics.py b/src/scaleflow/metrics/_metrics.py similarity index 100% rename from src/cellflow/metrics/_metrics.py rename to src/scaleflow/metrics/_metrics.py diff --git a/src/scaleflow/model/__init__.py b/src/scaleflow/model/__init__.py new file mode 100644 index 00000000..427f3234 --- /dev/null +++ b/src/scaleflow/model/__init__.py @@ -0,0 +1,3 @@ +from scaleflow.model._cellflow import CellFlow + +__all__ = ["CellFlow"] diff --git a/src/cellflow/model/_cellflow.py b/src/scaleflow/model/_cellflow.py similarity index 85% rename from src/cellflow/model/_cellflow.py rename to src/scaleflow/model/_cellflow.py index b8c63b87..93cbbcaf 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/scaleflow/model/_cellflow.py @@ -15,18 +15,18 @@ import pandas as pd from ott.neural.methods.flows import dynamics -from cellflow import _constants -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler -from cellflow.data._data import ConditionData, TrainingData, ValidationData -from cellflow.data._datamanager import DataManager -from cellflow.model._utils import _write_predictions -from cellflow.networks import _velocity_field -from cellflow.plotting import _utils -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback -from cellflow.training._trainer import CellFlowTrainer -from cellflow.utils import match_linear +from scaleflow import _constants +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from scaleflow.data._data import ConditionData, TrainingData, ValidationData +from scaleflow.data._datamanager import DataManager +from scaleflow.model._utils import _write_predictions +from scaleflow.networks import _velocity_field +from scaleflow.plotting import _utils +from scaleflow.solvers import _genot, _otfm +from scaleflow.training._callbacks import BaseCallback +from scaleflow.training._trainer import CellFlowTrainer +from scaleflow.utils import match_linear __all__ = ["CellFlow"] @@ -73,19 +73,19 @@ def prepare_data( max_combination_length: int | None = None, null_value: float = 0.0, ) -> None: - """Prepare the dataloader for training from :attr:`~cellflow.model.CellFlow.adata`. + """Prepare the dataloader for training from :attr:`~scaleflow.model.CellFlow.adata`. Parameters ---------- sample_rep - Key in :attr:`~anndata.AnnData.obsm` of :attr:`cellflow.model.CellFlow.adata` where + Key in :attr:`~anndata.AnnData.obsm` of :attr:`scaleflow.model.CellFlow.adata` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. control_key Key of a boolean column in :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata` that defines the control samples. + :attr:`scaleflow.model.CellFlow.adata` that defines the control samples. perturbation_covariates A dictionary where the keys indicate the name of the covariate group and the values are - keys in :attr:`~anndata.AnnData.obs` of :attr:`cellflow.model.CellFlow.adata`. The + keys in :attr:`~anndata.AnnData.obs` of :attr:`scaleflow.model.CellFlow.adata`. The corresponding columns can be of the following types: - categorial: The column contains categories whose representation is stored in @@ -126,8 +126,8 @@ def prepare_data( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.data_manager` - the :class:`cellflow.data.DataManager` object. - - :attr:`cellflow.model.CellFlow.train_data` - the training data. + - :attr:`scaleflow.model.CellFlow.data_manager` - the :class:`scaleflow.data.DataManager` object. + - :attr:`scaleflow.model.CellFlow.train_data` - the training data. Example ------- @@ -203,7 +203,7 @@ def prepare_validation_data( An :class:`~anndata.AnnData` object. name Name of the validation data defining the key in - :attr:`cellflow.model.CellFlow.validation_data`. + :attr:`scaleflow.model.CellFlow.validation_data`. n_conditions_on_log_iteration Number of conditions to use for computation callbacks at each logged iteration. If :obj:`None`, use all conditions. @@ -212,14 +212,14 @@ def prepare_validation_data( If :obj:`None`, use all conditions. predict_kwargs Keyword arguments for the prediction function - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. Returns ------- :obj:`None`, and updates the following fields: - - :attr:`cellflow.model.CellFlow.validation_data` - a dictionary with the validation data. + - :attr:`scaleflow.model.CellFlow.validation_data` - a dictionary with the validation data. """ if self.train_data is None: @@ -238,7 +238,7 @@ def prepare_validation_data( predict_kwargs = predict_kwargs or {} # Check if predict_kwargs is alreday provided from an earlier call if "predict_kwargs" in self._validation_data and len(predict_kwargs): - predict_kwargs = self._validation_data["predict_kwargs"].update(predict_kwargs) + self._validation_data["predict_kwargs"].update(predict_kwargs) # Set batched prediction to False if split_val is True if split_val: predict_kwargs["batched"] = False @@ -279,7 +279,7 @@ def prepare_model( """Prepare the model for training. This function sets up the neural network architecture and specificities of the - :attr:`solver`. When :attr:`solver` is an instance of :class:`cellflow.solvers._genot.GENOT`, + :attr:`solver`. When :attr:`solver` is an instance of :class:`scaleflow.solvers._genot.GENOT`, the following arguments have to be passed to ``'condition_encoder_kwargs'``: @@ -309,9 +309,9 @@ def prepare_model( pooling_kwargs Keyword arguments for the pooling method corresponding to: - - :class:`cellflow.networks.TokenAttentionPooling` if ``'pooling'`` is + - :class:`scaleflow.networks.TokenAttentionPooling` if ``'pooling'`` is ``'attention_token'``. - - :class:`cellflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. + - :class:`scaleflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. layers_before_pool Layers applied to the condition embeddings before pooling. Can be of type @@ -320,8 +320,8 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keyword arguments for the layer type :class:`cellflow.networks.MLPBlock` or - :class:`cellflow.networks.SelfAttentionBlock`. + - Further keyword arguments for the layer type :class:`scaleflow.networks.MLPBlock` or + :class:`scaleflow.networks.SelfAttentionBlock`. - :class:`dict` with keys corresponding to perturbation covariate keys, and values correspondinng to the above mentioned tuples. @@ -333,16 +333,16 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keys depend on the layer type, either for :class:`cellflow.networks.MLPBlock` or - for :class:`cellflow.networks.SelfAttentionBlock`. + - Further keys depend on the layer type, either for :class:`scaleflow.networks.MLPBlock` or + for :class:`scaleflow.networks.SelfAttentionBlock`. condition_embedding_dim Dimensions of the condition embedding, i.e. the last layer of the - :class:`cellflow.networks.ConditionEncoder`. + :class:`scaleflow.networks.ConditionEncoder`. cond_output_dropout - Dropout rate for the last layer of the :class:`cellflow.networks.ConditionEncoder`. + Dropout rate for the last layer of the :class:`scaleflow.networks.ConditionEncoder`. condition_encoder_kwargs - Keyword arguments for the :class:`cellflow.networks.ConditionEncoder`. + Keyword arguments for the :class:`scaleflow.networks.ConditionEncoder`. pool_sample_covariates Whether to include sample covariates in the pooling. time_freqs @@ -350,17 +350,17 @@ def prepare_model( (:func:`ott.neural.networks.layers.sinusoidal_time_encoder`). time_max_period Controls the frequency of the time embeddings, see - :func:`cellflow.networks.utils.sinusoidal_time_encoder`. + :func:`scaleflow.networks.utils.sinusoidal_time_encoder`. time_encoder_dims Dimensions of the layers processing the time embedding in - :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. time_encoder_dropout - Dropout rate for the :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + Dropout rate for the :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. hidden_dims Dimensions of the layers processing the input to the velocity field - via :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + via :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. hidden_dropout - Dropout rate for :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + Dropout rate for :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. conditioning Conditioning method, should be one of: @@ -373,18 +373,18 @@ def prepare_model( Keyword arguments for the conditioning method. decoder_dims Dimensions of the output layers in - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. decoder_dropout Dropout rate for the output layer - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. vf_act_fn - Activation function of the :class:`cellflow.networks.ConditionalVelocityField`. + Activation function of the :class:`scaleflow.networks.ConditionalVelocityField`. vf_kwargs Additional keyword arguments for the solver-specific vector field. For instance, when ``'solver==genot'``, the following keyword argument can be passed: - ``'genot_source_dims'`` of type :class:`tuple` with the dimensions - of the :class:`cellflow.networks.MLPBlock` processing the source cell. + of the :class:`scaleflow.networks.MLPBlock` processing the source cell. - ``'genot_source_dropout'`` of type :class:`float` indicating the dropout rate for the source cell processing. probability_path @@ -397,12 +397,12 @@ def prepare_model( match_fn Matching function between unperturbed and perturbed cells. Should take as input source and target data and return the optimal transport matrix, see e.g. - :func:`cellflow.utils.match_linear`. + :func:`scaleflow.utils.match_linear`. optimizer Optimizer used for training. solver_kwargs - Keyword arguments for the solver :class:`cellflow.solvers.OTFlowMatching` or - :class:`cellflow.solvers.GENOT`. + Keyword arguments for the solver :class:`scaleflow.solvers.OTFlowMatching` or + :class:`scaleflow.solvers.GENOT`. layer_norm_before_concatenation If :obj:`True`, applies layer normalization before concatenating the embedded time, embedded data, and condition embeddings. @@ -416,16 +416,19 @@ def prepare_model( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.velocity_field` - an instance of the - :class:`cellflow.networks.ConditionalVelocityField`. - - :attr:`cellflow.model.CellFlow.solver` - an instance of :class:`cellflow.solvers.OTFlowMatching` - or :class:`cellflow.solvers.GENOT`. - - :attr:`cellflow.model.CellFlow.trainer` - an instance of the - :class:`cellflow.training.CellFlowTrainer`. + - :attr:`scaleflow.model.CellFlow.velocity_field` - an instance of the + :class:`scaleflow.networks.ConditionalVelocityField`. + - :attr:`scaleflow.model.CellFlow.solver` - an instance of :class:`scaleflow.solvers.OTFlowMatching` + or :class:`scaleflow.solvers.GENOT`. + - :attr:`scaleflow.model.CellFlow.trainer` - an instance of the + :class:`scaleflow.training.CellFlowTrainer`. """ if self.train_data is None: raise ValueError("Dataloader not initialized. Please call `prepare_data` first.") + # Store the seed for use in train method + self._seed = seed + if condition_mode == "stochastic": if regularization == 0.0: raise ValueError("Stochastic condition embeddings require `regularization`>0.") @@ -517,9 +520,12 @@ def train( num_iterations: int, batch_size: int = 1024, valid_freq: int = 1000, + validation_batch_size: int | None = None, callbacks: Sequence[BaseCallback] = [], monitor_metrics: Sequence[str] = [], out_of_core_dataloading: bool = False, + num_workers: int = 8, # Increased from default 4 + prefetch_factor: int = 4, # Increased from default 2 ) -> None: """Train the model. @@ -539,21 +545,21 @@ def train( callbacks Callbacks to perform at each validation step. There are two types of callbacks: - Callbacks for computations should inherit from - :class:`~cellflow.training.ComputationCallback` see e.g. :class:`cellflow.training.Metrics`. - - Callbacks for logging should inherit from :class:`~cellflow.training.LoggingCallback` see - e.g. :class:`~cellflow.training.WandbLogger`. + :class:`~scaleflow.training.ComputationCallback` see e.g. :class:`scaleflow.training.Metrics`. + - Callbacks for logging should inherit from :class:`~scaleflow.training.LoggingCallback` see + e.g. :class:`~scaleflow.training.WandbLogger`. monitor_metrics Metrics to monitor. out_of_core_dataloading - If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data.JaxOutOfCoreTrainSampler` + If :obj:`True`, use out-of-core dataloading. Uses the :class:`scaleflow.data.JaxOutOfCoreTrainSampler` to load data that does not fit into GPU memory. Returns ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.dataloader` - the training dataloader. - - :attr:`cellflow.model.CellFlow.solver` - the trained solver. + - :attr:`scaleflow.model.CellFlow.dataloader` - the training dataloader. + - :attr:`scaleflow.model.CellFlow.solver` - the trained solver. """ if self.train_data is None: raise ValueError("Data not initialized. Please call `prepare_data` first.") @@ -562,10 +568,22 @@ def train( raise ValueError("Model not initialized. Please call `prepare_model` first.") if out_of_core_dataloading: - self._dataloader = JaxOutOfCoreTrainSampler(data=self.train_data, batch_size=batch_size) + self._dataloader = JaxOutOfCoreTrainSampler( + data=self.train_data, + batch_size=batch_size, + seed=self._seed, + num_workers=num_workers, + prefetch_factor=prefetch_factor + ) else: self._dataloader = TrainSampler(data=self.train_data, batch_size=batch_size) - validation_loaders = {k: ValidationSampler(v) for k, v in self.validation_data.items() if k != "predict_kwargs"} + + # Pass validation_batch_size to ValidationSampler + validation_loaders = { + k: ValidationSampler(v, validation_batch_size=validation_batch_size) + for k, v in self.validation_data.items() + if k != "predict_kwargs" + } self._solver = self.trainer.train( dataloader=self._dataloader, @@ -595,8 +613,8 @@ def predict( covariate_data Covariate data defining the condition to predict. This :class:`~pandas.DataFrame` should have the same columns as :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata`, and as registered in - :attr:`cellflow.model.CellFlow.data_manager`. + :attr:`scaleflow.model.CellFlow.adata`, and as registered in + :attr:`scaleflow.model.CellFlow.data_manager`. sample_rep Key in :attr:`~anndata.AnnData.obsm` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. If :obj:`None`, the key is assumed to be @@ -608,12 +626,12 @@ def predict( If :obj:`None`, the predictions are not stored, and the predictions are returned as a :class:`dict`. rng - Random number generator. If :obj:`None` and :attr:`cellflow.model.CellFlow.conditino_mode` + Random number generator. If :obj:`None` and :attr:`scaleflow.model.CellFlow.conditino_mode` is ``'stochastic'``, the condition vector will be the mean of the learnt distributions, otherwise samples from the distribution. kwargs Keyword arguments for the predict function, i.e. - :meth:`cellflow.solvers.OTFlowMatching.predict` or :meth:`cellflow.solvers.GENOT.predict`. + :meth:`scaleflow.solvers.OTFlowMatching.predict` or :meth:`scaleflow.solvers.GENOT.predict`. Returns ------- @@ -679,7 +697,7 @@ def get_condition_embedding( """Get the embedding of the conditions. Outputs the mean and variance of the learnt embeddings - generated by the :class:`~cellflow.networks.ConditionEncoder`. + generated by the :class:`~scaleflow.networks.ConditionEncoder`. Parameters ---------- @@ -687,8 +705,8 @@ def get_condition_embedding( Can be one of - a :class:`~pandas.DataFrame` defining the conditions with the same columns as the - :class:`~anndata.AnnData` used for the initialisation of :class:`~cellflow.model.CellFlow`. - - an instance of :class:`~cellflow.data.ConditionData`. + :class:`~anndata.AnnData` used for the initialisation of :class:`~scaleflow.model.CellFlow`. + - an instance of :class:`~scaleflow.data.ConditionData`. rep_dict Dictionary containing the representations of the perturbation covariates. Will be considered an @@ -756,7 +774,7 @@ def save( """ Save the model. - Pickles the :class:`~cellflow.model.CellFlow` object. + Pickles the :class:`~scaleflow.model.CellFlow` object. Parameters ---------- @@ -789,7 +807,7 @@ def load( filename: str, ) -> "CellFlow": """ - Load a :class:`~cellflow.model.CellFlow` model from a saved instance. + Load a :class:`~scaleflow.model.CellFlow` model from a saved instance. Parameters ---------- @@ -837,7 +855,7 @@ def validation_data(self) -> dict[str, ValidationData]: @property def data_manager(self) -> DataManager: - """The data manager, initialised with :attr:`cellflow.model.CellFlow.adata`.""" + """The data manager, initialised with :attr:`scaleflow.model.CellFlow.adata`.""" return self._dm @property diff --git a/src/cellflow/model/_utils.py b/src/scaleflow/model/_utils.py similarity index 96% rename from src/cellflow/model/_utils.py rename to src/scaleflow/model/_utils.py index 76384b38..920bbf77 100644 --- a/src/cellflow/model/_utils.py +++ b/src/scaleflow/model/_utils.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike def _multivariate_normal( diff --git a/src/cellflow/networks/__init__.py b/src/scaleflow/networks/__init__.py similarity index 70% rename from src/cellflow/networks/__init__.py rename to src/scaleflow/networks/__init__.py index e8051b1c..2716121d 100644 --- a/src/cellflow/networks/__init__.py +++ b/src/scaleflow/networks/__init__.py @@ -1,7 +1,7 @@ -from cellflow.networks._set_encoders import ( +from scaleflow.networks._set_encoders import ( ConditionEncoder, ) -from cellflow.networks._utils import ( +from scaleflow.networks._utils import ( FilmBlock, MLPBlock, ResNetBlock, @@ -10,7 +10,7 @@ SelfAttentionBlock, TokenAttentionPooling, ) -from cellflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField +from scaleflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField __all__ = [ "ConditionalVelocityField", diff --git a/src/cellflow/networks/_set_encoders.py b/src/scaleflow/networks/_set_encoders.py similarity index 98% rename from src/cellflow/networks/_set_encoders.py rename to src/scaleflow/networks/_set_encoders.py index 74279872..8c334233 100644 --- a/src/cellflow/networks/_set_encoders.py +++ b/src/scaleflow/networks/_set_encoders.py @@ -9,8 +9,8 @@ from flax.training import train_state from flax.typing import FrozenDict -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.networks import _utils as nn_utils +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.networks import _utils as nn_utils __all__ = [ "ConditionEncoder", diff --git a/src/cellflow/networks/_utils.py b/src/scaleflow/networks/_utils.py similarity index 99% rename from src/cellflow/networks/_utils.py rename to src/scaleflow/networks/_utils.py index 3441330c..a6a72da5 100644 --- a/src/cellflow/networks/_utils.py +++ b/src/scaleflow/networks/_utils.py @@ -7,7 +7,7 @@ from flax import linen as nn from flax.linen import initializers -from cellflow._types import Layers_t +from scaleflow._types import Layers_t __all__ = [ "SelfAttention", diff --git a/src/cellflow/networks/_velocity_field.py b/src/scaleflow/networks/_velocity_field.py similarity index 85% rename from src/cellflow/networks/_velocity_field.py rename to src/scaleflow/networks/_velocity_field.py index 157ad4d8..f97bcde7 100644 --- a/src/cellflow/networks/_velocity_field.py +++ b/src/scaleflow/networks/_velocity_field.py @@ -9,11 +9,11 @@ from flax import linen as nn from flax.training import train_state -from cellflow._types import Layers_separate_input_t, Layers_t -from cellflow.networks._set_encoders import ConditionEncoder -from cellflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder +from scaleflow._types import Layers_separate_input_t, Layers_t +from scaleflow.networks._set_encoders import ConditionEncoder +from scaleflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder -__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField"] +__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField", "MultiTaskConditionalVelocityField"] class ConditionalVelocityField(nn.Module): @@ -238,7 +238,7 @@ def get_condition_embedding(self, condition: dict[str, jnp.ndarray]) -> tuple[jn Returns ------- Learnt mean and log-variance of the condition embedding. - If :attr:`cellflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance + If :attr:`scaleflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance is set to zero. """ condition_mean, condition_logvar = self.condition_encoder(condition, training=False) @@ -587,3 +587,107 @@ def create_train_state( train=False, )["params"] return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer) + + +class MultiTaskConditionalVelocityField(ConditionalVelocityField): + """Extended ConditionalVelocityField with phenotype prediction capability. + + This class extends the standard velocity field to support both flow matching + and phenotype prediction tasks, enabling transfer learning between single-cell + and phenotypic data through shared condition encodings. + + Parameters + ---------- + phenotype_head_dims + Dimensions of the phenotype prediction head layers. + phenotype_output_dim + Output dimension for phenotype prediction (typically 1 for scalar outputs). + phenotype_dropout + Dropout rate for the phenotype prediction head. + + All other parameters are inherited from ConditionalVelocityField. + """ + + phenotype_head_dims: tuple[int, ...] = (128, 64, 32) + phenotype_output_dim: int = 1 + phenotype_dropout: float = 0.1 + + def setup(self): + """Initialize both flow matching and phenotype prediction components.""" + # Initialize parent components for flow matching + super().setup() + + # Add phenotype prediction head + from scaleflow.networks._utils import MLPBlock + self.phenotype_head = MLPBlock( + dims=self.phenotype_head_dims, + act_fn=self.act_fn, + dropout_rate=self.phenotype_dropout, + act_last_layer=True, + ) + + def __call__( + self, + t: jnp.ndarray, + x_t: jnp.ndarray, + cond: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + train: bool = True, + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward pass computing both flow matching and phenotype outputs. + + Returns + ------- + tuple + (flow_output, cond_mean, cond_logvar, phenotype_output) + """ + # Get original flow matching outputs + flow_output, cond_mean, cond_logvar = super().__call__(t, x_t, cond, encoder_noise, train) + + # Compute phenotype prediction using the same condition embedding + if self.condition_mode == "deterministic": + cond_embedding = cond_mean + else: + cond_embedding = cond_mean + encoder_noise * jnp.exp(cond_logvar / 2.0) + + # Apply dropout and phenotype head + phenotype_output = self.phenotype_head(cond_embedding, training=train) + + return flow_output, cond_mean, cond_logvar, phenotype_output + + def predict_phenotype( + self, + cond: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray | None = None, + train: bool = False + ) -> jnp.ndarray: + """Predict phenotype values directly from conditions. + + Parameters + ---------- + cond + Condition dictionary. + encoder_noise + Noise for stochastic condition encoding. If None, uses zeros. + train + Whether in training mode. + + Returns + ------- + Phenotype predictions. + """ + # Get condition embeddings + cond_mean, cond_logvar = self.condition_encoder(cond, training=train) + + # Handle encoder noise + if encoder_noise is None: + encoder_noise = jnp.zeros((cond_mean.shape[0], self.condition_embedding_dim)) + + # Compute condition embedding + if self.condition_mode == "deterministic": + cond_embedding = cond_mean + else: + cond_embedding = cond_mean + encoder_noise * jnp.exp(cond_logvar / 2.0) + + # Apply dropout and phenotype head + return self.phenotype_head(cond_embedding, training=train) diff --git a/src/scaleflow/plotting/__init__.py b/src/scaleflow/plotting/__init__.py new file mode 100644 index 00000000..45364b19 --- /dev/null +++ b/src/scaleflow/plotting/__init__.py @@ -0,0 +1,3 @@ +from scaleflow.plotting._plotting import plot_condition_embedding + +__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/plotting/_plotting.py b/src/scaleflow/plotting/_plotting.py similarity index 96% rename from src/cellflow/plotting/_plotting.py rename to src/scaleflow/plotting/_plotting.py index 389f37d0..6cfbeef7 100644 --- a/src/cellflow/plotting/_plotting.py +++ b/src/scaleflow/plotting/_plotting.py @@ -7,8 +7,8 @@ import seaborn as sns from adjustText import adjust_text -from cellflow import _constants -from cellflow.plotting._utils import ( +from scaleflow import _constants +from scaleflow.plotting._utils import ( _compute_kernel_pca_from_df, _compute_pca_from_df, _compute_umap_from_df, @@ -38,7 +38,7 @@ def plot_condition_embedding( df A :class:`pandas.DataFrame` with embedding and metadata. Column names of embedding dimensions should be consecutive integers starting from 0, - e.g. as output from :meth:`~cellflow.model.CellFlow.get_condition_embedding`, and + e.g. as output from :meth:`~scaleflow.model.CellFlow.get_condition_embedding`, and metadata should be in columns with strings. embedding Embedding to plot. Options are "raw_embedding", "UMAP", "PCA", "Kernel_PCA". diff --git a/src/cellflow/plotting/_utils.py b/src/scaleflow/plotting/_utils.py similarity index 98% rename from src/cellflow/plotting/_utils.py rename to src/scaleflow/plotting/_utils.py index 6378122d..e89b805d 100644 --- a/src/cellflow/plotting/_utils.py +++ b/src/scaleflow/plotting/_utils.py @@ -9,7 +9,7 @@ from sklearn.decomposition import KernelPCA from sklearn.metrics.pairwise import cosine_similarity -from cellflow import _constants, _logging +from scaleflow import _constants, _logging def set_plotting_vars( diff --git a/src/scaleflow/preprocessing/__init__.py b/src/scaleflow/preprocessing/__init__.py new file mode 100644 index 00000000..36e1bff1 --- /dev/null +++ b/src/scaleflow/preprocessing/__init__.py @@ -0,0 +1,9 @@ +from scaleflow.preprocessing._gene_emb import ( + GeneInfo, + get_esm_embedding, + prot_sequence_from_ensembl, + protein_features_from_genes, +) +from scaleflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca +from scaleflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints +from scaleflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/preprocessing/_gene_emb.py b/src/scaleflow/preprocessing/_gene_emb.py similarity index 99% rename from src/cellflow/preprocessing/_gene_emb.py rename to src/scaleflow/preprocessing/_gene_emb.py index cbddb59f..e2cbfb3e 100644 --- a/src/cellflow/preprocessing/_gene_emb.py +++ b/src/scaleflow/preprocessing/_gene_emb.py @@ -7,7 +7,7 @@ import anndata as ad import pandas as pd -from cellflow._logging import logger +from scaleflow._logging import logger try: import requests # type: ignore[import-untyped] diff --git a/src/cellflow/preprocessing/_pca.py b/src/scaleflow/preprocessing/_pca.py similarity index 99% rename from src/cellflow/preprocessing/_pca.py rename to src/scaleflow/preprocessing/_pca.py index b6b72238..6a0dc886 100644 --- a/src/cellflow/preprocessing/_pca.py +++ b/src/scaleflow/preprocessing/_pca.py @@ -3,7 +3,7 @@ import scanpy as sc from scipy.sparse import csr_matrix -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike __all__ = ["centered_pca", "reconstruct_pca", "project_pca"] diff --git a/src/cellflow/preprocessing/_preprocessing.py b/src/scaleflow/preprocessing/_preprocessing.py similarity index 98% rename from src/cellflow/preprocessing/_preprocessing.py rename to src/scaleflow/preprocessing/_preprocessing.py index a12bd627..96149d01 100644 --- a/src/cellflow/preprocessing/_preprocessing.py +++ b/src/scaleflow/preprocessing/_preprocessing.py @@ -5,9 +5,9 @@ import numpy as np import sklearn.preprocessing as preprocessing -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._utils import _to_list +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._utils import _to_list __all__ = ["encode_onehot", "annotate_compounds", "get_molecular_fingerprints"] diff --git a/src/cellflow/preprocessing/_wknn.py b/src/scaleflow/preprocessing/_wknn.py similarity index 99% rename from src/cellflow/preprocessing/_wknn.py rename to src/scaleflow/preprocessing/_wknn.py index 222a9dcf..5430f926 100644 --- a/src/cellflow/preprocessing/_wknn.py +++ b/src/scaleflow/preprocessing/_wknn.py @@ -6,8 +6,8 @@ import pandas as pd from scipy import sparse -from cellflow._logging import logger -from cellflow._types import ArrayLike +from scaleflow._logging import logger +from scaleflow._types import ArrayLike __all__ = ["compute_wknn", "transfer_labels"] diff --git a/src/scaleflow/solvers/__init__.py b/src/scaleflow/solvers/__init__.py new file mode 100644 index 00000000..35ff8cb8 --- /dev/null +++ b/src/scaleflow/solvers/__init__.py @@ -0,0 +1,4 @@ +from scaleflow.solvers._genot import GENOT +from scaleflow.solvers._otfm import OTFlowMatching + +__all__ = ["GENOT", "OTFlowMatching"] diff --git a/src/cellflow/solvers/_genot.py b/src/scaleflow/solvers/_genot.py similarity index 97% rename from src/cellflow/solvers/_genot.py rename to src/scaleflow/solvers/_genot.py index 7270ad7f..079ab922 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/scaleflow/solvers/_genot.py @@ -11,9 +11,9 @@ from ott.neural.networks import velocity_field from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.model._utils import _multivariate_normal +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.model._utils import _multivariate_normal __all__ = ["GENOT"] @@ -240,7 +240,7 @@ def predict( """Generate the push-forward of ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -257,7 +257,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. diff --git a/src/scaleflow/solvers/_multitask_otfm.py b/src/scaleflow/solvers/_multitask_otfm.py new file mode 100644 index 00000000..c8903df6 --- /dev/null +++ b/src/scaleflow/solvers/_multitask_otfm.py @@ -0,0 +1,399 @@ +from collections.abc import Callable +from functools import partial +from typing import Any + +import diffrax +import jax +import jax.numpy as jnp +import numpy as np +from flax.core import frozen_dict +from flax.training import train_state +from ott.neural.methods.flows import dynamics +from ott.solvers import utils as solver_utils + +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import MultiTaskConditionalVelocityField +from scaleflow.solvers.utils import ema_update + +__all__ = ["MultiTaskOTFlowMatching"] + + +class MultiTaskOTFlowMatching: + """Multi-task OT Flow Matching for both single-cell and phenotype prediction. + + This solver extends the standard OT Flow Matching to handle both flow matching + for single-cell data and phenotype prediction tasks, enabling transfer learning + between the two modalities through shared condition encodings. + + Parameters + ---------- + vf + Multi-task velocity field parameterized by a neural network. + probability_path + Probability path between the source and the target distributions. + match_fn + Function to match samples from the source and the target distributions. + time_sampler + Time sampler with a ``(rng, n_samples) -> time`` signature. + phenotype_loss_weight + Weight for the phenotype prediction loss relative to flow matching loss. + ema + Exponential moving average parameter for inference state. + kwargs + Keyword arguments for velocity field initialization. + """ + + def __init__( + self, + vf: MultiTaskConditionalVelocityField, + probability_path: dynamics.BaseFlow, + match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, + time_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, + phenotype_loss_weight: float = 1.0, + ema: float = 0.999, + **kwargs: Any, + ): + self._is_trained: bool = False + self.vf = vf + self.condition_encoder_mode = self.vf.condition_mode + self.condition_encoder_regularization = self.vf.regularization + self.probability_path = probability_path + self.match_fn = match_fn + self.time_sampler = time_sampler + self.phenotype_loss_weight = phenotype_loss_weight + self.ema = ema + + self.vf_state = self.vf.create_train_state(**kwargs) + self.vf_state_inference = self.vf_state + self.vf_step_fn = self._get_vf_step_fn() + + def _get_vf_step_fn(self) -> Callable: + @jax.jit + def vf_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + time: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + ): + def loss_fn( + params: jnp.ndarray, + t: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_flow, rng_encoder, rng_dropout = jax.random.split(rng, 3) + x_t = self.probability_path.compute_xt(rng_flow, t, source, target) + v_t, mean_cond, logvar_cond, _ = vf_state.apply_fn( + {"params": params}, + t, + x_t, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + u_t = self.probability_path.compute_ut(t, x_t, source, target) + flow_matching_loss = jnp.mean((v_t - u_t) ** 2) + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + return flow_matching_loss + encoder_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(vf_state.params, time, source, target, conditions, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return vf_step_fn + + def _get_phenotype_step_fn(self) -> Callable: + @jax.jit + def phenotype_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + conditions: dict[str, jnp.ndarray], + phenotype_targets: jnp.ndarray, + encoder_noise: jnp.ndarray, + ): + def phenotype_loss_fn( + params: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + phenotype_targets: jnp.ndarray, + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_encoder, rng_dropout = jax.random.split(rng, 2) + + # Create dummy inputs for flow matching components + n = phenotype_targets.shape[0] + dummy_t = jnp.zeros(n) + dummy_x = jnp.zeros((n, self.vf.output_dim)) + + # Forward pass through multi-task velocity field + _, mean_cond, logvar_cond, phenotype_pred = vf_state.apply_fn( + {"params": params}, + dummy_t, + dummy_x, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + + # Phenotype prediction loss (MSE for regression) + phenotype_loss = jnp.mean((phenotype_pred.squeeze() - phenotype_targets) ** 2) + + # Same condition regularization as flow matching + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + + return self.phenotype_loss_weight * phenotype_loss + encoder_loss + + grad_fn = jax.value_and_grad(phenotype_loss_fn) + loss, grads = grad_fn(vf_state.params, conditions, phenotype_targets, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return phenotype_step_fn + + def step_fn( + self, + rng: jnp.ndarray, + batch: dict[str, ArrayLike], + ) -> float: + """Single step function handling both flow matching and phenotype tasks. + + Parameters + ---------- + rng + Random number generator. + batch + Data batch. For flow matching: ``src_cell_data``, ``tgt_cell_data``, ``condition``. + For phenotype: ``condition``, ``phenotype_target``, ``task``. + + Returns + ------- + Loss value. + """ + task = batch.get("task", "flow_matching") + + if task == "phenotype": + return self._phenotype_step(rng, batch) + else: + return self._flow_matching_step(rng, batch) + + def _flow_matching_step(self, rng: jnp.ndarray, batch: dict[str, ArrayLike]) -> float: + """Handle flow matching step.""" + src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] + condition = batch.get("condition") + rng_resample, rng_time, rng_step_fn, rng_encoder_noise = jax.random.split(rng, 4) + n = src.shape[0] + time = self.time_sampler(rng_time, n) + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + if self.match_fn is not None: + tmat = self.match_fn(src, tgt) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) + src, tgt = src[src_ixs], tgt[tgt_ixs] + + self.vf_state, loss = self.vf_step_fn( + rng_step_fn, + self.vf_state, + time, + src, + tgt, + condition, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def _phenotype_step(self, rng: jnp.ndarray, batch: dict[str, ArrayLike]) -> float: + """Handle phenotype prediction step.""" + condition = batch["condition"] + phenotype_target = batch["phenotype_target"] + rng_step_fn, rng_encoder_noise = jax.random.split(rng, 2) + n = phenotype_target.shape[0] + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + phenotype_step_fn = self._get_phenotype_step_fn() + self.vf_state, loss = phenotype_step_fn( + rng_step_fn, + self.vf_state, + condition, + phenotype_target, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_numpy=True) -> ArrayLike: + """Get learnt embeddings of the conditions.""" + cond_mean, cond_logvar = self.vf.apply( + {"params": self.vf_state_inference.params}, + condition, + method="get_condition_embedding", + ) + if return_as_numpy: + return np.asarray(cond_mean), np.asarray(cond_logvar) + return cond_mean, cond_logvar + + def predict( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + task: str = "flow_matching", + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict either flow matching or phenotype outcomes. + + Parameters + ---------- + x + Input data (ignored for phenotype prediction). + condition + Condition dictionary. + rng + Random number generator. + batched + Whether to use batched prediction. + task + Either "flow_matching" or "phenotype". + kwargs + Additional arguments for ODE solver. + + Returns + ------- + Predictions based on the specified task. + """ + if task == "phenotype": + return self._predict_phenotype(condition, rng) + else: + return self._predict_flow_matching(x, condition, rng, batched, **kwargs) + + def _predict_phenotype( + self, + condition: dict[str, ArrayLike], + rng: jax.Array | None = None + ) -> ArrayLike: + """Predict phenotype values.""" + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + + # Get condition shape + first_cond = next(iter(condition.values())) + n_samples = first_cond.shape[0] + + encoder_noise = jnp.zeros((n_samples, self.vf.condition_embedding_dim)) if use_mean else \ + jax.random.normal(rng, (n_samples, self.vf.condition_embedding_dim)) + + phenotype_pred = self.vf_state_inference.apply_fn( + {"params": self.vf_state_inference.params}, + method="predict_phenotype", + cond=condition, + encoder_noise=encoder_noise, + train=False + ) + return np.array(phenotype_pred) + + def _predict_flow_matching( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict flow matching outcomes (same as original OTFM).""" + if batched and not x: + return {} + + if batched: + keys = sorted(x.keys()) + condition_keys = sorted(set().union(*(condition[k].keys() for k in keys))) + _predict_jit = jax.jit(lambda x, condition: self._predict_jit(x, condition, rng, **kwargs)) + batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0))) + n_cells = x[keys[0]].shape[0] + for k in keys: + assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition" + src_inputs = jnp.stack([x[k] for k in keys], axis=0) + batched_conditions = {} + for cond_key in condition_keys: + batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys]) + pred_targets = batched_predict(src_inputs, batched_conditions) + return {k: pred_targets[i] for i, k in enumerate(keys)} + else: + x_pred = self._predict_jit(x, condition, rng, **kwargs) + return np.array(x_pred) + + def _predict_jit( + self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any + ) -> ArrayLike: + """JIT-compiled prediction for flow matching.""" + kwargs.setdefault("dt0", None) + kwargs.setdefault("solver", diffrax.Tsit5()) + kwargs.setdefault("stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5)) + + noise_dim = (1, self.vf.condition_embedding_dim) + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim) + + def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray: + params = self.vf_state_inference.params + condition, encoder_noise = args + # Only use flow matching output (first element) + return self.vf_state_inference.apply_fn({"params": params}, t, x, condition, encoder_noise, train=False)[0] + + def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + ode_term = diffrax.ODETerm(vf) + result = diffrax.diffeqsolve( + ode_term, + t0=0.0, + t1=1.0, + y0=x, + args=(condition, encoder_noise), + **kwargs, + ) + return result.ys[0] + + x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))(x, condition, encoder_noise) + return x_pred + + @property + def is_trained(self) -> bool: + """Whether the model has been trained.""" + return self._is_trained + + @is_trained.setter + def is_trained(self, value: bool) -> None: + """Set the trained status.""" + self._is_trained = value diff --git a/src/cellflow/solvers/_otfm.py b/src/scaleflow/solvers/_otfm.py similarity index 95% rename from src/cellflow/solvers/_otfm.py rename to src/scaleflow/solvers/_otfm.py index 31114a6b..e987b8e1 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/scaleflow/solvers/_otfm.py @@ -11,10 +11,10 @@ from ott.neural.methods.flows import dynamics from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.networks._velocity_field import ConditionalVelocityField -from cellflow.solvers.utils import ema_update +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import ConditionalVelocityField +from scaleflow.solvers.utils import ema_update __all__ = ["OTFlowMatching"] @@ -34,14 +34,14 @@ class OTFlowMatching: match_fn Function to match samples from the source and the target distributions. It has a ``(src, tgt) -> matching`` signature, - see e.g. :func:`cellflow.utils.match_linear`. If :obj:`None`, no + see e.g. :func:`scaleflow.utils.match_linear`. If :obj:`None`, no matching is performed, and pure probability_path matching :cite:`lipman:22` is applied. time_sampler Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g. :func:`ott.solvers.utils.uniform_sampler`. kwargs - Keyword arguments for :meth:`cellflow.networks.ConditionalVelocityField.create_train_state`. + Keyword arguments for :meth:`scaleflow.networks.ConditionalVelocityField.create_train_state`. """ def __init__( @@ -231,7 +231,7 @@ def predict( """Predict the translated source ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -249,7 +249,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. diff --git a/src/cellflow/solvers/utils.py b/src/scaleflow/solvers/utils.py similarity index 100% rename from src/cellflow/solvers/utils.py rename to src/scaleflow/solvers/utils.py diff --git a/src/cellflow/training/__init__.py b/src/scaleflow/training/__init__.py similarity index 79% rename from src/cellflow/training/__init__.py rename to src/scaleflow/training/__init__.py index 387411d2..c19a50dd 100644 --- a/src/cellflow/training/__init__.py +++ b/src/scaleflow/training/__init__.py @@ -1,4 +1,4 @@ -from cellflow.training._callbacks import ( +from scaleflow.training._callbacks import ( BaseCallback, CallbackRunner, ComputationCallback, @@ -8,7 +8,7 @@ VAEDecodedMetrics, WandbLogger, ) -from cellflow.training._trainer import CellFlowTrainer +from scaleflow.training._trainer import CellFlowTrainer __all__ = [ "CellFlowTrainer", diff --git a/src/cellflow/training/_callbacks.py b/src/scaleflow/training/_callbacks.py similarity index 93% rename from src/cellflow/training/_callbacks.py rename to src/scaleflow/training/_callbacks.py index 5b65f33f..92a4524d 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/scaleflow/training/_callbacks.py @@ -7,14 +7,14 @@ import jax.tree_util as jtu import numpy as np -from cellflow._types import ArrayLike -from cellflow.metrics._metrics import ( +from scaleflow._types import ArrayLike +from scaleflow.metrics._metrics import ( compute_e_distance_fast, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div, ) -from cellflow.solvers import _genot, _otfm +from scaleflow.solvers import _genot, _otfm __all__ = [ "BaseCallback", @@ -42,7 +42,7 @@ class BaseCallback(abc.ABC): - """Base class for callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self, *args: Any, **kwargs: Any) -> None: @@ -61,7 +61,7 @@ def on_train_end(self, *args: Any, **kwargs: Any) -> Any: class LoggingCallback(BaseCallback, abc.ABC): - """Base class for logging callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -92,7 +92,7 @@ def on_train_end(self, dict_to_log: dict[str, Any]) -> Any: class ComputationCallback(BaseCallback, abc.ABC): - """Base class for computation callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for computation callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -118,7 +118,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -146,7 +146,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -205,7 +205,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -240,7 +240,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -299,7 +299,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -328,7 +328,7 @@ class VAEDecodedMetrics(Metrics): ---------- vae A VAE model object with a ``'get_reconstruction'`` method, can be an instance - of :class:`cellflow.external.CFJaxSCVI`. + of :class:`scaleflow.external.CFJaxSCVI`. adata An :class:`~anndata.AnnData` object in the same format as the ``vae``. metrics @@ -374,7 +374,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -477,14 +477,14 @@ def on_train_end(self, dict_to_log: dict[str, float]) -> Any: class CallbackRunner: - """Runs a set of computational and logging callbacks in the :class:`~cellflow.training.CellFlowTrainer` + """Runs a set of computational and logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer` Parameters ---------- callbacks List of callbacks to run. Callbacks should be of type - :class:`~cellflow.training.ComputationCallback` or - :class:`~cellflow.training.LoggingCallback` + :class:`~scaleflow.training.ComputationCallback` or + :class:`~scaleflow.training.LoggingCallback` Returns ------- @@ -529,7 +529,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -565,7 +565,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns diff --git a/src/cellflow/training/_trainer.py b/src/scaleflow/training/_trainer.py similarity index 80% rename from src/cellflow/training/_trainer.py rename to src/scaleflow/training/_trainer.py index 5a359df4..125c750a 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/scaleflow/training/_trainer.py @@ -6,9 +6,9 @@ from numpy.typing import ArrayLike from tqdm import tqdm -from cellflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback, CallbackRunner +from scaleflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler +from scaleflow.solvers import _genot, _otfm +from scaleflow.training._callbacks import BaseCallback, CallbackRunner class CellFlowTrainer: @@ -19,12 +19,12 @@ class CellFlowTrainer: dataloader Data sampler. solver - :class:`~cellflow.solvers._otfm.OTFlowMatching` or - :class:`~cellflow.solvers._genot.GENOT` solver with a conditional velocity field. + :class:`~scaleflow.solvers._otfm.OTFlowMatching` or + :class:`~scaleflow.solvers._genot.GENOT` solver with a conditional velocity field. predict_kwargs Keyword arguments for the prediction functions - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. seed Random seed for subsampling validation data. @@ -61,15 +61,34 @@ def _validation_step( valid_source_data: dict[str, dict[str, ArrayLike]] = {} valid_pred_data: dict[str, dict[str, ArrayLike]] = {} valid_true_data: dict[str, dict[str, ArrayLike]] = {} - for val_key, vdl in val_data.items(): + + # Add progress bar for validation + val_pbar = tqdm(val_data.items(), desc="Validation", leave=False) + for val_key, vdl in val_pbar: batch = vdl.sample(mode=mode) src = batch["source"] + print(len(src)) + key0 = list(src.keys())[0] + key1 = list(src.keys())[1] + key2 = list(src.keys())[2] + print(key0) + print(key1) + print(key2) + print(src[key0].shape) + print(src[key1].shape) + print(src[key2].shape) + print(batch["condition"][key0]) + print(batch["condition"][key1]) condition = batch.get("condition", None) true_tgt = batch["target"] valid_source_data[val_key] = src valid_pred_data[val_key] = self.solver.predict(src, condition=condition, **self.predict_kwargs) valid_true_data[val_key] = true_tgt + print("Predictions done") + # Update progress bar description with current validation set + val_pbar.set_description(f"Validation ({val_key})") + return valid_source_data, valid_true_data, valid_pred_data def _update_logs(self, logs: dict[str, Any]) -> None: diff --git a/src/cellflow/training/_utils.py b/src/scaleflow/training/_utils.py similarity index 100% rename from src/cellflow/training/_utils.py rename to src/scaleflow/training/_utils.py diff --git a/src/cellflow/utils.py b/src/scaleflow/utils.py similarity index 100% rename from src/cellflow/utils.py rename to src/scaleflow/utils.py From 5dd738a10032c23efadcdbab8e068b324395429c Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 29 Sep 2025 15:51:35 +0200 Subject: [PATCH 31/35] idk if this works --- README.md | 20 +-- .../{cellflow_dark.png => scaleflow_dark.png} | Bin ...cellflow_light.png => scaleflow_light.png} | Bin docs/conf.py | 20 +-- docs/developer.rst | 4 +- docs/index.rst | 10 +- docs/installation.rst | 12 +- docs/notebooks/100_pbmc.ipynb | 74 +++++----- docs/notebooks/200_zebrafish.ipynb | 44 +++--- docs/notebooks/201_zebrafish_continuous.ipynb | 44 +++--- docs/notebooks/300_ineuron_tutorial.ipynb | 36 ++--- docs/notebooks/500_combosciplex.ipynb | 50 +++---- docs/notebooks/600_trainsampler copy.ipynb | 4 +- docs/notebooks/tahoe_sizes.ipynb | 2 +- docs/user/datasets.rst | 4 +- docs/user/external.rst | 4 +- docs/user/index.rst | 2 +- docs/user/metrics.rst | 4 +- docs/user/model.rst | 4 +- docs/user/networks.rst | 4 +- docs/user/plotting.rst | 4 +- docs/user/preprocessing.rst | 4 +- docs/user/solvers.rst | 4 +- docs/user/training.rst | 4 +- docs/user/utils.rst | 4 +- pyproject.toml | 20 +-- scripts/process_tahoe.py | 4 +- src/cellflow/__init__.py | 4 - src/cellflow/external/__init__.py | 6 - src/cellflow/model/__init__.py | 3 - src/cellflow/plotting/__init__.py | 3 - src/cellflow/preprocessing/__init__.py | 9 -- src/cellflow/solvers/__init__.py | 4 - src/scaleflow/__init__.py | 4 + src/{cellflow => scaleflow}/_constants.py | 4 +- src/{cellflow => scaleflow}/_logging.py | 0 src/{cellflow => scaleflow}/_optional.py | 2 +- src/{cellflow => scaleflow}/_types.py | 0 .../compat/__init__.py | 0 src/{cellflow => scaleflow}/compat/torch_.py | 2 +- src/{cellflow => scaleflow}/data/__init__.py | 10 +- src/{cellflow => scaleflow}/data/_data.py | 4 +- .../data/_dataloader.py | 8 +- .../data/_datamanager.py | 10 +- .../data/_jax_dataloader.py | 4 +- .../data/_torch_dataloader.py | 6 +- src/{cellflow => scaleflow}/data/_utils.py | 0 src/{cellflow => scaleflow}/datasets.py | 8 +- src/scaleflow/external/__init__.py | 6 + src/{cellflow => scaleflow}/external/_scvi.py | 4 +- .../external/_scvi_utils.py | 0 .../metrics/__init__.py | 2 +- .../metrics/_metrics.py | 0 src/scaleflow/model/__init__.py | 3 + .../model/_scaleflow.py} | 132 +++++++++--------- src/{cellflow => scaleflow}/model/_utils.py | 2 +- .../networks/__init__.py | 6 +- .../networks/_set_encoders.py | 4 +- .../networks/_utils.py | 2 +- .../networks/_velocity_field.py | 8 +- src/scaleflow/plotting/__init__.py | 3 + .../plotting/_plotting.py | 6 +- .../plotting/_utils.py | 2 +- src/scaleflow/preprocessing/__init__.py | 9 ++ .../preprocessing/_gene_emb.py | 4 +- .../preprocessing/_pca.py | 2 +- .../preprocessing/_preprocessing.py | 6 +- .../preprocessing/_wknn.py | 4 +- src/scaleflow/solvers/__init__.py | 4 + src/{cellflow => scaleflow}/solvers/_genot.py | 10 +- src/{cellflow => scaleflow}/solvers/_otfm.py | 16 +-- src/{cellflow => scaleflow}/solvers/utils.py | 0 .../training/__init__.py | 4 +- .../training/_callbacks.py | 36 ++--- .../training/_trainer.py | 14 +- .../training/_utils.py | 0 src/{cellflow => scaleflow}/utils.py | 0 tests/conftest.py | 2 +- tests/data/test_cfsampler.py | 12 +- tests/data/test_datamanager.py | 22 +-- tests/data/test_old_get_condition_data.py | 4 +- tests/data/test_torch_dataloader.py | 6 +- tests/external/test_CFJaxSCVI.py | 2 +- tests/metrics/test_metrics.py | 26 ++-- .../{test_cellflow.py => test_scaleflow.py} | 34 ++--- tests/networks/test_aggregators.py | 4 +- tests/networks/test_condencoder.py | 4 +- tests/networks/test_velocityfield.py | 2 +- tests/plotting/test_plotting.py | 2 +- tests/preprocessing/test_gene_emb.py | 2 +- tests/preprocessing/test_pca.py | 22 +-- tests/preprocessing/test_preprocessing.py | 12 +- tests/preprocessing/test_wknn.py | 18 +-- tests/solver/test_solver.py | 16 +-- tests/trainer/test_callbacks.py | 6 +- tests/trainer/test_trainer.py | 30 ++-- 96 files changed, 496 insertions(+), 496 deletions(-) rename docs/_static/images/{cellflow_dark.png => scaleflow_dark.png} (100%) rename docs/_static/images/{cellflow_light.png => scaleflow_light.png} (100%) delete mode 100644 src/cellflow/__init__.py delete mode 100644 src/cellflow/external/__init__.py delete mode 100644 src/cellflow/model/__init__.py delete mode 100644 src/cellflow/plotting/__init__.py delete mode 100644 src/cellflow/preprocessing/__init__.py delete mode 100644 src/cellflow/solvers/__init__.py create mode 100644 src/scaleflow/__init__.py rename src/{cellflow => scaleflow}/_constants.py (57%) rename src/{cellflow => scaleflow}/_logging.py (100%) rename src/{cellflow => scaleflow}/_optional.py (67%) rename src/{cellflow => scaleflow}/_types.py (100%) rename src/{cellflow => scaleflow}/compat/__init__.py (100%) rename src/{cellflow => scaleflow}/compat/torch_.py (86%) rename src/{cellflow => scaleflow}/data/__init__.py (66%) rename src/{cellflow => scaleflow}/data/_data.py (99%) rename src/{cellflow => scaleflow}/data/_dataloader.py (98%) rename src/{cellflow => scaleflow}/data/_datamanager.py (99%) rename src/{cellflow => scaleflow}/data/_jax_dataloader.py (97%) rename src/{cellflow => scaleflow}/data/_torch_dataloader.py (94%) rename src/{cellflow => scaleflow}/data/_utils.py (100%) rename src/{cellflow => scaleflow}/datasets.py (94%) create mode 100644 src/scaleflow/external/__init__.py rename src/{cellflow => scaleflow}/external/_scvi.py (98%) rename src/{cellflow => scaleflow}/external/_scvi_utils.py (100%) rename src/{cellflow => scaleflow}/metrics/__init__.py (92%) rename src/{cellflow => scaleflow}/metrics/_metrics.py (100%) create mode 100644 src/scaleflow/model/__init__.py rename src/{cellflow/model/_cellflow.py => scaleflow/model/_scaleflow.py} (87%) rename src/{cellflow => scaleflow}/model/_utils.py (96%) rename src/{cellflow => scaleflow}/networks/__init__.py (70%) rename src/{cellflow => scaleflow}/networks/_set_encoders.py (98%) rename src/{cellflow => scaleflow}/networks/_utils.py (99%) rename src/{cellflow => scaleflow}/networks/_velocity_field.py (98%) create mode 100644 src/scaleflow/plotting/__init__.py rename src/{cellflow => scaleflow}/plotting/_plotting.py (96%) rename src/{cellflow => scaleflow}/plotting/_utils.py (98%) create mode 100644 src/scaleflow/preprocessing/__init__.py rename src/{cellflow => scaleflow}/preprocessing/_gene_emb.py (99%) rename src/{cellflow => scaleflow}/preprocessing/_pca.py (99%) rename src/{cellflow => scaleflow}/preprocessing/_preprocessing.py (98%) rename src/{cellflow => scaleflow}/preprocessing/_wknn.py (99%) create mode 100644 src/scaleflow/solvers/__init__.py rename src/{cellflow => scaleflow}/solvers/_genot.py (97%) rename src/{cellflow => scaleflow}/solvers/_otfm.py (95%) rename src/{cellflow => scaleflow}/solvers/utils.py (100%) rename src/{cellflow => scaleflow}/training/__init__.py (79%) rename src/{cellflow => scaleflow}/training/_callbacks.py (93%) rename src/{cellflow => scaleflow}/training/_trainer.py (91%) rename src/{cellflow => scaleflow}/training/_utils.py (100%) rename src/{cellflow => scaleflow}/utils.py (100%) rename tests/model/{test_cellflow.py => test_scaleflow.py} (95%) diff --git a/README.md b/README.md index e0b0f9d3..3cac74ea 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ -CellFlow +CellFlow -[![PyPI](https://img.shields.io/pypi/v/cellflow-tools.svg)](https://pypi.org/project/cellflow-tools/) -[![Downloads](https://static.pepy.tech/badge/cellflow-tools)](https://pepy.tech/project/cellflow-tools) -[![CI](https://img.shields.io/github/actions/workflow/status/theislab/cellflow/test.yaml?branch=main)](https://github.com/theislab/cellflow/actions) +[![PyPI](https://img.shields.io/pypi/v/scaleflow-tools.svg)](https://pypi.org/project/scaleflow-tools/) +[![Downloads](https://static.pepy.tech/badge/scaleflow-tools)](https://pepy.tech/project/scaleflow-tools) +[![CI](https://img.shields.io/github/actions/workflow/status/theislab/scaleflow/test.yaml?branch=main)](https://github.com/theislab/scaleflow/actions) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/theislab/CellFlow/main.svg)](https://results.pre-commit.ci/latest/github/theislab/CellFlow/main) -[![Codecov](https://codecov.io/gh/theislab/cellflow/branch/main/graph/badge.svg?token=Rgtm5Tsblo)](https://codecov.io/gh/theislab/cellflow) -[![Docs](https://img.shields.io/readthedocs/cellflow)](https://cellflow.readthedocs.io/en/latest/) +[![Codecov](https://codecov.io/gh/theislab/scaleflow/branch/main/graph/badge.svg?token=Rgtm5Tsblo)](https://codecov.io/gh/theislab/scaleflow) +[![Docs](https://img.shields.io/readthedocs/scaleflow)](https://scaleflow.readthedocs.io/en/latest/) CellFlow - Modeling Complex Perturbations with Flow Matching ============================================================ @@ -21,20 +21,20 @@ Check out the [preprint](https://www.biorxiv.org/content/10.1101/2025.04.11.6482 - Modeling the development of perturbed organisms - Cell fate engineering - Optimizing protocols for growing organoids -- ... and more; check out the [documentation](https://cellflow.readthedocs.io) for more information. +- ... and more; check out the [documentation](https://scaleflow.readthedocs.io) for more information. Installation ------------ Install **CellFlow** by running:: - pip install cellflow-tools + pip install scaleflow-tools In order to install **CellFlow** in editable mode, run:: - git clone https://github.com/theislab/cellflow - cd cellflow + git clone https://github.com/theislab/scaleflow + cd scaleflow pip install -e . For further instructions how to install jax, please refer to https://github.com/google/jax. diff --git a/docs/_static/images/cellflow_dark.png b/docs/_static/images/scaleflow_dark.png similarity index 100% rename from docs/_static/images/cellflow_dark.png rename to docs/_static/images/scaleflow_dark.png diff --git a/docs/_static/images/cellflow_light.png b/docs/_static/images/scaleflow_light.png similarity index 100% rename from docs/_static/images/cellflow_light.png rename to docs/_static/images/scaleflow_light.png diff --git a/docs/conf.py b/docs/conf.py index b21eedfd..46cdf374 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,15 +13,15 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -import cellflow +import scaleflow sys.path.insert(0, str(Path(__file__).parent / "extensions")) # -- Project information ----------------------------------------------------- -project = cellflow.__name__ -author = "CellFlow team" -version = ilm.version("cellflow-tools") +project = scaleflow.__name__ +author = "ScaleFlow team" +version = ilm.version("scaleflow-tools") copyright = f"{datetime.now():%Y}, Theislab" # -- General configuration --------------------------------------------------- @@ -67,9 +67,9 @@ ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ - (r"py:class", r"cellflow\..*(K|B|O)"), - (r"py:class", r"cellflow\._typing.*"), - (r"py:class", r"cellflow\..*Protocol.*"), + (r"py:class", r"scaleflow\..*(K|B|O)"), + (r"py:class", r"scaleflow\._typing.*"), + (r"py:class", r"scaleflow\..*Protocol.*"), ] @@ -152,8 +152,8 @@ html_show_sourcelink = False html_theme_options = { "sidebar_hide_name": True, - "light_logo": "images/cellflow_dark.png", - "dark_logo": "images/cellflow_dark.png", + "light_logo": "images/scaleflow_dark.png", + "dark_logo": "images/scaleflow_dark.png", "light_css_variables": { "color-brand-primary": "#003262", "color-brand-content": "#003262", @@ -164,7 +164,7 @@ "footer_icons": [ { "name": "GitHub", - "url": "https://github.com/theislab/cellflow", + "url": "https://github.com/theislab/scaleflow", "html": "", "class": "fab fa-github", }, diff --git a/docs/developer.rst b/docs/developer.rst index a8e16530..8eecafa1 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -4,8 +4,8 @@ Developer API CellFlow model ~~~~~~~~~~~~~~ -.. module:: cellflow.data -.. currentmodule:: cellflow.data +.. module:: scaleflow.data +.. currentmodule:: scaleflow.data .. autosummary:: :toctree: genapi diff --git a/docs/index.rst b/docs/index.rst index a1e64452..0291acc6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,11 +1,11 @@ CellFlow ===================== -.. module:: cellflow +.. module:: scaleflow -:mod:`cellflow` is a framework for modeling single-cell perturbation screens. CellFlow is very flexible and enables researchers to systematically explore how cells respond to a wide range of experimental interventions, including drug treatments, genetic modifications, cytokine stimulation, morphogen pathway modulation or even entire organoid protocols. +:mod:`scaleflow` is a framework for modeling single-cell perturbation screens. CellFlow is very flexible and enables researchers to systematically explore how cells respond to a wide range of experimental interventions, including drug treatments, genetic modifications, cytokine stimulation, morphogen pathway modulation or even entire organoid protocols. -:note: This is a work in progress. We are actively working on extending the documentation of :mod:`cellflow` with more tutorials to cover a wide range of use cases. If you have any questions or suggestions, please feel free to reach out to us. +:note: This is a work in progress. We are actively working on extending the documentation of :mod:`scaleflow` with more tutorials to cover a wide range of use cases. If you have any questions or suggestions, please feel free to reach out to us. .. grid:: 3 @@ -15,13 +15,13 @@ CellFlow :link: installation :link-type: doc - Learn how to install :mod:`cellflow`. + Learn how to install :mod:`scaleflow`. .. grid-item-card:: User API :link: user/index :link-type: doc - The API reference with all the details on how to use :mod:`cellflow` functions. + The API reference with all the details on how to use :mod:`scaleflow` functions. .. grid-item-card:: Manuscript diff --git a/docs/installation.rst b/docs/installation.rst index 9e8f5711..4baf5d48 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,22 +1,22 @@ Installation ============ -:mod:`cellflow` requires Python version >= 3.11 to run. +:mod:`scaleflow` requires Python version >= 3.11 to run. PyPI ---- -Install :mod:`cellflow` by running:: +Install :mod:`scaleflow` by running:: - pip install cellflow-tools + pip install scaleflow-tools Installing `rapids-singlecell` and `cuml`: -While it's not necessary to install :mod:`cellflow` with `rapids-singlecell` and `cuml`, +While it's not necessary to install :mod:`scaleflow` with `rapids-singlecell` and `cuml`, it is recommended to do so for faster preprocessing or downstream functions. -To install :mod:`cellflow` with `rapids-singlecell` and `cuml`, please refer to +To install :mod:`scaleflow` with `rapids-singlecell` and `cuml`, please refer to `instructions how to install rapids `_. Development version ------------------- -To install :mod:`cellflow` from `GitHub `_, run:: +To install :mod:`scaleflow` from `GitHub `_, run:: pip install git+https://github.com/theislab/CellFlow.git@main diff --git a/docs/notebooks/100_pbmc.ipynb b/docs/notebooks/100_pbmc.ipynb index 4260e0cd..b28b9b68 100644 --- a/docs/notebooks/100_pbmc.ipynb +++ b/docs/notebooks/100_pbmc.ipynb @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -66,13 +66,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.pbmc_cytokines()" + "adata = scaleflow.datasets.pbmc_cytokines()" ] }, { @@ -444,7 +444,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the solver `\"otfm\"`, which deterministically maps a cell to its perturbed equivalent. If we wanted to incorporate stochasticity on single-cell level, we would select `\"genot\"`." ] @@ -464,7 +464,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -476,9 +476,9 @@ "\n", "The `perturbation_covariates` indicates the external intervention, i.e. the cytokine treatment. We define a key (of arbitrary name) `\"cytokine_treatment\"` for this, and have the values be tuples with the perturbation and potential perturbation covariates. As we don't have a perturbation covariate (e.g. always the same dose), we only have one tuple, and as we don't observe combinations of treatments, the tuple has length 1. We use ESM2 embeddings for representing the cytokines, which we have precomputed already for the purpose of this notebook, saved in {attr}`uns['esm2_embeddings'] `. Thus, we pass the information that `\"esm2_embeddings\"` stores embeddings of the {attr}`obs['cytokine'] ` treatments via `perturbation_covariate_reps`.\n", "\n", - "The sample covariate describes the cellular context independent of the perturbation. In our case, these are donors, and given in the {attr}`obs['donor'] ` column. We use the mean of the control sample as donor representation, precomputed and saved in {attr}`uns['donor_embeddings'] `. We thus pass this piece of information to {class}`~cellflow.model.CellFlow` via `sample_covariate_reps`. \n", + "The sample covariate describes the cellular context independent of the perturbation. In our case, these are donors, and given in the {attr}`obs['donor'] ` column. We use the mean of the control sample as donor representation, precomputed and saved in {attr}`uns['donor_embeddings'] `. We thus pass this piece of information to {class}`~scaleflow.model.CellFlow` via `sample_covariate_reps`. \n", "\n", - "It remains to define `split_covariates`, according to which {class}`~cellflow.model.CellFlow` trains and predicts perturbations. In effect, `split_covariates` defines how to split the control distributions, and often coincides with `sample_covariates`. This ensure that we don't learn a mapping from the control distribution of donor A to a perturbed population of donor B, but only within the same donor. \n", + "It remains to define `split_covariates`, according to which {class}`~scaleflow.model.CellFlow` trains and predicts perturbations. In effect, `split_covariates` defines how to split the control distributions, and often coincides with `sample_covariates`. This ensure that we don't learn a mapping from the control distribution of donor A to a perturbed population of donor B, but only within the same donor. \n", "\n", "Finally, we can pass `max_combination_length` and `null_value`. These are relevant for combinations of treatments, which doesn't apply for this use case, as we don't want to predict combinationatorial effects of cytokines. In particular, `max_combination_length` is the maximum number of combinations of cytokines which we train on or we want to eventually predict for. The null value is the token representing no treatment, e.g. relevant when we have a treatment with fewer interventions than `max_combination_length`, see tutorials with combinatorial treatments as examples." ] @@ -548,7 +548,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We can now prepare the data for validation using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. We can pass arbitrary splits, which we define with the `name` parameter. The corresponding {class}`adata ` object has to contain the true value, such that during evaluation, we can compare the generated with the true cells.\n", + "We can now prepare the data for validation using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. We can pass arbitrary splits, which we define with the `name` parameter. The corresponding {class}`adata ` object has to contain the true value, such that during evaluation, we can compare the generated with the true cells.\n", "\n", "Note that inference takes relatively long due to solving a neural ODE, hence we might not want to evaluate on the full {class}`adata ` objects, but only on a subset of conditions, the number of which we define using `n_conditions_on_log_iteration` and `n_conditions_on_train_end`. The number of cells we generate for each condition corresponds to the number of control cells, in our case to the number of control cells specific to each donor. As in this dataset the number of control cells is relatively large, we now first subsample the {class}`adata ` object to accelerate inference. " ] @@ -642,7 +642,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -650,36 +650,36 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", "We walk through the parameters one by one:\n", "\n", - "- `condition_mode` defines the structure of the learnt condition embedding space. We will use `deterministic` mode with `regularization=0.0`, which means we learn point estimates of the condition embedding. If we added `regularization>0.0`, this would mean we impose some regularization with respect to the L2-norm of the embeddings. `condition_mode=\"stochastic\"` parameterizes the embeddings space as like a decoder-free variational auto-encoder, i.e. we set a normal isotropic prior on the embeddings, this allows to learn a stochastic mapping and evaluate the uncertainty of predictions on a distributional level (rather than on a single-cell level which can be done with {class}`~cellflow.solvers.GENOT`).\n", + "- `condition_mode` defines the structure of the learnt condition embedding space. We will use `deterministic` mode with `regularization=0.0`, which means we learn point estimates of the condition embedding. If we added `regularization>0.0`, this would mean we impose some regularization with respect to the L2-norm of the embeddings. `condition_mode=\"stochastic\"` parameterizes the embeddings space as like a decoder-free variational auto-encoder, i.e. we set a normal isotropic prior on the embeddings, this allows to learn a stochastic mapping and evaluate the uncertainty of predictions on a distributional level (rather than on a single-cell level which can be done with {class}`~scaleflow.solvers.GENOT`).\n", "- `regularization`, as mentioned above, is a tradeoff between the flow matching loss (which also implicitly learns the condition embeddings space), and the regularization of the mean and potentially the variance of the embedding space. Here, we learn point-wise estimates without any prior on the embeddings space, thus setting `regularization` to 0.0.\n", "- `pooling` defines how we aggregate combinations of conditions, which doesn't apply here. Putting `\"mean\"` thus has no effect, while `\"attention_token\"` or `\"attention_seed\"` would reduce to self-attention.\n", - "- `pooling_kwargs` specifies further keyword arguments for {class}`~cellflow.networks.TokenAttentionPooling` if `pooling` is\n", - " `\"attention_token\"` or {class}`~cellflow.networks.SeedAttentionPooling` if `pooling` is `\"attention_seed\"`.\n", - "- `layers_before_pool` specifies the layers processing the perturbation variables, i.e. perturbations, perturbation covariates, and sample covariates. It must be a dictionary with keys corresponding to the keys we used in {meth}`~cellflow.model.CellFlow.prepare_data`. In this case, this means that we have keys `\"cytokine_treatment\"` and `\"donor_embeddings\"`, with values specifying the architecture, e.g. the type of the module (`\"mlp\"` or `\"self_attention\"`) and layer specifications like number of layers, width, and dropout rate.\n", + "- `pooling_kwargs` specifies further keyword arguments for {class}`~scaleflow.networks.TokenAttentionPooling` if `pooling` is\n", + " `\"attention_token\"` or {class}`~scaleflow.networks.SeedAttentionPooling` if `pooling` is `\"attention_seed\"`.\n", + "- `layers_before_pool` specifies the layers processing the perturbation variables, i.e. perturbations, perturbation covariates, and sample covariates. It must be a dictionary with keys corresponding to the keys we used in {meth}`~scaleflow.model.CellFlow.prepare_data`. In this case, this means that we have keys `\"cytokine_treatment\"` and `\"donor_embeddings\"`, with values specifying the architecture, e.g. the type of the module (`\"mlp\"` or `\"self_attention\"`) and layer specifications like number of layers, width, and dropout rate.\n", "- `layers_before_pool` specifies the architecture of the module after the pooling has been performed.\n", "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder. We set it to 64.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", - "- `condition_encoder_kwargs` specify the architecture of the {class}`~cellflow.networks.ConditionEncoder`. Here, we don't apply any more specifications.\n", + "- `condition_encoder_kwargs` specify the architecture of the {class}`~scaleflow.networks.ConditionEncoder`. Here, we don't apply any more specifications.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", "- `time_freqs` thus (deterministically) embeds the time component before being processed by a feed-forward neural network. This choice is relatively independent of the data. \n", "- `time_encoder_dims` specifies the architecture how to process the time embedding needed for the neural ODE. Note that we pre-encode the time with a sinusoidal embedding of dimension `time_freqs`. This choice is relatively independent of the data. \n", "- `time_encoder_dropout` denotes the dropout applied to the layers processing the time component. This choice is relatively independent of the data. \n", "- `hidden_dims` specifies the layers processing the control cells. The choice depends on the dimensionality of the cell embedding.\n", "- `hidden_dropout` specifies the dropout in the layers defined by `hidden_dims`.\n", - "- `conditioning` specifies the method we use to integrate the different embeddings into the model. Here, we use `\"concatenation\"`, which simply concatenates the time, condition and data embeddings into a single array. Alternative options for `conditioning` are `\"film\"`, which conditions using a {class}`~cellflow.networks.FilmBlock` based on [Perez et al.](https://arxiv.org/abs/1709.07871) and `\"resnet\"` which conditions using a {class}`~cellflow.networks.ResNetBlock` based on [He et al.](https://arxiv.org/abs/1512.03385). \n", - "- `conditioning_kwargs` provides further keyword arguments when the conditioning is not `\"concatenation\"`, e.g. it provides keywords for {class}`~cellflow.networks.FilmBlock` and {class}`~cellflow.networks.ResNetBlock`, which we don't require for this use case.\n", + "- `conditioning` specifies the method we use to integrate the different embeddings into the model. Here, we use `\"concatenation\"`, which simply concatenates the time, condition and data embeddings into a single array. Alternative options for `conditioning` are `\"film\"`, which conditions using a {class}`~scaleflow.networks.FilmBlock` based on [Perez et al.](https://arxiv.org/abs/1709.07871) and `\"resnet\"` which conditions using a {class}`~scaleflow.networks.ResNetBlock` based on [He et al.](https://arxiv.org/abs/1512.03385). \n", + "- `conditioning_kwargs` provides further keyword arguments when the conditioning is not `\"concatenation\"`, e.g. it provides keywords for {class}`~scaleflow.networks.FilmBlock` and {class}`~scaleflow.networks.ResNetBlock`, which we don't require for this use case.\n", "- `decoder_dims` specifies the layers processing the embedding of the condition, the embedding of the cell, and the embedding of the time. It depends on the dimensionality of the cell representation, i.e. the higher-dimensional the cell representation, the higher `decoder_dims` should be chosen.\n", "- `decoder_dropout` sets the dropout rate of the layers processing `decoder_dims`.\n", - "- `vf_act_fn` sets the activation function in the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` if not specified otherwise.\n", - "- `vf_kwargs` provides further keyword arguments when the solver is not `\"otfm\"`, e.g. it provides keywords for {class}`~cellflow.networks._velocity_field.GENOTConditionalVelocityField`, which we don't require for this use case.\n", - "- `probability_path` defines the path between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, which internally applies {class}`~ott.neural.methods.flows.dynamics.ConstantNoiseFlow`. This means that the paths are augmented with random normal noise. Note that the magnitude should depend on the support / variance of the cell embedding. The higher the noise, the more the data is augmented, but the less the marginal distributions are fitted. To maintain convergence on the marginals, one can use `{\"bridge\"}\n", + "- `vf_act_fn` sets the activation function in the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` if not specified otherwise.\n", + "- `vf_kwargs` provides further keyword arguments when the solver is not `\"otfm\"`, e.g. it provides keywords for {class}`~scaleflow.networks._velocity_field.GENOTConditionalVelocityField`, which we don't require for this use case.\n", + "- `probability_path` defines the path between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, which internally applies {class}`~ott.neural.methods.flows.dynamics.ConstantNoiseFlow`. This means that the paths are augmented with random normal noise. Note that the magnitude should depend on the support / variance of the cell embedding. The higher the noise, the more the data is augmented, but the less the marginal distributions are fitted. To maintain convergence on the marginals, one can use `{\"bridge\"}\n", "- `match_fn` defines how to sample pairs batch-wise. If we have largely heterogeneous populations (e.g. whole embryos), we should choose a small entropic regularisation, while for homoegeneous cell populations like cell lines, a large entropic regularisation parameter is sufficient. Moreover, we can select the hyperparameters `tau_a` and `tau_b` determining the extent of unbalancedness in the learnt coupling, see e.g. [moscot](moscot-tools.org) for an in-depth discussion of optimal transport parameters.\n", "- `optimizer` should be used with gradient averaging to have parameter updates after having seen also multiple conditions, not only multiple cells. We found 20 to be a good value, but we recommend to perform a hyperparameter search. \n", - "- `solver_kwargs` is primarily necessary for using a different {attr}`~cellflow.model.CellFlow.solver` than {class}`~cellflow.solvers.OTFlowMatching`, e.g. when using {class}`~cellflow.solvers.GENOT`. So this doesn't apply here.\n", + "- `solver_kwargs` is primarily necessary for using a different {attr}`~scaleflow.model.CellFlow.solver` than {class}`~scaleflow.solvers.OTFlowMatching`, e.g. when using {class}`~scaleflow.solvers.GENOT`. So this doesn't apply here.\n", "- `layer_norm_before_concatenation` determines whether to apply a linear layer before concatenating the condition embedding, the time embedding, and the cell embedding. It can be hyperparameterized over, but we generally found it to not significantly help.\n", "- `linear_projection_before_concatenation` applies linear layers to the embeddings of the condition, the time, and the cell. It can be hyperparameterized over, but we generally found it to not significantly help.\n", "- `seed` sets the seed for solvers." @@ -779,11 +779,11 @@ "source": [ "## Computing and logging metrics during training \n", "\n", - "For computing metrics during training, we provide callbacks. We divide callbacks into two categories: The first one performs computations, thus is an instance of {class}`~cellflow.training.ComputationCallback`; the second one are instances of {class}`~cellflow.training.LoggingCallback` and is used for logging. Users can either provide their own callbacks, or make use of existing ones, including {class}`~cellflow.training.Metrics` for computing metrics in the space which the cells are generated in, e.g. in PCA or VAE-space. For computing metrics in gene space, we can use {class}`~cellflow.training.PCADecodedMetrics` in case cells are PCA-embedded, or {class}`~cellflow.training.VAEDecodedMetrics` in case cells are embedding using {class}`~cellflow.external.CFJaxSCVI`. For computing metrics, we can provide user-defined ones, or metrics provided by CellFlow, which we will do below.\n", + "For computing metrics during training, we provide callbacks. We divide callbacks into two categories: The first one performs computations, thus is an instance of {class}`~scaleflow.training.ComputationCallback`; the second one are instances of {class}`~scaleflow.training.LoggingCallback` and is used for logging. Users can either provide their own callbacks, or make use of existing ones, including {class}`~scaleflow.training.Metrics` for computing metrics in the space which the cells are generated in, e.g. in PCA or VAE-space. For computing metrics in gene space, we can use {class}`~scaleflow.training.PCADecodedMetrics` in case cells are PCA-embedded, or {class}`~scaleflow.training.VAEDecodedMetrics` in case cells are embedding using {class}`~scaleflow.external.CFJaxSCVI`. For computing metrics, we can provide user-defined ones, or metrics provided by CellFlow, which we will do below.\n", "\n", - "For logging, we recommend using [Weights and Biases](https://wandb.ai), for which we provide a callback: {class}`~cellflow.training.WandbLogger`.\n", + "For logging, we recommend using [Weights and Biases](https://wandb.ai), for which we provide a callback: {class}`~scaleflow.training.WandbLogger`.\n", "\n", - "As our cells live in PCA-space, we use the {class}`~cellflow.training.PCADecodedMetrics` callback, which takes as input also an {class}`adata ` object which contains the PCs computed from the training data." + "As our cells live in PCA-space, we use the {class}`~scaleflow.training.PCADecodedMetrics` callback, which takes as input also an {class}`adata ` object which contains the PCs computed from the training data." ] }, { @@ -793,9 +793,9 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"r_squared\", \"mmd\", \"e_distance\"])\n", - "decoded_metrics_callback = cellflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", - "wandb_callback = cellflow.training.WandbLogger(project=\"cellflow_tutorials\", out_dir=\"~\", config={\"name\": \"100m_pbmc\"})\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"r_squared\", \"mmd\", \"e_distance\"])\n", + "decoded_metrics_callback = scaleflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", + "wandb_callback = scaleflow.training.WandbLogger(project=\"scaleflow_tutorials\", out_dir=\"~\", config={\"name\": \"100m_pbmc\"})\n", "\n", "# we don't pass the wandb_callback as it requires a user-specific account, but recommend setting it up\n", "callbacks = [metrics_callback, decoded_metrics_callback]\n" @@ -839,7 +839,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -926,7 +926,7 @@ "id": "c1e33895-4d76-4113-97b1-8f46a3de9037", "metadata": {}, "source": [ - "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Therefore, we have to provide a {class}`~pandas.DataFrame` with the same structure of {attr}`adata.obs ` (at least the columns which we used for {meth}`~cellflow.model.CellFlow.prepare_data`). Note that the embedding is independent of the cells, we thus don't need to pass the cellular representation. Moreover, {meth}`~cellflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", + "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Therefore, we have to provide a {class}`~pandas.DataFrame` with the same structure of {attr}`adata.obs ` (at least the columns which we used for {meth}`~scaleflow.model.CellFlow.prepare_data`). Note that the embedding is independent of the cells, we thus don't need to pass the cellular representation. Moreover, {meth}`~scaleflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", "For now, let's use all conditions, but indicate whether they're seen during training or not:" ] }, @@ -983,7 +983,7 @@ "id": "078cfff2-f938-44f1-9630-491a4db408ca", "metadata": {}, "source": [ - "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~cellflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." + "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~scaleflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." ] }, { @@ -1746,7 +1746,7 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow_mod", + "display_name": "scaleflow_mod", "language": "python", "name": "python3" }, diff --git a/docs/notebooks/200_zebrafish.ipynb b/docs/notebooks/200_zebrafish.ipynb index c566c686..e1164d57 100644 --- a/docs/notebooks/200_zebrafish.ipynb +++ b/docs/notebooks/200_zebrafish.ipynb @@ -13,7 +13,7 @@ "id": "bc061b8c-aaab-413f-8f20-a0f07c812dde", "metadata": {}, "source": [ - "In this tutorial, we predict perturbations on embryo-scale. Therefore, we consider [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2), which captures up to 23 perturbations at 5 different time points, resulting in 71 perturbed phenotypes. The experimental design is sparse, hence we investigate to what extent we can fill missing measurements with {class}`~cellflow.model.CellFlow`'s predictions." + "In this tutorial, we predict perturbations on embryo-scale. Therefore, we consider [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2), which captures up to 23 perturbations at 5 different time points, resulting in 71 perturbed phenotypes. The experimental design is sparse, hence we investigate to what extent we can fill missing measurements with {class}`~scaleflow.model.CellFlow`'s predictions." ] }, { @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -66,13 +66,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.zesta()" + "adata = scaleflow.datasets.zesta()" ] }, { @@ -484,7 +484,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -504,7 +504,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -557,7 +557,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We now prepare the data validation data using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. \n", + "We now prepare the data validation data using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. \n", "\n", "As for some conditions, and in particular for control cells, we have a large number of measurements, we subsample for inference to be faster. However, due to the heterogeneity of the cellular distribution, covering hundreds of cell types, we should not subsample by too much." ] @@ -635,7 +635,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -643,9 +643,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -653,7 +653,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `probability_path` defines the path between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `probability_path` defines the path between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -754,7 +754,7 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", "callbacks = [metrics_callback]\n" ] }, @@ -796,7 +796,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -966,7 +966,7 @@ "source": [ "## Predicting with CellFlow\n", "\n", - "Predictions with {class}`~cellflow.model.CellFlow` require an {class}`adata ` object with control cells. As we only want to generate cells corresponding the unseen perturbation cdx4 and cdx1a, we only need control cells for time point 36. Moreover, we require `covariate_data` to store the information about what we would like to predict. " + "Predictions with {class}`~scaleflow.model.CellFlow` require an {class}`adata ` object with control cells. As we only want to generate cells corresponding the unseen perturbation cdx4 and cdx1a, we only need control cells for time point 36. Moreover, we require `covariate_data` to store the information about what we would like to predict. " ] }, { @@ -1227,9 +1227,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/201_zebrafish_continuous.ipynb b/docs/notebooks/201_zebrafish_continuous.ipynb index 60f78516..38a74c6d 100644 --- a/docs/notebooks/201_zebrafish_continuous.ipynb +++ b/docs/notebooks/201_zebrafish_continuous.ipynb @@ -13,7 +13,7 @@ "id": "bc061b8c-aaab-413f-8f20-a0f07c812dde", "metadata": {}, "source": [ - "Similary to {doc}`200_zebrafish_continuous`, we make use of the [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2) dataset, which captures up to 23 perturbations at 5 different time points. Here, we leverage {class}`~cellflow.model.CellFlow` to interpolate the perturbed development at densely sampled time points." + "Similary to {doc}`200_zebrafish_continuous`, we make use of the [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2) dataset, which captures up to 23 perturbations at 5 different time points. Here, we leverage {class}`~scaleflow.model.CellFlow` to interpolate the perturbed development at densely sampled time points." ] }, { @@ -34,7 +34,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -58,13 +58,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -74,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.zesta()" + "adata = scaleflow.datasets.zesta()" ] }, { @@ -532,7 +532,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -552,7 +552,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -632,7 +632,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -640,9 +640,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -650,7 +650,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `probability_path` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `probability_path` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -751,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", "callbacks = [metrics_callback]\n" ] }, @@ -808,7 +808,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -867,7 +867,7 @@ "id": "c1e33895-4d76-4113-97b1-8f46a3de9037", "metadata": {}, "source": [ - "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Note that {meth}`~cellflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", + "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Note that {meth}`~scaleflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", "For now, let's use all conditions, but indicate whether they're seen during training or not:" ] }, @@ -903,7 +903,7 @@ "id": "078cfff2-f938-44f1-9630-491a4db408ca", "metadata": {}, "source": [ - "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~cellflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." + "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~scaleflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." ] }, { @@ -1416,9 +1416,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/300_ineuron_tutorial.ipynb b/docs/notebooks/300_ineuron_tutorial.ipynb index 25d87eef..26c3bc6b 100644 --- a/docs/notebooks/300_ineuron_tutorial.ipynb +++ b/docs/notebooks/300_ineuron_tutorial.ipynb @@ -7,7 +7,7 @@ "source": [ "# Neuron fate prediction from combinatorial morphogen treatment\n", "\n", - "In this notebook, we show how {class}`~cellflow.model.CellFlow` can be used to predict the outcome of **neuron fate programming experiments**. We use the the dataset from [Lin, Janssens et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. \n", + "In this notebook, we show how {class}`~scaleflow.model.CellFlow` can be used to predict the outcome of **neuron fate programming experiments**. We use the the dataset from [Lin, Janssens et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. \n", "\n", "## Preparing the data" ] @@ -33,8 +33,8 @@ "from scipy.sparse import csr_matrix\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", - "import cellflow\n", - "import cellflow.preprocessing as cfpp" + "import scaleflow\n", + "import scaleflow.preprocessing as cfpp" ] }, { @@ -59,7 +59,7 @@ } ], "source": [ - "adata = cellflow.datasets.ineurons()\n", + "adata = scaleflow.datasets.ineurons()\n", "print(adata)" ] }, @@ -222,7 +222,7 @@ "metadata": {}, "outputs": [], "source": [ - "cf = cellflow.model.CellFlow(adata_train_full, solver=\"otfm\")" + "cf = scaleflow.model.CellFlow(adata_train_full, solver=\"otfm\")" ] }, { @@ -230,7 +230,7 @@ "id": "d8849016", "metadata": {}, "source": [ - "### Preparing CellFlow’s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`\n", + "### Preparing CellFlow’s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`\n", "We set up the data as follows:\n", "- We use `.obsm[\"X_pca\"]` as the cellular representation (`sample_rep`)\n", "- `\"CTRL\"` indicated the source distribution we constructed earlier\n", @@ -266,16 +266,16 @@ "id": "3cdd1103", "metadata": {}, "source": [ - "### Preparing CellFlow’s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`\n", - "Now we can set up the architecture of the CellFlow model. For a detailed description of all hyperparameters, please have a look at {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`. \n", + "### Preparing CellFlow’s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`\n", + "Now we can set up the architecture of the CellFlow model. For a detailed description of all hyperparameters, please have a look at {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`. \n", "\n", "While there is some intuition behind which parameter settings to use, we generally we use hyperparameter optimization on a separate validation set to find the best hyperparameters for each task. These are some of the most relevant parameters for this task:\n", "\n", "- `layers_before_pool` and `layers_after_pool` define the networks before and after permutation-invariant pooling of combinatorial conditions. Here, we only define a network before pooling to encode the one-hot-encoded morphogen representations and no `layers_after_pool` to use only one layer to transform the pooled representation into the condition embedding.\n", "- We found that that pooling the combinations by their mean (`pooling_type=\"mean\"`) works best for this task. This might be due to the fact that the morphogen combination conditions are *relatively* simple and their total number is somewhat small, which might make it harder to learn a more complex attention-based pooling.\n", "- `match_fn` defines how to sample pairs between the source and the perturbed cells. Here, the source distribution is a random distribution rather than a control condition, but because the output distributions are relatively complex an might contain outliers, we still found some unbalancedness to be useful for this task, so we se `tau_a=tau_b=0.99`.\n", - "- `flow` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. We don't use any noise here as our cell population is highly heterogenous.\n", - "- We found that sometimes the relationship between sizes of the condition embedding as well as encoded `x`, and `time` in the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` can matter quite a bit to the model. We are not sure exactly why this is tha case, but we found it to be especially important for iNeuron and organoid applications, where we generate from noise into a complex output distribution. We therefore set the `hidden_dims=[2048] * 2 + [128]` to transform the `x` embedding into a smaller dimension with the last layer." + "- `flow` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. We don't use any noise here as our cell population is highly heterogenous.\n", + "- We found that sometimes the relationship between sizes of the condition embedding as well as encoded `x`, and `time` in the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` can matter quite a bit to the model. We are not sure exactly why this is tha case, but we found it to be especially important for iNeuron and organoid applications, where we generate from noise into a complex output distribution. We therefore set the `hidden_dims=[2048] * 2 + [128]` to transform the `x` embedding into a smaller dimension with the last layer." ] }, { @@ -339,7 +339,7 @@ "id": "c806a95a", "metadata": {}, "source": [ - "Now we can train the model. To make training quicker, we here don't compute validation metrics during training, but only evaluare predictions afterwards. If you are running a model for the fist time, we recommend to monitor training behaviour with validation data through {meth}`~cellflow.model.CellFlow.prepare_validation_data` as explained in {doc}`100_pbmc`." + "Now we can train the model. To make training quicker, we here don't compute validation metrics during training, but only evaluare predictions afterwards. If you are running a model for the fist time, we recommend to monitor training behaviour with validation data through {meth}`~scaleflow.model.CellFlow.prepare_validation_data` as explained in {doc}`100_pbmc`." ] }, { @@ -365,7 +365,7 @@ "id": "87a4848b", "metadata": {}, "source": [ - "After training, we can save the model to disk with {meth}`~cellflow.model.CellFlow.save_model` and load it again with {meth}`~cellflow.model.CellFlow.load_model`. " + "After training, we can save the model to disk with {meth}`~scaleflow.model.CellFlow.save_model` and load it again with {meth}`~scaleflow.model.CellFlow.load_model`. " ] }, { @@ -375,9 +375,9 @@ "metadata": {}, "outputs": [], "source": [ - "cf.save(\"cellflow_model/\", overwrite=True)\n", - "cf = cellflow.model.CellFlow.load(\n", - " \"cellflow_model/\"\n", + "cf.save(\"scaleflow_model/\", overwrite=True)\n", + "cf = scaleflow.model.CellFlow.load(\n", + " \"scaleflow_model/\"\n", ")" ] }, @@ -387,7 +387,7 @@ "metadata": {}, "source": [ "### Making predictions\n", - "Now we can finally check out the predictions. we use {meth}`~cellflow.model.CellFlow.predict` to generate predictions for the held-out conditions in the validation dataset." + "Now we can finally check out the predictions. we use {meth}`~scaleflow.model.CellFlow.predict` to generate predictions for the held-out conditions in the validation dataset." ] }, { @@ -428,7 +428,7 @@ "id": "d037aadd", "metadata": {}, "source": [ - "{meth}`~cellflow.model.CellFlow.predict` returns a dictionaly with predictions for each condition. We now convert the predictions into an {class}`adata ` object." + "{meth}`~scaleflow.model.CellFlow.predict` returns a dictionaly with predictions for each condition. We now convert the predictions into an {class}`adata ` object." ] }, { @@ -456,7 +456,7 @@ "id": "deadfa4e", "metadata": {}, "source": [ - "To obtain gene expression values for our predictions, we use {meth}`cellflow.preprocessing.reconstruct_pca` to reconstruct the PCA space where the predictions were made. We then reproject the predictions into a new PCA space with the full ground truth data. " + "To obtain gene expression values for our predictions, we use {meth}`scaleflow.preprocessing.reconstruct_pca` to reconstruct the PCA space where the predictions were made. We then reproject the predictions into a new PCA space with the full ground truth data. " ] }, { diff --git a/docs/notebooks/500_combosciplex.ipynb b/docs/notebooks/500_combosciplex.ipynb index 8e42b30b..879ba0a9 100644 --- a/docs/notebooks/500_combosciplex.ipynb +++ b/docs/notebooks/500_combosciplex.ipynb @@ -34,7 +34,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -63,13 +63,13 @@ "import flax.linen as nn\n", "import optax\n", "import pertpy\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca, annotate_compounds, get_molecular_fingerprints\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca, annotate_compounds, get_molecular_fingerprints\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -356,7 +356,7 @@ "id": "82a74985-1d8b-43e5-97a6-a3cf220e33e5", "metadata": {}, "source": [ - "We require embeddings for the drugs. While we encourage users to try different ones, we use molecular fingerprints in the following. Therefore, we first annotate the drugs, i.e. we retrieve the SMILES and [PubChem](https://pubchem.ncbi.nlm.nih.gov/) metadata using {func}`~cellflow.preprocessing.annotate_compounds`:" + "We require embeddings for the drugs. While we encourage users to try different ones, we use molecular fingerprints in the following. Therefore, we first annotate the drugs, i.e. we retrieve the SMILES and [PubChem](https://pubchem.ncbi.nlm.nih.gov/) metadata using {func}`~scaleflow.preprocessing.annotate_compounds`:" ] }, { @@ -634,7 +634,7 @@ "id": "7438db45-50af-4c97-8be9-c546a0982f70", "metadata": {}, "source": [ - "Among others, this gave us the SMILES strings, such that we can now get the molecular fingerprints for the SMILES strings using {func}`~cellflow.preprocessing.get_molecular_fingerprints`. We have {attr}`uns['fingerprints'] ` added, and see that all drugs have been assigned a fingerprint." + "Among others, this gave us the SMILES strings, such that we can now get the molecular fingerprints for the SMILES strings using {func}`~scaleflow.preprocessing.get_molecular_fingerprints`. We have {attr}`uns['fingerprints'] ` added, and see that all drugs have been assigned a fingerprint." ] }, { @@ -664,7 +664,7 @@ "id": "abe53039-afc8-42fc-a255-40c94d8e79b4", "metadata": {}, "source": [ - "We now add a zero token which is going to be ignored during training for \"filling\" the second drug in case of single drug perturbations. Note that this zero token will be specified later in {meth}`~cellflow.model.CellFlow.prepare_data`." + "We now add a zero token which is going to be ignored during training for \"filling\" the second drug in case of single drug perturbations. Note that this zero token will be specified later in {meth}`~scaleflow.model.CellFlow.prepare_data`." ] }, { @@ -770,7 +770,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -790,7 +790,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -837,7 +837,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We now prepare the data validation data using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. \n", + "We now prepare the data validation data using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. \n", "\n", "As for some conditions, and in particular for control cells, we have a large number of measurements, we subsample for inference to be faster. However, due to the heterogeneity of the cellular distribution, covering hundreds of cell types, we should not subsample by too much." ] @@ -878,7 +878,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -886,9 +886,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -896,7 +896,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `flow` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `flow` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -998,8 +998,8 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", - "decoded_metrics_callback = cellflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "decoded_metrics_callback = scaleflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", "callbacks = [metrics_callback, decoded_metrics_callback]\n" ] }, @@ -1041,7 +1041,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -1178,7 +1178,7 @@ "source": [ "## Predicting with CellFlow\n", "\n", - "Predictions with {class}`~cellflow.model.CellFlow` require an {class}`adata ` object with control cells. Moreover, we need `covariate_data` to store the information about what we would like to predict. " + "Predictions with {class}`~scaleflow.model.CellFlow` require an {class}`adata ` object with control cells. Moreover, we need `covariate_data` to store the information about what we would like to predict. " ] }, { @@ -1613,7 +1613,7 @@ "id": "2b290bcf-30c7-4077-b3ac-959dabed0aee", "metadata": {}, "source": [ - "We also compute the metrics of CellFlow with respect to the ground truth data going through the encoder-decoder in order to separate CellFlow's model performance from the encoder-decoder. Note that this is what is computed during training with {class}`~cellflow.training.PCADecodedMetrics`. " + "We also compute the metrics of CellFlow with respect to the ground truth data going through the encoder-decoder in order to separate CellFlow's model performance from the encoder-decoder. Note that this is what is computed during training with {class}`~scaleflow.training.PCADecodedMetrics`. " ] }, { @@ -1989,9 +1989,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/600_trainsampler copy.ipynb b/docs/notebooks/600_trainsampler copy.ipynb index 56b04788..79a4072e 100644 --- a/docs/notebooks/600_trainsampler copy.ipynb +++ b/docs/notebooks/600_trainsampler copy.ipynb @@ -19,7 +19,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "from cellflow.data import MappedCellData" + "from scaleflow.data import MappedCellData" ] }, { @@ -121,7 +121,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cellflow.data import ReservoirSampler" + "from scaleflow.data import ReservoirSampler" ] }, { diff --git a/docs/notebooks/tahoe_sizes.ipynb b/docs/notebooks/tahoe_sizes.ipynb index 57e27865..ea311b09 100644 --- a/docs/notebooks/tahoe_sizes.ipynb +++ b/docs/notebooks/tahoe_sizes.ipynb @@ -7,7 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cellflow.data import MappedCellData" + "from scaleflow.data import MappedCellData" ] }, { diff --git a/docs/user/datasets.rst b/docs/user/datasets.rst index b674e7db..8c70f2f7 100644 --- a/docs/user/datasets.rst +++ b/docs/user/datasets.rst @@ -1,7 +1,7 @@ Datasets ~~~~~~~~ -.. module:: cellflow.datasets -.. currentmodule:: cellflow +.. module:: scaleflow.datasets +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/external.rst b/docs/user/external.rst index 124c16b2..2556f735 100644 --- a/docs/user/external.rst +++ b/docs/user/external.rst @@ -1,7 +1,7 @@ External ~~~~~~~~ -.. module:: cellflow.external -.. currentmodule:: cellflow +.. module:: scaleflow.external +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/index.rst b/docs/user/index.rst index af12c42e..ea9e5584 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -1,7 +1,7 @@ User API ######## -.. module:: cellflow.user +.. module:: scaleflow.user .. toctree:: :maxdepth: 2 diff --git a/docs/user/metrics.rst b/docs/user/metrics.rst index 3f1b558e..19927d7f 100644 --- a/docs/user/metrics.rst +++ b/docs/user/metrics.rst @@ -1,7 +1,7 @@ Metrics ~~~~~~~ -.. module:: cellflow.metrics -.. currentmodule:: cellflow +.. module:: scaleflow.metrics +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/model.rst b/docs/user/model.rst index aae3774f..9f8c2aad 100644 --- a/docs/user/model.rst +++ b/docs/user/model.rst @@ -1,7 +1,7 @@ Model ~~~~~ -.. module:: cellflow.model -.. currentmodule:: cellflow +.. module:: scaleflow.model +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/networks.rst b/docs/user/networks.rst index 76302a4b..d6148740 100644 --- a/docs/user/networks.rst +++ b/docs/user/networks.rst @@ -1,7 +1,7 @@ Networks ~~~~~~~~ -.. module:: cellflow.networks -.. currentmodule:: cellflow +.. module:: scaleflow.networks +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/plotting.rst b/docs/user/plotting.rst index 88a0389b..9414813e 100644 --- a/docs/user/plotting.rst +++ b/docs/user/plotting.rst @@ -1,7 +1,7 @@ Plotting ~~~~~~~~ -.. module:: cellflow.plotting -.. currentmodule:: cellflow +.. module:: scaleflow.plotting +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/preprocessing.rst b/docs/user/preprocessing.rst index ada1ca2b..8c897774 100644 --- a/docs/user/preprocessing.rst +++ b/docs/user/preprocessing.rst @@ -1,7 +1,7 @@ Preprocessing ~~~~~~~~~~~~~ -.. module:: cellflow.preprocessing -.. currentmodule:: cellflow +.. module:: scaleflow.preprocessing +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/solvers.rst b/docs/user/solvers.rst index f1991264..4f1dc4dc 100644 --- a/docs/user/solvers.rst +++ b/docs/user/solvers.rst @@ -1,8 +1,8 @@ Solvers ~~~~~~~ -.. module:: cellflow.solvers -.. currentmodule:: cellflow +.. module:: scaleflow.solvers +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/training.rst b/docs/user/training.rst index a2741424..dd257024 100644 --- a/docs/user/training.rst +++ b/docs/user/training.rst @@ -1,7 +1,7 @@ Training ~~~~~~~~ -.. module:: cellflow.training -.. currentmodule:: cellflow +.. module:: scaleflow.training +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/utils.rst b/docs/user/utils.rst index 28be6224..ec1f4d19 100644 --- a/docs/user/utils.rst +++ b/docs/user/utils.rst @@ -1,7 +1,7 @@ Utils ~~~~~ -.. module:: cellflow.utils -.. currentmodule:: cellflow +.. module:: scaleflow.utils +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/pyproject.toml b/pyproject.toml index 152c690b..eda7b744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ build-backend = "hatchling.build" requires = [ "hatch-vcs", "hatchling" ] [project] -name = "cellflow-tools" +name = "scaleflow-tools" description = "Modeling complex perturbations with flow matching at single-cell resolution" readme = "README.md" license = "PolyForm-Noncommercial-1.0.0" @@ -88,9 +88,9 @@ optional-dependencies.pp = [ "rdkit", ] optional-dependencies.test = [ - "cellflow-tools[embedding]", - "cellflow-tools[external]", - "cellflow-tools[pp]", + "scaleflow-tools[embedding]", + "scaleflow-tools[external]", + "scaleflow-tools[pp]", "coverage[toml]>=7", "pytest", "pytest-cov>=6", @@ -98,12 +98,12 @@ optional-dependencies.test = [ "pytest-xdist>=3", ] -urls.Documentation = "https://cellflow.readthedocs.io/" -urls.Home-page = "https://github.com/theislab/cellflow" -urls.Source = "https://github.com/theislab/cellflow" +urls.Documentation = "https://scaleflow.readthedocs.io/" +urls.Home-page = "https://github.com/theislab/scaleflow" +urls.Source = "https://github.com/theislab/scaleflow" [tool.hatch.build.targets.wheel] -packages = [ 'src/cellflow' ] +packages = [ 'src/scaleflow' ] [tool.hatch.version] source = "vcs" @@ -201,7 +201,7 @@ extras = test,pp,external,embedding pass_env = PYTEST_*,CI commands = coverage run -m pytest {tty:--color=yes} {posargs: \ - --cov={env_site_packages_dir}{/}cellflow --cov-config={tox_root}{/}pyproject.toml \ + --cov={env_site_packages_dir}{/}scaleflow --cov-config={tox_root}{/}pyproject.toml \ --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} [testenv:lint-code] @@ -236,7 +236,7 @@ deps = leidenalg changedir = {tox_root}{/}docs commands = - python -m ipykernel install --user --name=cellflow + python -m ipykernel install --user --name=scaleflow bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks [testenv:clean-docs] diff --git a/scripts/process_tahoe.py b/scripts/process_tahoe.py index 5edd1445..8d1669c5 100644 --- a/scripts/process_tahoe.py +++ b/scripts/process_tahoe.py @@ -7,9 +7,9 @@ import anndata as ad import h5py import zarr -from cellflow.data._utils import write_sharded +from scaleflow.data._utils import write_sharded from anndata.experimental import read_lazy -from cellflow.data import DataManager +from scaleflow.data import DataManager import cupy as cp import tqdm import dask diff --git a/src/cellflow/__init__.py b/src/cellflow/__init__.py deleted file mode 100644 index 526fc741..00000000 --- a/src/cellflow/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from importlib import metadata - -import cellflow.preprocessing as pp -from cellflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/external/__init__.py b/src/cellflow/external/__init__.py deleted file mode 100644 index 7a03a1c8..00000000 --- a/src/cellflow/external/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -try: - from cellflow.external._scvi import CFJaxSCVI -except ImportError as e: - raise ImportError( - "cellflow.external requires more dependencies. Please install via pip install 'cellflow[external]'" - ) from e diff --git a/src/cellflow/model/__init__.py b/src/cellflow/model/__init__.py deleted file mode 100644 index 8731f241..00000000 --- a/src/cellflow/model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.model._cellflow import CellFlow - -__all__ = ["CellFlow"] diff --git a/src/cellflow/plotting/__init__.py b/src/cellflow/plotting/__init__.py deleted file mode 100644 index c7fd387e..00000000 --- a/src/cellflow/plotting/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.plotting._plotting import plot_condition_embedding - -__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/preprocessing/__init__.py b/src/cellflow/preprocessing/__init__.py deleted file mode 100644 index 21eaa993..00000000 --- a/src/cellflow/preprocessing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from cellflow.preprocessing._gene_emb import ( - GeneInfo, - get_esm_embedding, - prot_sequence_from_ensembl, - protein_features_from_genes, -) -from cellflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca -from cellflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints -from cellflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/solvers/__init__.py b/src/cellflow/solvers/__init__.py deleted file mode 100644 index a02a5510..00000000 --- a/src/cellflow/solvers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cellflow.solvers._genot import GENOT -from cellflow.solvers._otfm import OTFlowMatching - -__all__ = ["GENOT", "OTFlowMatching"] diff --git a/src/scaleflow/__init__.py b/src/scaleflow/__init__.py new file mode 100644 index 00000000..60891e49 --- /dev/null +++ b/src/scaleflow/__init__.py @@ -0,0 +1,4 @@ +from importlib import metadata + +import scaleflow.preprocessing as pp +from scaleflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/_constants.py b/src/scaleflow/_constants.py similarity index 57% rename from src/cellflow/_constants.py rename to src/scaleflow/_constants.py index 3b782c8d..92f38201 100644 --- a/src/cellflow/_constants.py +++ b/src/scaleflow/_constants.py @@ -1,4 +1,4 @@ -CONTROL_HELPER = "_cellflow_control" +CONTROL_HELPER = "_scaleflow_control" CONDITION_EMBEDDING = "condition_embedding" -CELLFLOW_KEY = "cellflow" +CELLFLOW_KEY = "scaleflow" GENOT_CELL_KEY = "cell_embedding_condition" diff --git a/src/cellflow/_logging.py b/src/scaleflow/_logging.py similarity index 100% rename from src/cellflow/_logging.py rename to src/scaleflow/_logging.py diff --git a/src/cellflow/_optional.py b/src/scaleflow/_optional.py similarity index 67% rename from src/cellflow/_optional.py rename to src/scaleflow/_optional.py index 53ea8e31..c05ccec0 100644 --- a/src/cellflow/_optional.py +++ b/src/scaleflow/_optional.py @@ -5,5 +5,5 @@ class OptionalDependencyNotAvailable(ImportError): def torch_required_msg() -> str: return ( "Optional dependency 'torch' is required for this feature.\n" - "Install it via: pip install torch # or pip install 'cellflow-tools[torch]'" + "Install it via: pip install torch # or pip install 'scaleflow-tools[torch]'" ) diff --git a/src/cellflow/_types.py b/src/scaleflow/_types.py similarity index 100% rename from src/cellflow/_types.py rename to src/scaleflow/_types.py diff --git a/src/cellflow/compat/__init__.py b/src/scaleflow/compat/__init__.py similarity index 100% rename from src/cellflow/compat/__init__.py rename to src/scaleflow/compat/__init__.py diff --git a/src/cellflow/compat/torch_.py b/src/scaleflow/compat/torch_.py similarity index 86% rename from src/cellflow/compat/torch_.py rename to src/scaleflow/compat/torch_.py index 5a51fa5e..b79f134e 100644 --- a/src/cellflow/compat/torch_.py +++ b/src/scaleflow/compat/torch_.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from cellflow._optional import OptionalDependencyNotAvailable, torch_required_msg +from scaleflow._optional import OptionalDependencyNotAvailable, torch_required_msg try: from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore diff --git a/src/cellflow/data/__init__.py b/src/scaleflow/data/__init__.py similarity index 66% rename from src/cellflow/data/__init__.py rename to src/scaleflow/data/__init__.py index d54e4b67..23cf6b76 100644 --- a/src/cellflow/data/__init__.py +++ b/src/scaleflow/data/__init__.py @@ -1,4 +1,4 @@ -from cellflow.data._data import ( +from scaleflow.data._data import ( BaseDataMixin, ConditionData, PredictionData, @@ -6,15 +6,15 @@ ValidationData, MappedCellData, ) -from cellflow.data._dataloader import ( +from scaleflow.data._dataloader import ( PredictionSampler, TrainSampler, ReservoirSampler, ValidationSampler, ) -from cellflow.data._datamanager import DataManager -from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler -from cellflow.data._torch_dataloader import TorchCombinedTrainSampler +from scaleflow.data._datamanager import DataManager +from scaleflow.data._jax_dataloader import JaxOutOfCoreTrainSampler +from scaleflow.data._torch_dataloader import TorchCombinedTrainSampler __all__ = [ "DataManager", diff --git a/src/cellflow/data/_data.py b/src/scaleflow/data/_data.py similarity index 99% rename from src/cellflow/data/_data.py rename to src/scaleflow/data/_data.py index 219dbeae..099114db 100644 --- a/src/cellflow/data/_data.py +++ b/src/scaleflow/data/_data.py @@ -8,8 +8,8 @@ import zarr from zarr.storage import LocalStore -from cellflow._types import ArrayLike -from cellflow.data._utils import write_sharded +from scaleflow._types import ArrayLike +from scaleflow.data._utils import write_sharded __all__ = [ "BaseDataMixin", diff --git a/src/cellflow/data/_dataloader.py b/src/scaleflow/data/_dataloader.py similarity index 98% rename from src/cellflow/data/_dataloader.py rename to src/scaleflow/data/_dataloader.py index 614a5004..162579b8 100644 --- a/src/cellflow/data/_dataloader.py +++ b/src/scaleflow/data/_dataloader.py @@ -6,7 +6,7 @@ import threading from concurrent.futures import ThreadPoolExecutor, Future -from cellflow.data._data import ( +from scaleflow.data._data import ( PredictionData, TrainingData, ValidationData, @@ -22,7 +22,7 @@ class TrainSampler: - """Data sampler for :class:`~cellflow.data.TrainingData`. + """Data sampler for :class:`~scaleflow.data.TrainingData`. Parameters ---------- @@ -349,7 +349,7 @@ def _get_condition_data(self, cond_idx: int) -> dict[str, np.ndarray]: class ValidationSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.ValidationData`. + """Data sampler for :class:`~scaleflow.data.ValidationData`. Parameters ---------- @@ -415,7 +415,7 @@ def data(self) -> ValidationData: class PredictionSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.PredictionData`. + """Data sampler for :class:`~scaleflow.data.PredictionData`. Parameters ---------- diff --git a/src/cellflow/data/_datamanager.py b/src/scaleflow/data/_datamanager.py similarity index 99% rename from src/cellflow/data/_datamanager.py rename to src/scaleflow/data/_datamanager.py index 3ccc0793..5cb60550 100644 --- a/src/cellflow/data/_datamanager.py +++ b/src/scaleflow/data/_datamanager.py @@ -13,9 +13,9 @@ from dask.diagnostics import ProgressBar from pandas.api.types import is_numeric_dtype -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData from ._utils import _flatten_list, _to_list @@ -223,8 +223,8 @@ def get_prediction_data( is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. covariate_data A :class:`~pandas.DataFrame` with columns defining the covariates as - in :meth:`cellflow.model.CellFlow.prepare_data` and stored in - :attr:`cellflow.model.CellFlow.data_manager`. + in :meth:`scaleflow.model.CellFlow.prepare_data` and stored in + :attr:`scaleflow.model.CellFlow.data_manager`. rep_dict Dictionary with representations of the covariates. If not provided, :attr:`~anndata.AnnData.uns` is used. diff --git a/src/cellflow/data/_jax_dataloader.py b/src/scaleflow/data/_jax_dataloader.py similarity index 97% rename from src/cellflow/data/_jax_dataloader.py rename to src/scaleflow/data/_jax_dataloader.py index 1c181243..4178fd84 100644 --- a/src/cellflow/data/_jax_dataloader.py +++ b/src/scaleflow/data/_jax_dataloader.py @@ -6,10 +6,10 @@ import numpy as np -from cellflow.data._data import ( +from scaleflow.data._data import ( TrainingData, ) -from cellflow.data._dataloader import TrainSampler +from scaleflow.data._dataloader import TrainSampler def _prefetch_to_device( diff --git a/src/cellflow/data/_torch_dataloader.py b/src/scaleflow/data/_torch_dataloader.py similarity index 94% rename from src/cellflow/data/_torch_dataloader.py rename to src/scaleflow/data/_torch_dataloader.py index 61f12040..832746eb 100644 --- a/src/cellflow/data/_torch_dataloader.py +++ b/src/scaleflow/data/_torch_dataloader.py @@ -3,9 +3,9 @@ import numpy as np -from cellflow.compat import TorchIterableDataset -from cellflow.data._data import MappedCellData -from cellflow.data._dataloader import TrainSampler, ReservoirSampler +from scaleflow.compat import TorchIterableDataset +from scaleflow.data._data import MappedCellData +from scaleflow.data._dataloader import TrainSampler, ReservoirSampler def _worker_init_fn_helper(worker_id, random_generators): diff --git a/src/cellflow/data/_utils.py b/src/scaleflow/data/_utils.py similarity index 100% rename from src/cellflow/data/_utils.py rename to src/scaleflow/data/_utils.py diff --git a/src/cellflow/datasets.py b/src/scaleflow/datasets.py similarity index 94% rename from src/cellflow/datasets.py rename to src/scaleflow/datasets.py index d07ce340..f8bace3a 100644 --- a/src/cellflow/datasets.py +++ b/src/scaleflow/datasets.py @@ -4,7 +4,7 @@ import anndata as ad from scanpy.readwrite import _check_datafile_present_and_download -from cellflow._types import PathLike +from scaleflow._types import PathLike __all__ = [ "ineurons", @@ -13,7 +13,7 @@ def ineurons( - path: PathLike = "~/.cache/cellflow/ineurons.h5ad", + path: PathLike = "~/.cache/scaleflow/ineurons.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: @@ -45,7 +45,7 @@ def ineurons( def pbmc_cytokines( - path: PathLike = "~/.cache/cellflow/pbmc_parse.h5ad", + path: PathLike = "~/.cache/scaleflow/pbmc_parse.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: @@ -78,7 +78,7 @@ def pbmc_cytokines( def zesta( - path: PathLike = "~/.cache/cellflow/zesta.h5ad", + path: PathLike = "~/.cache/scaleflow/zesta.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: diff --git a/src/scaleflow/external/__init__.py b/src/scaleflow/external/__init__.py new file mode 100644 index 00000000..f6d2f0e3 --- /dev/null +++ b/src/scaleflow/external/__init__.py @@ -0,0 +1,6 @@ +try: + from scaleflow.external._scvi import CFJaxSCVI +except ImportError as e: + raise ImportError( + "scaleflow.external requires more dependencies. Please install via pip install 'scaleflow[external]'" + ) from e diff --git a/src/cellflow/external/_scvi.py b/src/scaleflow/external/_scvi.py similarity index 98% rename from src/cellflow/external/_scvi.py rename to src/scaleflow/external/_scvi.py index f979d93c..e84a912f 100644 --- a/src/cellflow/external/_scvi.py +++ b/src/scaleflow/external/_scvi.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike if TYPE_CHECKING: from typing import Literal @@ -25,7 +25,7 @@ class CFJaxSCVI(JaxSCVI): - from cellflow.external._scvi_utils import CFJaxVAE + from scaleflow.external._scvi_utils import CFJaxVAE _module_cls = CFJaxVAE diff --git a/src/cellflow/external/_scvi_utils.py b/src/scaleflow/external/_scvi_utils.py similarity index 100% rename from src/cellflow/external/_scvi_utils.py rename to src/scaleflow/external/_scvi_utils.py diff --git a/src/cellflow/metrics/__init__.py b/src/scaleflow/metrics/__init__.py similarity index 92% rename from src/cellflow/metrics/__init__.py rename to src/scaleflow/metrics/__init__.py index 79cb1738..63a2aa52 100644 --- a/src/cellflow/metrics/__init__.py +++ b/src/scaleflow/metrics/__init__.py @@ -1,4 +1,4 @@ -from cellflow.metrics._metrics import ( +from scaleflow.metrics._metrics import ( compute_e_distance, compute_e_distance_fast, compute_mean_metrics, diff --git a/src/cellflow/metrics/_metrics.py b/src/scaleflow/metrics/_metrics.py similarity index 100% rename from src/cellflow/metrics/_metrics.py rename to src/scaleflow/metrics/_metrics.py diff --git a/src/scaleflow/model/__init__.py b/src/scaleflow/model/__init__.py new file mode 100644 index 00000000..32751554 --- /dev/null +++ b/src/scaleflow/model/__init__.py @@ -0,0 +1,3 @@ +from scaleflow.model._scaleflow import CellFlow + +__all__ = ["CellFlow"] diff --git a/src/cellflow/model/_cellflow.py b/src/scaleflow/model/_scaleflow.py similarity index 87% rename from src/cellflow/model/_cellflow.py rename to src/scaleflow/model/_scaleflow.py index b8c63b87..cfe8eb47 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/scaleflow/model/_scaleflow.py @@ -15,18 +15,18 @@ import pandas as pd from ott.neural.methods.flows import dynamics -from cellflow import _constants -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler -from cellflow.data._data import ConditionData, TrainingData, ValidationData -from cellflow.data._datamanager import DataManager -from cellflow.model._utils import _write_predictions -from cellflow.networks import _velocity_field -from cellflow.plotting import _utils -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback -from cellflow.training._trainer import CellFlowTrainer -from cellflow.utils import match_linear +from scaleflow import _constants +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from scaleflow.data._data import ConditionData, TrainingData, ValidationData +from scaleflow.data._datamanager import DataManager +from scaleflow.model._utils import _write_predictions +from scaleflow.networks import _velocity_field +from scaleflow.plotting import _utils +from scaleflow.solvers import _genot, _otfm +from scaleflow.training._callbacks import BaseCallback +from scaleflow.training._trainer import CellFlowTrainer +from scaleflow.utils import match_linear __all__ = ["CellFlow"] @@ -73,19 +73,19 @@ def prepare_data( max_combination_length: int | None = None, null_value: float = 0.0, ) -> None: - """Prepare the dataloader for training from :attr:`~cellflow.model.CellFlow.adata`. + """Prepare the dataloader for training from :attr:`~scaleflow.model.CellFlow.adata`. Parameters ---------- sample_rep - Key in :attr:`~anndata.AnnData.obsm` of :attr:`cellflow.model.CellFlow.adata` where + Key in :attr:`~anndata.AnnData.obsm` of :attr:`scaleflow.model.CellFlow.adata` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. control_key Key of a boolean column in :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata` that defines the control samples. + :attr:`scaleflow.model.CellFlow.adata` that defines the control samples. perturbation_covariates A dictionary where the keys indicate the name of the covariate group and the values are - keys in :attr:`~anndata.AnnData.obs` of :attr:`cellflow.model.CellFlow.adata`. The + keys in :attr:`~anndata.AnnData.obs` of :attr:`scaleflow.model.CellFlow.adata`. The corresponding columns can be of the following types: - categorial: The column contains categories whose representation is stored in @@ -126,8 +126,8 @@ def prepare_data( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.data_manager` - the :class:`cellflow.data.DataManager` object. - - :attr:`cellflow.model.CellFlow.train_data` - the training data. + - :attr:`scaleflow.model.CellFlow.data_manager` - the :class:`scaleflow.data.DataManager` object. + - :attr:`scaleflow.model.CellFlow.train_data` - the training data. Example ------- @@ -203,7 +203,7 @@ def prepare_validation_data( An :class:`~anndata.AnnData` object. name Name of the validation data defining the key in - :attr:`cellflow.model.CellFlow.validation_data`. + :attr:`scaleflow.model.CellFlow.validation_data`. n_conditions_on_log_iteration Number of conditions to use for computation callbacks at each logged iteration. If :obj:`None`, use all conditions. @@ -212,14 +212,14 @@ def prepare_validation_data( If :obj:`None`, use all conditions. predict_kwargs Keyword arguments for the prediction function - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. Returns ------- :obj:`None`, and updates the following fields: - - :attr:`cellflow.model.CellFlow.validation_data` - a dictionary with the validation data. + - :attr:`scaleflow.model.CellFlow.validation_data` - a dictionary with the validation data. """ if self.train_data is None: @@ -279,7 +279,7 @@ def prepare_model( """Prepare the model for training. This function sets up the neural network architecture and specificities of the - :attr:`solver`. When :attr:`solver` is an instance of :class:`cellflow.solvers._genot.GENOT`, + :attr:`solver`. When :attr:`solver` is an instance of :class:`scaleflow.solvers._genot.GENOT`, the following arguments have to be passed to ``'condition_encoder_kwargs'``: @@ -309,9 +309,9 @@ def prepare_model( pooling_kwargs Keyword arguments for the pooling method corresponding to: - - :class:`cellflow.networks.TokenAttentionPooling` if ``'pooling'`` is + - :class:`scaleflow.networks.TokenAttentionPooling` if ``'pooling'`` is ``'attention_token'``. - - :class:`cellflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. + - :class:`scaleflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. layers_before_pool Layers applied to the condition embeddings before pooling. Can be of type @@ -320,8 +320,8 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keyword arguments for the layer type :class:`cellflow.networks.MLPBlock` or - :class:`cellflow.networks.SelfAttentionBlock`. + - Further keyword arguments for the layer type :class:`scaleflow.networks.MLPBlock` or + :class:`scaleflow.networks.SelfAttentionBlock`. - :class:`dict` with keys corresponding to perturbation covariate keys, and values correspondinng to the above mentioned tuples. @@ -333,16 +333,16 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keys depend on the layer type, either for :class:`cellflow.networks.MLPBlock` or - for :class:`cellflow.networks.SelfAttentionBlock`. + - Further keys depend on the layer type, either for :class:`scaleflow.networks.MLPBlock` or + for :class:`scaleflow.networks.SelfAttentionBlock`. condition_embedding_dim Dimensions of the condition embedding, i.e. the last layer of the - :class:`cellflow.networks.ConditionEncoder`. + :class:`scaleflow.networks.ConditionEncoder`. cond_output_dropout - Dropout rate for the last layer of the :class:`cellflow.networks.ConditionEncoder`. + Dropout rate for the last layer of the :class:`scaleflow.networks.ConditionEncoder`. condition_encoder_kwargs - Keyword arguments for the :class:`cellflow.networks.ConditionEncoder`. + Keyword arguments for the :class:`scaleflow.networks.ConditionEncoder`. pool_sample_covariates Whether to include sample covariates in the pooling. time_freqs @@ -350,17 +350,17 @@ def prepare_model( (:func:`ott.neural.networks.layers.sinusoidal_time_encoder`). time_max_period Controls the frequency of the time embeddings, see - :func:`cellflow.networks.utils.sinusoidal_time_encoder`. + :func:`scaleflow.networks.utils.sinusoidal_time_encoder`. time_encoder_dims Dimensions of the layers processing the time embedding in - :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. time_encoder_dropout - Dropout rate for the :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + Dropout rate for the :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. hidden_dims Dimensions of the layers processing the input to the velocity field - via :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + via :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. hidden_dropout - Dropout rate for :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + Dropout rate for :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. conditioning Conditioning method, should be one of: @@ -373,18 +373,18 @@ def prepare_model( Keyword arguments for the conditioning method. decoder_dims Dimensions of the output layers in - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. decoder_dropout Dropout rate for the output layer - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. vf_act_fn - Activation function of the :class:`cellflow.networks.ConditionalVelocityField`. + Activation function of the :class:`scaleflow.networks.ConditionalVelocityField`. vf_kwargs Additional keyword arguments for the solver-specific vector field. For instance, when ``'solver==genot'``, the following keyword argument can be passed: - ``'genot_source_dims'`` of type :class:`tuple` with the dimensions - of the :class:`cellflow.networks.MLPBlock` processing the source cell. + of the :class:`scaleflow.networks.MLPBlock` processing the source cell. - ``'genot_source_dropout'`` of type :class:`float` indicating the dropout rate for the source cell processing. probability_path @@ -397,12 +397,12 @@ def prepare_model( match_fn Matching function between unperturbed and perturbed cells. Should take as input source and target data and return the optimal transport matrix, see e.g. - :func:`cellflow.utils.match_linear`. + :func:`scaleflow.utils.match_linear`. optimizer Optimizer used for training. solver_kwargs - Keyword arguments for the solver :class:`cellflow.solvers.OTFlowMatching` or - :class:`cellflow.solvers.GENOT`. + Keyword arguments for the solver :class:`scaleflow.solvers.OTFlowMatching` or + :class:`scaleflow.solvers.GENOT`. layer_norm_before_concatenation If :obj:`True`, applies layer normalization before concatenating the embedded time, embedded data, and condition embeddings. @@ -416,12 +416,12 @@ def prepare_model( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.velocity_field` - an instance of the - :class:`cellflow.networks.ConditionalVelocityField`. - - :attr:`cellflow.model.CellFlow.solver` - an instance of :class:`cellflow.solvers.OTFlowMatching` - or :class:`cellflow.solvers.GENOT`. - - :attr:`cellflow.model.CellFlow.trainer` - an instance of the - :class:`cellflow.training.CellFlowTrainer`. + - :attr:`scaleflow.model.CellFlow.velocity_field` - an instance of the + :class:`scaleflow.networks.ConditionalVelocityField`. + - :attr:`scaleflow.model.CellFlow.solver` - an instance of :class:`scaleflow.solvers.OTFlowMatching` + or :class:`scaleflow.solvers.GENOT`. + - :attr:`scaleflow.model.CellFlow.trainer` - an instance of the + :class:`scaleflow.training.CellFlowTrainer`. """ if self.train_data is None: raise ValueError("Dataloader not initialized. Please call `prepare_data` first.") @@ -539,21 +539,21 @@ def train( callbacks Callbacks to perform at each validation step. There are two types of callbacks: - Callbacks for computations should inherit from - :class:`~cellflow.training.ComputationCallback` see e.g. :class:`cellflow.training.Metrics`. - - Callbacks for logging should inherit from :class:`~cellflow.training.LoggingCallback` see - e.g. :class:`~cellflow.training.WandbLogger`. + :class:`~scaleflow.training.ComputationCallback` see e.g. :class:`scaleflow.training.Metrics`. + - Callbacks for logging should inherit from :class:`~scaleflow.training.LoggingCallback` see + e.g. :class:`~scaleflow.training.WandbLogger`. monitor_metrics Metrics to monitor. out_of_core_dataloading - If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data.JaxOutOfCoreTrainSampler` + If :obj:`True`, use out-of-core dataloading. Uses the :class:`scaleflow.data.JaxOutOfCoreTrainSampler` to load data that does not fit into GPU memory. Returns ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.dataloader` - the training dataloader. - - :attr:`cellflow.model.CellFlow.solver` - the trained solver. + - :attr:`scaleflow.model.CellFlow.dataloader` - the training dataloader. + - :attr:`scaleflow.model.CellFlow.solver` - the trained solver. """ if self.train_data is None: raise ValueError("Data not initialized. Please call `prepare_data` first.") @@ -595,8 +595,8 @@ def predict( covariate_data Covariate data defining the condition to predict. This :class:`~pandas.DataFrame` should have the same columns as :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata`, and as registered in - :attr:`cellflow.model.CellFlow.data_manager`. + :attr:`scaleflow.model.CellFlow.adata`, and as registered in + :attr:`scaleflow.model.CellFlow.data_manager`. sample_rep Key in :attr:`~anndata.AnnData.obsm` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. If :obj:`None`, the key is assumed to be @@ -608,12 +608,12 @@ def predict( If :obj:`None`, the predictions are not stored, and the predictions are returned as a :class:`dict`. rng - Random number generator. If :obj:`None` and :attr:`cellflow.model.CellFlow.conditino_mode` + Random number generator. If :obj:`None` and :attr:`scaleflow.model.CellFlow.conditino_mode` is ``'stochastic'``, the condition vector will be the mean of the learnt distributions, otherwise samples from the distribution. kwargs Keyword arguments for the predict function, i.e. - :meth:`cellflow.solvers.OTFlowMatching.predict` or :meth:`cellflow.solvers.GENOT.predict`. + :meth:`scaleflow.solvers.OTFlowMatching.predict` or :meth:`scaleflow.solvers.GENOT.predict`. Returns ------- @@ -679,7 +679,7 @@ def get_condition_embedding( """Get the embedding of the conditions. Outputs the mean and variance of the learnt embeddings - generated by the :class:`~cellflow.networks.ConditionEncoder`. + generated by the :class:`~scaleflow.networks.ConditionEncoder`. Parameters ---------- @@ -687,8 +687,8 @@ def get_condition_embedding( Can be one of - a :class:`~pandas.DataFrame` defining the conditions with the same columns as the - :class:`~anndata.AnnData` used for the initialisation of :class:`~cellflow.model.CellFlow`. - - an instance of :class:`~cellflow.data.ConditionData`. + :class:`~anndata.AnnData` used for the initialisation of :class:`~scaleflow.model.CellFlow`. + - an instance of :class:`~scaleflow.data.ConditionData`. rep_dict Dictionary containing the representations of the perturbation covariates. Will be considered an @@ -756,7 +756,7 @@ def save( """ Save the model. - Pickles the :class:`~cellflow.model.CellFlow` object. + Pickles the :class:`~scaleflow.model.CellFlow` object. Parameters ---------- @@ -789,7 +789,7 @@ def load( filename: str, ) -> "CellFlow": """ - Load a :class:`~cellflow.model.CellFlow` model from a saved instance. + Load a :class:`~scaleflow.model.CellFlow` model from a saved instance. Parameters ---------- @@ -837,7 +837,7 @@ def validation_data(self) -> dict[str, ValidationData]: @property def data_manager(self) -> DataManager: - """The data manager, initialised with :attr:`cellflow.model.CellFlow.adata`.""" + """The data manager, initialised with :attr:`scaleflow.model.CellFlow.adata`.""" return self._dm @property diff --git a/src/cellflow/model/_utils.py b/src/scaleflow/model/_utils.py similarity index 96% rename from src/cellflow/model/_utils.py rename to src/scaleflow/model/_utils.py index 76384b38..920bbf77 100644 --- a/src/cellflow/model/_utils.py +++ b/src/scaleflow/model/_utils.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike def _multivariate_normal( diff --git a/src/cellflow/networks/__init__.py b/src/scaleflow/networks/__init__.py similarity index 70% rename from src/cellflow/networks/__init__.py rename to src/scaleflow/networks/__init__.py index e8051b1c..2716121d 100644 --- a/src/cellflow/networks/__init__.py +++ b/src/scaleflow/networks/__init__.py @@ -1,7 +1,7 @@ -from cellflow.networks._set_encoders import ( +from scaleflow.networks._set_encoders import ( ConditionEncoder, ) -from cellflow.networks._utils import ( +from scaleflow.networks._utils import ( FilmBlock, MLPBlock, ResNetBlock, @@ -10,7 +10,7 @@ SelfAttentionBlock, TokenAttentionPooling, ) -from cellflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField +from scaleflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField __all__ = [ "ConditionalVelocityField", diff --git a/src/cellflow/networks/_set_encoders.py b/src/scaleflow/networks/_set_encoders.py similarity index 98% rename from src/cellflow/networks/_set_encoders.py rename to src/scaleflow/networks/_set_encoders.py index 74279872..8c334233 100644 --- a/src/cellflow/networks/_set_encoders.py +++ b/src/scaleflow/networks/_set_encoders.py @@ -9,8 +9,8 @@ from flax.training import train_state from flax.typing import FrozenDict -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.networks import _utils as nn_utils +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.networks import _utils as nn_utils __all__ = [ "ConditionEncoder", diff --git a/src/cellflow/networks/_utils.py b/src/scaleflow/networks/_utils.py similarity index 99% rename from src/cellflow/networks/_utils.py rename to src/scaleflow/networks/_utils.py index 3441330c..a6a72da5 100644 --- a/src/cellflow/networks/_utils.py +++ b/src/scaleflow/networks/_utils.py @@ -7,7 +7,7 @@ from flax import linen as nn from flax.linen import initializers -from cellflow._types import Layers_t +from scaleflow._types import Layers_t __all__ = [ "SelfAttention", diff --git a/src/cellflow/networks/_velocity_field.py b/src/scaleflow/networks/_velocity_field.py similarity index 98% rename from src/cellflow/networks/_velocity_field.py rename to src/scaleflow/networks/_velocity_field.py index 157ad4d8..416dada5 100644 --- a/src/cellflow/networks/_velocity_field.py +++ b/src/scaleflow/networks/_velocity_field.py @@ -9,9 +9,9 @@ from flax import linen as nn from flax.training import train_state -from cellflow._types import Layers_separate_input_t, Layers_t -from cellflow.networks._set_encoders import ConditionEncoder -from cellflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder +from scaleflow._types import Layers_separate_input_t, Layers_t +from scaleflow.networks._set_encoders import ConditionEncoder +from scaleflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder __all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField"] @@ -238,7 +238,7 @@ def get_condition_embedding(self, condition: dict[str, jnp.ndarray]) -> tuple[jn Returns ------- Learnt mean and log-variance of the condition embedding. - If :attr:`cellflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance + If :attr:`scaleflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance is set to zero. """ condition_mean, condition_logvar = self.condition_encoder(condition, training=False) diff --git a/src/scaleflow/plotting/__init__.py b/src/scaleflow/plotting/__init__.py new file mode 100644 index 00000000..45364b19 --- /dev/null +++ b/src/scaleflow/plotting/__init__.py @@ -0,0 +1,3 @@ +from scaleflow.plotting._plotting import plot_condition_embedding + +__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/plotting/_plotting.py b/src/scaleflow/plotting/_plotting.py similarity index 96% rename from src/cellflow/plotting/_plotting.py rename to src/scaleflow/plotting/_plotting.py index 389f37d0..6cfbeef7 100644 --- a/src/cellflow/plotting/_plotting.py +++ b/src/scaleflow/plotting/_plotting.py @@ -7,8 +7,8 @@ import seaborn as sns from adjustText import adjust_text -from cellflow import _constants -from cellflow.plotting._utils import ( +from scaleflow import _constants +from scaleflow.plotting._utils import ( _compute_kernel_pca_from_df, _compute_pca_from_df, _compute_umap_from_df, @@ -38,7 +38,7 @@ def plot_condition_embedding( df A :class:`pandas.DataFrame` with embedding and metadata. Column names of embedding dimensions should be consecutive integers starting from 0, - e.g. as output from :meth:`~cellflow.model.CellFlow.get_condition_embedding`, and + e.g. as output from :meth:`~scaleflow.model.CellFlow.get_condition_embedding`, and metadata should be in columns with strings. embedding Embedding to plot. Options are "raw_embedding", "UMAP", "PCA", "Kernel_PCA". diff --git a/src/cellflow/plotting/_utils.py b/src/scaleflow/plotting/_utils.py similarity index 98% rename from src/cellflow/plotting/_utils.py rename to src/scaleflow/plotting/_utils.py index 6378122d..e89b805d 100644 --- a/src/cellflow/plotting/_utils.py +++ b/src/scaleflow/plotting/_utils.py @@ -9,7 +9,7 @@ from sklearn.decomposition import KernelPCA from sklearn.metrics.pairwise import cosine_similarity -from cellflow import _constants, _logging +from scaleflow import _constants, _logging def set_plotting_vars( diff --git a/src/scaleflow/preprocessing/__init__.py b/src/scaleflow/preprocessing/__init__.py new file mode 100644 index 00000000..36e1bff1 --- /dev/null +++ b/src/scaleflow/preprocessing/__init__.py @@ -0,0 +1,9 @@ +from scaleflow.preprocessing._gene_emb import ( + GeneInfo, + get_esm_embedding, + prot_sequence_from_ensembl, + protein_features_from_genes, +) +from scaleflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca +from scaleflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints +from scaleflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/preprocessing/_gene_emb.py b/src/scaleflow/preprocessing/_gene_emb.py similarity index 99% rename from src/cellflow/preprocessing/_gene_emb.py rename to src/scaleflow/preprocessing/_gene_emb.py index cbddb59f..376f25e7 100644 --- a/src/cellflow/preprocessing/_gene_emb.py +++ b/src/scaleflow/preprocessing/_gene_emb.py @@ -7,7 +7,7 @@ import anndata as ad import pandas as pd -from cellflow._logging import logger +from scaleflow._logging import logger try: import requests # type: ignore[import-untyped] @@ -21,7 +21,7 @@ EsmModel = None raise ImportError( "To use gene embedding, please install `transformers` and `torch` \ - e.g. via `pip install cellflow['embedding']`." + e.g. via `pip install scaleflow['embedding']`." ) from e __all__ = [ diff --git a/src/cellflow/preprocessing/_pca.py b/src/scaleflow/preprocessing/_pca.py similarity index 99% rename from src/cellflow/preprocessing/_pca.py rename to src/scaleflow/preprocessing/_pca.py index b6b72238..6a0dc886 100644 --- a/src/cellflow/preprocessing/_pca.py +++ b/src/scaleflow/preprocessing/_pca.py @@ -3,7 +3,7 @@ import scanpy as sc from scipy.sparse import csr_matrix -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike __all__ = ["centered_pca", "reconstruct_pca", "project_pca"] diff --git a/src/cellflow/preprocessing/_preprocessing.py b/src/scaleflow/preprocessing/_preprocessing.py similarity index 98% rename from src/cellflow/preprocessing/_preprocessing.py rename to src/scaleflow/preprocessing/_preprocessing.py index a12bd627..96149d01 100644 --- a/src/cellflow/preprocessing/_preprocessing.py +++ b/src/scaleflow/preprocessing/_preprocessing.py @@ -5,9 +5,9 @@ import numpy as np import sklearn.preprocessing as preprocessing -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._utils import _to_list +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._utils import _to_list __all__ = ["encode_onehot", "annotate_compounds", "get_molecular_fingerprints"] diff --git a/src/cellflow/preprocessing/_wknn.py b/src/scaleflow/preprocessing/_wknn.py similarity index 99% rename from src/cellflow/preprocessing/_wknn.py rename to src/scaleflow/preprocessing/_wknn.py index 222a9dcf..5430f926 100644 --- a/src/cellflow/preprocessing/_wknn.py +++ b/src/scaleflow/preprocessing/_wknn.py @@ -6,8 +6,8 @@ import pandas as pd from scipy import sparse -from cellflow._logging import logger -from cellflow._types import ArrayLike +from scaleflow._logging import logger +from scaleflow._types import ArrayLike __all__ = ["compute_wknn", "transfer_labels"] diff --git a/src/scaleflow/solvers/__init__.py b/src/scaleflow/solvers/__init__.py new file mode 100644 index 00000000..35ff8cb8 --- /dev/null +++ b/src/scaleflow/solvers/__init__.py @@ -0,0 +1,4 @@ +from scaleflow.solvers._genot import GENOT +from scaleflow.solvers._otfm import OTFlowMatching + +__all__ = ["GENOT", "OTFlowMatching"] diff --git a/src/cellflow/solvers/_genot.py b/src/scaleflow/solvers/_genot.py similarity index 97% rename from src/cellflow/solvers/_genot.py rename to src/scaleflow/solvers/_genot.py index 7270ad7f..079ab922 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/scaleflow/solvers/_genot.py @@ -11,9 +11,9 @@ from ott.neural.networks import velocity_field from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.model._utils import _multivariate_normal +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.model._utils import _multivariate_normal __all__ = ["GENOT"] @@ -240,7 +240,7 @@ def predict( """Generate the push-forward of ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -257,7 +257,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. diff --git a/src/cellflow/solvers/_otfm.py b/src/scaleflow/solvers/_otfm.py similarity index 95% rename from src/cellflow/solvers/_otfm.py rename to src/scaleflow/solvers/_otfm.py index 31114a6b..e987b8e1 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/scaleflow/solvers/_otfm.py @@ -11,10 +11,10 @@ from ott.neural.methods.flows import dynamics from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.networks._velocity_field import ConditionalVelocityField -from cellflow.solvers.utils import ema_update +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import ConditionalVelocityField +from scaleflow.solvers.utils import ema_update __all__ = ["OTFlowMatching"] @@ -34,14 +34,14 @@ class OTFlowMatching: match_fn Function to match samples from the source and the target distributions. It has a ``(src, tgt) -> matching`` signature, - see e.g. :func:`cellflow.utils.match_linear`. If :obj:`None`, no + see e.g. :func:`scaleflow.utils.match_linear`. If :obj:`None`, no matching is performed, and pure probability_path matching :cite:`lipman:22` is applied. time_sampler Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g. :func:`ott.solvers.utils.uniform_sampler`. kwargs - Keyword arguments for :meth:`cellflow.networks.ConditionalVelocityField.create_train_state`. + Keyword arguments for :meth:`scaleflow.networks.ConditionalVelocityField.create_train_state`. """ def __init__( @@ -231,7 +231,7 @@ def predict( """Predict the translated source ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -249,7 +249,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. diff --git a/src/cellflow/solvers/utils.py b/src/scaleflow/solvers/utils.py similarity index 100% rename from src/cellflow/solvers/utils.py rename to src/scaleflow/solvers/utils.py diff --git a/src/cellflow/training/__init__.py b/src/scaleflow/training/__init__.py similarity index 79% rename from src/cellflow/training/__init__.py rename to src/scaleflow/training/__init__.py index 387411d2..c19a50dd 100644 --- a/src/cellflow/training/__init__.py +++ b/src/scaleflow/training/__init__.py @@ -1,4 +1,4 @@ -from cellflow.training._callbacks import ( +from scaleflow.training._callbacks import ( BaseCallback, CallbackRunner, ComputationCallback, @@ -8,7 +8,7 @@ VAEDecodedMetrics, WandbLogger, ) -from cellflow.training._trainer import CellFlowTrainer +from scaleflow.training._trainer import CellFlowTrainer __all__ = [ "CellFlowTrainer", diff --git a/src/cellflow/training/_callbacks.py b/src/scaleflow/training/_callbacks.py similarity index 93% rename from src/cellflow/training/_callbacks.py rename to src/scaleflow/training/_callbacks.py index 5b65f33f..92a4524d 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/scaleflow/training/_callbacks.py @@ -7,14 +7,14 @@ import jax.tree_util as jtu import numpy as np -from cellflow._types import ArrayLike -from cellflow.metrics._metrics import ( +from scaleflow._types import ArrayLike +from scaleflow.metrics._metrics import ( compute_e_distance_fast, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div, ) -from cellflow.solvers import _genot, _otfm +from scaleflow.solvers import _genot, _otfm __all__ = [ "BaseCallback", @@ -42,7 +42,7 @@ class BaseCallback(abc.ABC): - """Base class for callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self, *args: Any, **kwargs: Any) -> None: @@ -61,7 +61,7 @@ def on_train_end(self, *args: Any, **kwargs: Any) -> Any: class LoggingCallback(BaseCallback, abc.ABC): - """Base class for logging callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -92,7 +92,7 @@ def on_train_end(self, dict_to_log: dict[str, Any]) -> Any: class ComputationCallback(BaseCallback, abc.ABC): - """Base class for computation callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for computation callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -118,7 +118,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -146,7 +146,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -205,7 +205,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -240,7 +240,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -299,7 +299,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -328,7 +328,7 @@ class VAEDecodedMetrics(Metrics): ---------- vae A VAE model object with a ``'get_reconstruction'`` method, can be an instance - of :class:`cellflow.external.CFJaxSCVI`. + of :class:`scaleflow.external.CFJaxSCVI`. adata An :class:`~anndata.AnnData` object in the same format as the ``vae``. metrics @@ -374,7 +374,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -477,14 +477,14 @@ def on_train_end(self, dict_to_log: dict[str, float]) -> Any: class CallbackRunner: - """Runs a set of computational and logging callbacks in the :class:`~cellflow.training.CellFlowTrainer` + """Runs a set of computational and logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer` Parameters ---------- callbacks List of callbacks to run. Callbacks should be of type - :class:`~cellflow.training.ComputationCallback` or - :class:`~cellflow.training.LoggingCallback` + :class:`~scaleflow.training.ComputationCallback` or + :class:`~scaleflow.training.LoggingCallback` Returns ------- @@ -529,7 +529,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -565,7 +565,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns diff --git a/src/cellflow/training/_trainer.py b/src/scaleflow/training/_trainer.py similarity index 91% rename from src/cellflow/training/_trainer.py rename to src/scaleflow/training/_trainer.py index 5a359df4..ee8f3155 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/scaleflow/training/_trainer.py @@ -6,9 +6,9 @@ from numpy.typing import ArrayLike from tqdm import tqdm -from cellflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback, CallbackRunner +from scaleflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler +from scaleflow.solvers import _genot, _otfm +from scaleflow.training._callbacks import BaseCallback, CallbackRunner class CellFlowTrainer: @@ -19,12 +19,12 @@ class CellFlowTrainer: dataloader Data sampler. solver - :class:`~cellflow.solvers._otfm.OTFlowMatching` or - :class:`~cellflow.solvers._genot.GENOT` solver with a conditional velocity field. + :class:`~scaleflow.solvers._otfm.OTFlowMatching` or + :class:`~scaleflow.solvers._genot.GENOT` solver with a conditional velocity field. predict_kwargs Keyword arguments for the prediction functions - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. seed Random seed for subsampling validation data. diff --git a/src/cellflow/training/_utils.py b/src/scaleflow/training/_utils.py similarity index 100% rename from src/cellflow/training/_utils.py rename to src/scaleflow/training/_utils.py diff --git a/src/cellflow/utils.py b/src/scaleflow/utils.py similarity index 100% rename from src/cellflow/utils.py rename to src/scaleflow/utils.py diff --git a/tests/conftest.py b/tests/conftest.py index d6d95fdd..e793734d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from cellflow.data._dataloader import ValidationSampler +from scaleflow.data._dataloader import ValidationSampler @pytest.fixture diff --git a/tests/data/test_cfsampler.py b/tests/data/test_cfsampler.py index 18406e52..7c803e46 100644 --- a/tests/data/test_cfsampler.py +++ b/tests/data/test_cfsampler.py @@ -3,9 +3,9 @@ import numpy as np import pytest -from cellflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler -from cellflow.data._data import ZarrTrainingData -from cellflow.data._datamanager import DataManager +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler +from scaleflow.data._data import ZarrTrainingData +from scaleflow.data._datamanager import DataManager class TestTrainSampler: @@ -100,8 +100,8 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): class TestValidationSampler: @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 1, 3]) def test_valid_sampler(self, adata_perturbation, n_conditions_on_log_iteration): - from cellflow.data._dataloader import ValidationSampler - from cellflow.data._datamanager import DataManager + from scaleflow.data._dataloader import ValidationSampler + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] @@ -150,7 +150,7 @@ def test_pred_sampler( split_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager perturbation_covariates = {"drug": ["drug1", "drug2"]} diff --git a/tests/data/test_datamanager.py b/tests/data/test_datamanager.py index 237af9c7..91a64859 100644 --- a/tests/data/test_datamanager.py +++ b/tests/data/test_datamanager.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from cellflow.data._datamanager import DataManager +from scaleflow.data._datamanager import DataManager perturbation_covariates_args = [ {"drug": ["drug1"]}, @@ -38,7 +38,7 @@ def test_init_DataManager( perturbation_covariate_reps, sample_covariates, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -58,7 +58,7 @@ def test_init_DataManager( @pytest.mark.parametrize("el_to_delete", ["drug", "cell_type"]) def test_raise_false_uns_dict(self, adata_perturbation: ad.AnnData, el_to_delete): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -87,7 +87,7 @@ def test_raise_false_uns_dict(self, adata_perturbation: ad.AnnData, el_to_delete @pytest.mark.parametrize("el_to_delete", ["drug_b", "dosage_a"]) def test_raise_covar_mismatch(self, adata_perturbation: ad.AnnData, el_to_delete): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -113,7 +113,7 @@ def test_raise_covar_mismatch(self, adata_perturbation: ad.AnnData, el_to_delete ) def test_raise_target_without_source(self, adata_perturbation: ad.AnnData): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -156,8 +156,8 @@ def test_get_train_data( perturbation_covariate_reps, sample_covariates, ): - from cellflow.data._data import TrainingData - from cellflow.data._datamanager import DataManager + from scaleflow.data._data import TrainingData + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -211,7 +211,7 @@ def test_get_train_data_with_combinations( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -300,7 +300,7 @@ def test_get_validation_data( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] @@ -338,7 +338,7 @@ def test_get_validation_data( @pytest.mark.skip(reason="To discuss: why should it raise an error?") def test_raises_wrong_max_combination_length(self, adata_perturbation): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager max_combination_length = 3 adata = adata_perturbation @@ -378,7 +378,7 @@ def test_get_prediction_data( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] diff --git a/tests/data/test_old_get_condition_data.py b/tests/data/test_old_get_condition_data.py index a546281a..a7e94575 100644 --- a/tests/data/test_old_get_condition_data.py +++ b/tests/data/test_old_get_condition_data.py @@ -9,8 +9,8 @@ import pytest from tqdm import tqdm -from cellflow._types import ArrayLike -from cellflow.data._datamanager import ( +from scaleflow._types import ArrayLike +from scaleflow.data._datamanager import ( DataManager, ReturnData, _to_list, diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py index fa2a2cf0..ced4220e 100644 --- a/tests/data/test_torch_dataloader.py +++ b/tests/data/test_torch_dataloader.py @@ -1,5 +1,5 @@ -import cellflow -from cellflow.data import TorchCombinedTrainSampler +import scaleflow +from scaleflow.data import TorchCombinedTrainSampler class TestTorchDataloader: @@ -15,7 +15,7 @@ def test_torch_dataloader_shapes( perturbation_covariate_reps = {"drug": "drug"} batch_size = 18 - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, diff --git a/tests/external/test_CFJaxSCVI.py b/tests/external/test_CFJaxSCVI.py index d25d1907..9238afd4 100644 --- a/tests/external/test_CFJaxSCVI.py +++ b/tests/external/test_CFJaxSCVI.py @@ -1,7 +1,7 @@ import pytest from scvi.data import synthetic_iid -from cellflow.external import CFJaxSCVI +from scaleflow.external import CFJaxSCVI class TestCFJaxSCVI: diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 9239fcf1..e67f4c0e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -2,7 +2,7 @@ import numpy as np import pytest -import cellflow +import scaleflow class TestMetrics: @@ -11,8 +11,8 @@ def test_compute_metrics(self, metrics_data, prefix): x_test = metrics_data["x_test"] y_test = metrics_data["y_test"] - metrics = jtu.tree_map(cellflow.metrics.compute_metrics, x_test, y_test) - mean_metrics = cellflow.metrics.compute_mean_metrics(metrics, prefix) + metrics = jtu.tree_map(scaleflow.metrics.compute_metrics, x_test, y_test) + mean_metrics = scaleflow.metrics.compute_mean_metrics(metrics, prefix) assert "Alvespimycin+Pirarubicin" in metrics.keys() assert {"r_squared", "sinkhorn_div_1", "sinkhorn_div_10", "sinkhorn_div_100", "e_distance", "mmd"} == set( @@ -32,12 +32,12 @@ def test_function_output(self, metrics_data, epsilon): x_test = metrics_data["x_test"]["Alvespimycin+Pirarubicin"] y_test = metrics_data["y_test"]["Alvespimycin+Pirarubicin"] - r_squared = cellflow.metrics.compute_r_squared(x_test, y_test) - sinkhorn_div = cellflow.metrics.compute_sinkhorn_div(x_test, y_test, epsilon=epsilon) - e_distance = cellflow.metrics.compute_e_distance(x_test, y_test) - e_distance_fast = cellflow.metrics.compute_e_distance_fast(x_test, y_test) - scalar_mmd = cellflow.metrics.compute_scalar_mmd(x_test, y_test) - mmd_fast = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, exact=False) + r_squared = scaleflow.metrics.compute_r_squared(x_test, y_test) + sinkhorn_div = scaleflow.metrics.compute_sinkhorn_div(x_test, y_test, epsilon=epsilon) + e_distance = scaleflow.metrics.compute_e_distance(x_test, y_test) + e_distance_fast = scaleflow.metrics.compute_e_distance_fast(x_test, y_test) + scalar_mmd = scaleflow.metrics.compute_scalar_mmd(x_test, y_test) + mmd_fast = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, exact=False) assert -1000 <= r_squared <= 1 assert sinkhorn_div >= 0 @@ -51,11 +51,11 @@ def test_fast_metrics(self, metrics_data, gamma): x_test = metrics_data["x_test"]["Alvespimycin+Pirarubicin"] y_test = metrics_data["y_test"]["Alvespimycin+Pirarubicin"] - e_distance = cellflow.metrics.compute_e_distance(x_test, y_test) - e_distance_fast = cellflow.metrics.compute_e_distance_fast(x_test, y_test) + e_distance = scaleflow.metrics.compute_e_distance(x_test, y_test) + e_distance_fast = scaleflow.metrics.compute_e_distance_fast(x_test, y_test) - mmd = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=True) - mmd_fast = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=False) + mmd = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=True) + mmd_fast = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=False) assert np.allclose(e_distance, e_distance_fast, rtol=1e-4, atol=1e-4) assert np.allclose(mmd, mmd_fast, rtol=1e-4, atol=1e-4) diff --git a/tests/model/test_cellflow.py b/tests/model/test_scaleflow.py similarity index 95% rename from tests/model/test_cellflow.py rename to tests/model/test_scaleflow.py index 024349cb..643ac293 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_scaleflow.py @@ -2,8 +2,8 @@ import pandas as pd import pytest -import cellflow -from cellflow.networks import _velocity_field +import scaleflow +from scaleflow.networks import _velocity_field perturbation_covariate_comb_args = [ {"drug": ["drug1"]}, @@ -21,7 +21,7 @@ class TestCellFlow: @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) - def test_cellflow_solver( + def test_scaleflow_solver( self, adata_perturbation, solver, @@ -38,7 +38,7 @@ def test_cellflow_solver( condition_embedding_dim = 32 vf_kwargs = {"genot_source_dims": (32, 32), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -135,7 +135,7 @@ def test_cellflow_solver( @pytest.mark.slow @pytest.mark.parametrize("solver", ["otfm", "genot"]) @pytest.mark.parametrize("perturbation_covariate_reps", [{}, {"drug": "drug"}]) - def test_cellflow_covar_reps( + def test_scaleflow_covar_reps( self, adata_perturbation, perturbation_covariate_reps, @@ -148,7 +148,7 @@ def test_cellflow_covar_reps( condition_embedding_dim = 32 vf_kwargs = {"genot_source_dims": (32, 32), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -203,7 +203,7 @@ def test_cellflow_covar_reps( @pytest.mark.parametrize("perturbation_covariates", perturbation_covariate_comb_args) @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 0, 2]) @pytest.mark.parametrize("n_conditions_on_train_end", [None, 0, 2]) - def test_cellflow_val_data_loading( + def test_scaleflow_val_data_loading( self, adata_perturbation, split_covariates, @@ -211,7 +211,7 @@ def test_cellflow_val_data_loading( n_conditions_on_log_iteration, n_conditions_on_train_end, ): - cf = cellflow.model.CellFlow(adata_perturbation) + cf = scaleflow.model.CellFlow(adata_perturbation) cf.prepare_data( sample_rep="X", control_key="control", @@ -251,7 +251,7 @@ def test_cellflow_val_data_loading( @pytest.mark.parametrize("solver", ["otfm", "genot"]) @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 0, 1]) @pytest.mark.parametrize("n_conditions_on_train_end", [None, 0, 1]) - def test_cellflow_with_validation( + def test_scaleflow_with_validation( self, adata_perturbation, solver, @@ -259,7 +259,7 @@ def test_cellflow_with_validation( n_conditions_on_train_end, ): vf_kwargs = {"genot_source_dims": (2, 2), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", control_key="control", @@ -300,7 +300,7 @@ def test_cellflow_with_validation( assert cf._trainer is not None metric_to_compute = "r_squared" - metrics_callback = cellflow.training.Metrics(metrics=[metric_to_compute]) + metrics_callback = scaleflow.training.Metrics(metrics=[metric_to_compute]) cf.train(num_iterations=3, callbacks=[metrics_callback], valid_freq=1) assert cf._dataloader is not None @@ -310,14 +310,14 @@ def test_cellflow_with_validation( @pytest.mark.parametrize("solver", ["otfm", "genot"]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) - def test_cellflow_predict( + def test_scaleflow_predict( self, adata_perturbation, solver, condition_mode, regularization, ): - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", control_key="control", @@ -386,7 +386,7 @@ def test_cellflow_predict( def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): vf_kwargs = {"genot_source_dims": (2, 2), "genot_source_dropouts": 0.1} - cf = cellflow.model.CellFlow(adata_perturbation, solver="otfm") + cf = scaleflow.model.CellFlow(adata_perturbation, solver="otfm") cf.prepare_data( sample_rep="X", control_key="control", @@ -413,7 +413,7 @@ def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): @pytest.mark.parametrize("perturbation_covariates", perturbation_covariate_comb_args) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) - def test_cellflow_get_condition_embedding( + def test_scaleflow_get_condition_embedding( self, adata_perturbation, sample_covariate_and_reps, @@ -430,7 +430,7 @@ def test_cellflow_get_condition_embedding( condition_embedding_dim = 2 solver = "otfm" - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -504,7 +504,7 @@ def test_time_embedding( solver = "otfm" time_freqs = 1024 - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, diff --git a/tests/networks/test_aggregators.py b/tests/networks/test_aggregators.py index a2ce546a..682ae6d1 100644 --- a/tests/networks/test_aggregators.py +++ b/tests/networks/test_aggregators.py @@ -2,8 +2,8 @@ import jax.numpy as jnp import pytest -from cellflow.networks._set_encoders import ConditionEncoder -from cellflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling +from scaleflow.networks._set_encoders import ConditionEncoder +from scaleflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling class TestAggregator: diff --git a/tests/networks/test_condencoder.py b/tests/networks/test_condencoder.py index 54589d62..325278ee 100644 --- a/tests/networks/test_condencoder.py +++ b/tests/networks/test_condencoder.py @@ -3,7 +3,7 @@ import optax import pytest -import cellflow +import scaleflow cond = { "pert1": jnp.ones((1, 3, 3)), @@ -54,7 +54,7 @@ class TestConditionEncoder: def test_condition_encoder_init( self, pooling, covariates_not_pooled, layers_before_pool, layers_after_pool, condition_mode, regularization ): - cond_encoder = cellflow.networks.ConditionEncoder( + cond_encoder = scaleflow.networks.ConditionEncoder( output_dim=5, condition_mode=condition_mode, regularization=regularization, diff --git a/tests/networks/test_velocityfield.py b/tests/networks/test_velocityfield.py index 4e651d26..b592573f 100644 --- a/tests/networks/test_velocityfield.py +++ b/tests/networks/test_velocityfield.py @@ -4,7 +4,7 @@ import pytest from flax.linen import activation -from cellflow.networks import _velocity_field +from scaleflow.networks import _velocity_field x_test = jnp.ones((10, 5)) * 10 t_test = jnp.ones((10, 1)) diff --git a/tests/plotting/test_plotting.py b/tests/plotting/test_plotting.py index 7f431ff2..b3b0b0e6 100644 --- a/tests/plotting/test_plotting.py +++ b/tests/plotting/test_plotting.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import pytest -from cellflow.plotting import plot_condition_embedding +from scaleflow.plotting import plot_condition_embedding class TestCallbacks: diff --git a/tests/preprocessing/test_gene_emb.py b/tests/preprocessing/test_gene_emb.py index 4f79277e..3a1e3be0 100644 --- a/tests/preprocessing/test_gene_emb.py +++ b/tests/preprocessing/test_gene_emb.py @@ -6,7 +6,7 @@ import pytest import torch -from cellflow.preprocessing._gene_emb import get_esm_embedding +from scaleflow.preprocessing._gene_emb import get_esm_embedding IS_PROT_CODING = Counter(["ENSG00000139618", "ENSG00000206450", "ENSG00000049192"]) ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../test_artifacts/") diff --git a/tests/preprocessing/test_pca.py b/tests/preprocessing/test_pca.py index 662c721f..dccc5ff0 100644 --- a/tests/preprocessing/test_pca.py +++ b/tests/preprocessing/test_pca.py @@ -5,9 +5,9 @@ class TestPCA: def test_centered_pca(self, adata_pca: ad.AnnData): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) assert "X_pca" in adata_pca.obsm assert "PCs" in adata_pca.varm assert "X_mean" in adata_pca.varm @@ -17,10 +17,10 @@ def test_centered_pca(self, adata_pca: ad.AnnData): @pytest.mark.parametrize("layers_key_added", ["X_recon", "X_rec"]) def test_reconstruct_pca(self, adata_pca: ad.AnnData, layers_key_added): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - cellflow.pp.reconstruct_pca( + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.reconstruct_pca( adata_pca, ref_adata=adata_pca, use_rep="X_pca", @@ -36,18 +36,18 @@ def test_reconstruct_pca(self, adata_pca: ad.AnnData, layers_key_added): ) def test_reconstruct_pca_with_array_input(self, adata_pca: ad.AnnData): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - cellflow.pp.reconstruct_pca(adata_pca, ref_means=adata_pca.varm["X_mean"], ref_pcs=adata_pca.varm["PCs"]) + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.reconstruct_pca(adata_pca, ref_means=adata_pca.varm["X_mean"], ref_pcs=adata_pca.varm["PCs"]) assert "X_recon" in adata_pca.layers @pytest.mark.parametrize("obsm_key_added", ["X_pca", "X_pca_projected"]) def test_project_pca(self, adata_pca: ad.AnnData, obsm_key_added): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - adata_pca_project = cellflow.pp.project_pca( + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + adata_pca_project = scaleflow.pp.project_pca( adata_pca, ref_adata=adata_pca, obsm_key_added=obsm_key_added, copy=True ) assert obsm_key_added in adata_pca_project.obsm diff --git a/tests/preprocessing/test_preprocessing.py b/tests/preprocessing/test_preprocessing.py index e5423663..ac9af815 100644 --- a/tests/preprocessing/test_preprocessing.py +++ b/tests/preprocessing/test_preprocessing.py @@ -13,10 +13,10 @@ class TestPreprocessing: ], ) def test_annotate_compounds(self, adata_with_compounds: ad.AnnData, compound_key_and_type): - import cellflow + import scaleflow try: - cellflow.pp.annotate_compounds( + scaleflow.pp.annotate_compounds( adata_with_compounds, compound_keys=compound_key_and_type[0], query_id_type=compound_key_and_type[1], @@ -45,11 +45,11 @@ def test_annotate_compounds(self, adata_with_compounds: ad.AnnData, compound_key ], ) def test_get_molecular_fingerprints(self, adata_with_compounds: ad.AnnData, n_bits, compound_and_smiles_keys): - import cellflow + import scaleflow uns_key_added = "compound_fingerprints" - cellflow.pp.get_molecular_fingerprints( + scaleflow.pp.get_molecular_fingerprints( adata_with_compounds, compound_keys=compound_and_smiles_keys[0], smiles_keys=compound_and_smiles_keys[1], @@ -67,9 +67,9 @@ def test_get_molecular_fingerprints(self, adata_with_compounds: ad.AnnData, n_bi @pytest.mark.parametrize("uns_key_added", ["compounds", "compounds_onehot"]) @pytest.mark.parametrize("exclude_values", [None, "GW0742"]) def test_encode_onehot(self, adata_with_compounds: ad.AnnData, uns_key_added, exclude_values): - import cellflow + import scaleflow - cellflow.pp.encode_onehot( + scaleflow.pp.encode_onehot( adata_with_compounds, covariate_keys="compound_name", uns_key_added=uns_key_added, diff --git a/tests/preprocessing/test_wknn.py b/tests/preprocessing/test_wknn.py index 2bbe12ab..be52cd58 100644 --- a/tests/preprocessing/test_wknn.py +++ b/tests/preprocessing/test_wknn.py @@ -6,9 +6,9 @@ class TestWKNN: @pytest.mark.parametrize("n_neighbors", [50, 100]) def test_compute_wknn_k(self, adata_perturbation: ad.AnnData, n_neighbors): - import cellflow + import scaleflow - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -24,12 +24,12 @@ def test_compute_wknn_k(self, adata_perturbation: ad.AnnData, n_neighbors): @pytest.mark.parametrize("weighting_scheme", ["top_n", "jaccard", "jaccard_square"]) def test_compute_wknn_weighting(self, adata_perturbation: ad.AnnData, weighting_scheme): - import cellflow + import scaleflow n_neighbors = 50 top_n = 10 - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -50,11 +50,11 @@ def test_compute_wknn_weighting(self, adata_perturbation: ad.AnnData, weighting_ @pytest.mark.parametrize("uns_key_added", ["wknn", "wknn2"]) def test_compute_wknn_key_added(self, adata_perturbation: ad.AnnData, uns_key_added): - import cellflow + import scaleflow n_neighbors = 50 - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -71,16 +71,16 @@ def test_compute_wknn_key_added(self, adata_perturbation: ad.AnnData, uns_key_ad @pytest.mark.parametrize("label_key", ["drug1", "cell_type"]) def test_transfer_labels(self, adata_perturbation: ad.AnnData, label_key): - import cellflow + import scaleflow - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=50, copy=False, ) - cellflow.pp.transfer_labels( + scaleflow.pp.transfer_labels( adata_perturbation, adata_perturbation, label_key=label_key, diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index 6bf10401..fe31f73d 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -7,9 +7,9 @@ import pytest from ott.neural.methods.flows import dynamics -import cellflow -from cellflow.solvers import _genot, _otfm -from cellflow.utils import match_linear +import scaleflow +from scaleflow.solvers import _genot, _otfm +from scaleflow.utils import match_linear src = { ("drug_1",): np.random.rand(10, 5), @@ -26,9 +26,9 @@ class TestSolver: @pytest.mark.parametrize("solver_class", ["otfm", "genot"]) def test_predict_batch(self, dataloader, solver_class): if solver_class == "otfm": - vf_class = cellflow.networks.ConditionalVelocityField + vf_class = scaleflow.networks.ConditionalVelocityField else: - vf_class = cellflow.networks.GENOTConditionalVelocityField + vf_class = scaleflow.networks.GENOTConditionalVelocityField opt = optax.adam(1e-3) vf = vf_class( @@ -60,7 +60,7 @@ def test_predict_batch(self, dataloader, solver_class): ) predict_kwargs = {"max_steps": 3, "throw": False} - trainer = cellflow.training.CellFlowTrainer(solver=solver, predict_kwargs=predict_kwargs) + trainer = scaleflow.training.CellFlowTrainer(solver=solver, predict_kwargs=predict_kwargs) trainer.train( dataloader=dataloader, num_iterations=2, @@ -91,7 +91,7 @@ def test_predict_batch(self, dataloader, solver_class): @pytest.mark.parametrize("ema", [0.5, 1.0]) def test_EMA(self, dataloader, ema): - vf_class = cellflow.networks.ConditionalVelocityField + vf_class = scaleflow.networks.ConditionalVelocityField drug = np.random.rand(2, 1, 3) opt = optax.adam(1e-3) vf1 = vf_class( @@ -111,7 +111,7 @@ def test_EMA(self, dataloader, ema): rng=vf_rng, ema=ema, ) - trainer1 = cellflow.training.CellFlowTrainer(solver=solver1) + trainer1 = scaleflow.training.CellFlowTrainer(solver=solver1) trainer1.train( dataloader=dataloader, num_iterations=5, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index f1346ce6..6c00ad75 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -7,7 +7,7 @@ class TestCallbacks: @pytest.mark.parametrize("metrics", [["r_squared"]]) def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): - from cellflow.training import PCADecodedMetrics + from scaleflow.training import PCADecodedMetrics decoded_metrics_callback = PCADecodedMetrics( metrics=metrics, @@ -22,8 +22,8 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): def test_vae_reconstruction(self, metrics): from scvi.data import synthetic_iid - from cellflow.external import CFJaxSCVI - from cellflow.training import VAEDecodedMetrics + from scaleflow.external import CFJaxSCVI + from scaleflow.training import VAEDecodedMetrics adata = synthetic_iid() CFJaxSCVI.setup_anndata( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index beef4eb1..fb693bf2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,10 +7,10 @@ import pytest from ott.neural.methods.flows import dynamics -import cellflow -from cellflow.solvers import _otfm -from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics -from cellflow.utils import match_linear +import scaleflow +from scaleflow.solvers import _otfm +from scaleflow.training import CellFlowTrainer, ComputationCallback, Metrics +from scaleflow.utils import match_linear x_test = jnp.ones((10, 5)) * 10 t_test = jnp.ones((10, 1)) @@ -43,9 +43,9 @@ def on_train_end(self, source_data, validation_data, predicted_data, solver): class TestTrainer: @pytest.mark.parametrize("valid_freq", [10, 1]) - def test_cellflow_trainer(self, dataloader, valid_freq): + def test_scaleflow_trainer(self, dataloader, valid_freq): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -79,9 +79,9 @@ def test_cellflow_trainer(self, dataloader, valid_freq): assert out[1].shape == (1, 12) @pytest.mark.parametrize("use_validdata", [True, False]) - def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_validdata): + def test_scaleflow_trainer_with_callback(self, dataloader, valid_loader, use_validdata): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -124,9 +124,9 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali assert isinstance(out[1], np.ndarray) assert out[1].shape == (1, 12) - def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): + def test_scaleflow_trainer_with_custom_callback(self, dataloader, valid_loader): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( condition_mode="stochastic", output_dim=5, max_combination_length=2, @@ -164,14 +164,14 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): def test_predict_kwargs_iter(self, dataloader, valid_loader): opt_1 = optax.adam(1e-3) opt_2 = optax.adam(1e-3) - vf_1 = cellflow.networks.ConditionalVelocityField( + vf_1 = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, hidden_dims=(32, 32), decoder_dims=(32, 32), ) - vf_2 = cellflow.networks.ConditionalVelocityField( + vf_2 = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -196,13 +196,13 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): ) metric_to_compute = "e_distance" - metrics_callback = cellflow.training.Metrics(metrics=[metric_to_compute]) + metrics_callback = scaleflow.training.Metrics(metrics=[metric_to_compute]) predict_kwargs_1 = {"max_steps": 3, "throw": False} predict_kwargs_2 = {"max_steps": 500, "throw": False} - trainer_1 = cellflow.training.CellFlowTrainer(solver=model_1, predict_kwargs=predict_kwargs_1) - trainer_2 = cellflow.training.CellFlowTrainer(solver=model_2, predict_kwargs=predict_kwargs_2) + trainer_1 = scaleflow.training.CellFlowTrainer(solver=model_1, predict_kwargs=predict_kwargs_1) + trainer_2 = scaleflow.training.CellFlowTrainer(solver=model_2, predict_kwargs=predict_kwargs_2) start_1 = time.time() trainer_1.train( From 85431c4aad0b12eb08b91ad80d875f6c28802022 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 29 Sep 2025 17:52:53 +0200 Subject: [PATCH 32/35] load initial caching parallel --- src/scaleflow/data/_dataloader.py | 19 ++++++++++++++----- src/scaleflow/data/_datamanager.py | 27 ++++++++++++++------------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/scaleflow/data/_dataloader.py b/src/scaleflow/data/_dataloader.py index 162579b8..4e3aed9c 100644 --- a/src/scaleflow/data/_dataloader.py +++ b/src/scaleflow/data/_dataloader.py @@ -3,6 +3,7 @@ import numpy as np import tqdm +import os import threading from concurrent.futures import ThreadPoolExecutor, Future @@ -191,11 +192,19 @@ def _init_cache_pool_elements(self): raise ValueError("Pool not initialized. Call init_pool(rng) first.") with self._lock: self._cached_srcs = {i: self._data.src_cell_data[i][...] for i in self._src_idx_pool} - self._cached_tgts = { - j: self._data.tgt_cell_data[j][...] - for i in self._src_idx_pool - for j in self._data.control_to_perturbation[i] - } + tgt_indices = sorted( + {int(j) for i in self._src_idx_pool for j in self._data.control_to_perturbation[i]} + ) + + def _load_tgt(j: int): + return j, self._data.tgt_cell_data[j][...] + + max_workers = min(32, (os.cpu_count() or 4)) + with ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(_load_tgt, tgt_indices)) + + with self._lock: + self._cached_tgts = {j: arr for j, arr in results} def _init_pool(self, rng): """Initialize the pool with random source distribution indices.""" diff --git a/src/scaleflow/data/_datamanager.py b/src/scaleflow/data/_datamanager.py index 5cb60550..2274256d 100644 --- a/src/scaleflow/data/_datamanager.py +++ b/src/scaleflow/data/_datamanager.py @@ -1,17 +1,18 @@ -from collections import OrderedDict -from collections.abc import Sequence -from typing import Any - -import anndata -import dask -import dask.dataframe as dd -import dask.delayed +import abc +from typing import Any, Literal + import numpy as np -import pandas as pd -import scipy.sparse as sp -import sklearn.preprocessing as preprocessing -from dask.diagnostics import ProgressBar -from pandas.api.types import is_numeric_dtype +import tqdm +import threading +from concurrent.futures import ThreadPoolExecutor, Future +import os + +from scaleflow.data._data import ( + PredictionData, + TrainingData, + ValidationData, + MappedCellData, +) from scaleflow._logging import logger from scaleflow._types import ArrayLike From f0d0be318a6aa69e962df06c627b2b545203e810 Mon Sep 17 00:00:00 2001 From: AlejandroTL Date: Tue, 30 Sep 2025 14:19:52 +0200 Subject: [PATCH 33/35] alejandro/data_splitter --- src/scaleflow/data/_data_splitter.py | 855 +++++++++++++++++++++++++++ 1 file changed, 855 insertions(+) create mode 100644 src/scaleflow/data/_data_splitter.py diff --git a/src/scaleflow/data/_data_splitter.py b/src/scaleflow/data/_data_splitter.py new file mode 100644 index 00000000..cd3fb157 --- /dev/null +++ b/src/scaleflow/data/_data_splitter.py @@ -0,0 +1,855 @@ +"""Data splitter for creating train/validation/test splits from TrainingData objects.""" + +import logging +import warnings +from pathlib import Path +from typing import Literal + +import numpy as np +from sklearn.model_selection import train_test_split + +from scaleflow.data._data import MappedCellData, TrainingData + +logger = logging.getLogger(__name__) + +SplitType = Literal["holdout_groups", "holdout_combinations", "random", "stratified"] + + +class DataSplitter: + """ + A lightweight class for creating train/validation/test splits from TrainingData objects. + + This class extracts metadata from TrainingData objects and returns split indices, + making it memory-efficient for large datasets. + + Supports various splitting strategies: + - holdout_groups: Hold out specific groups (perturbations, cell lines, etc.) for validation/test + - holdout_combinations: Keep single treatments in training, hold out combinations for validation/test + - random: Random split of cells + - stratified: Stratified split maintaining proportions + + Parameters + ---------- + training_datasets : list[TrainingData | MappedCellData] + List of TrainingData or MappedCellData objects to process + dataset_names : list[str] + List of names for each dataset (for saving/loading) + split_ratios : list[list[float]] + List of triples, each indicating [train, validation, test] ratios for each dataset. + Each triple must sum to 1.0. Length must match training_datasets. + split_type : SplitType + Type of split to perform + split_key : str | list[str] | None + Column name(s) in adata.obs to use for splitting (required for holdout_groups and holdout_combinations). + Can be a single column or list of columns for combination treatments. + force_training_values : list[str] | None + Values that should be forced to appear only in training (e.g., ['control', 'dmso']). + These values will never appear in validation or test sets. + control_value : str | list[str] | None + Value(s) that represent control/untreated condition (e.g., 'control' or ['control', 'dmso']). + Required for holdout_combinations split type. + hard_test_split : bool + If True, validation and test get completely different groups (no overlap). + If False, validation and test can share groups, split at cell level. + Applies to all split types for consistent val/test separation control. + random_state : int + Random seed for reproducible splits + + Examples + -------- + >>> # Split by holdout groups with forced training values + >>> splitter = DataSplitter( + ... training_datasets=[train_data1, train_data2], + ... dataset_names=["dataset1", "dataset2"], + ... split_ratios=[[0.8, 0.2, 0.0], [0.9, 0.1, 0.0]], + ... split_type="holdout_groups", + ... split_key=["drug1", "drug2"], + ... force_training_values=["control", "dmso"], + ... ) + >>> # Split by holding out combinations (singletons in training) + >>> splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["dataset"], + ... split_ratios=[[0.8, 0.2, 0.0]], + ... split_type="holdout_combinations", + ... split_key=["drug1", "drug2"], + ... control_value=["control", "dmso"], + ... ) + >>> results = splitter.split_all_datasets() + >>> splitter.save_splits("./splits") + + >>> # Load split information later + >>> split_info = DataSplitter.load_split_info("./splits", "dataset1") + >>> train_indices = split_info["indices"]["train"] + """ + + def __init__( + self, + training_datasets: list[TrainingData | MappedCellData], + dataset_names: list[str], + split_ratios: list[list[float]], + split_type: SplitType = "random", + split_key: str | list[str] | None = None, + force_training_values: list[str] | None = None, + control_value: str | list[str] | None = None, + hard_test_split: bool = True, + random_state: int = 42, + ): + self.training_datasets = training_datasets + self.dataset_names = dataset_names + self.split_ratios = split_ratios + self.split_type = split_type + self.split_key = split_key + self.force_training_values = force_training_values or [] + self.control_value = [control_value] if isinstance(control_value, str) else control_value + self.hard_test_split = hard_test_split + self.random_state = random_state + + self._validate_inputs() + + self.split_results: dict[str, dict] = {} + + def _validate_inputs(self) -> None: + """Validate input parameters.""" + if len(self.training_datasets) != len(self.dataset_names): + raise ValueError( + f"training_datasets length ({len(self.training_datasets)}) must match " + f"dataset_names length ({len(self.dataset_names)})" + ) + + if not isinstance(self.split_ratios, list): + raise ValueError("split_ratios must be a list of lists") + + if len(self.split_ratios) != len(self.training_datasets): + raise ValueError( + f"split_ratios length ({len(self.split_ratios)}) must match " + f"training_datasets length ({len(self.training_datasets)})" + ) + + # Check each split ratio + for i, ratios in enumerate(self.split_ratios): + if not isinstance(ratios, list) or len(ratios) != 3: + raise ValueError(f"split_ratios[{i}] must be a list of 3 values [train, val, test]") + + if not np.isclose(sum(ratios), 1.0): + raise ValueError(f"split_ratios[{i}] must sum to 1.0, got {sum(ratios)}") + + if any(ratio < 0 for ratio in ratios): + raise ValueError(f"All values in split_ratios[{i}] must be non-negative") + + # Check split key requirement + if self.split_type in ["holdout_groups", "holdout_combinations"] and self.split_key is None: + raise ValueError(f"split_key must be provided for split_type '{self.split_type}'") + + # Check control_value requirement for holdout_combinations + if self.split_type == "holdout_combinations" and self.control_value is None: + raise ValueError("control_value must be provided for split_type 'holdout_combinations'") + + for i, td in enumerate(self.training_datasets): + if not isinstance(td, (TrainingData, MappedCellData)): + raise ValueError(f"training_datasets[{i}] must be a TrainingData or MappedCellData object") + + def extract_perturbation_info(self, training_data: TrainingData | MappedCellData) -> dict: + """ + Extract perturbation information from TrainingData or MappedCellData. + + Parameters + ---------- + training_data : TrainingData | MappedCellData + Training data object + + Returns + ------- + dict + Dictionary containing: + - perturbation_covariates_mask: array mapping cells to perturbation indices + - perturbation_idx_to_covariates: dict mapping perturbation indices to covariate tuples + - n_cells: total number of cells + """ + perturbation_covariates_mask = np.asarray(training_data.perturbation_covariates_mask) + perturbation_idx_to_covariates = training_data.perturbation_idx_to_covariates + + n_cells = len(perturbation_covariates_mask) + + logger.info(f"Extracted perturbation info for {n_cells} cells") + logger.info(f"Number of unique perturbations: {len(perturbation_idx_to_covariates)}") + + return { + "perturbation_covariates_mask": perturbation_covariates_mask, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "n_cells": n_cells, + } + + def _get_unique_perturbation_values( + self, perturbation_idx_to_covariates: dict[int, tuple[str, ...]] + ) -> list[str]: + """Get all unique covariate values from perturbation dictionary.""" + all_unique_vals = set() + for covariates in perturbation_idx_to_covariates.values(): + all_unique_vals.update(covariates) + return list(all_unique_vals) + + def _split_random(self, n_cells: int, split_ratios: list[float]) -> dict[str, np.ndarray]: + """Perform random split of cells.""" + train_ratio, val_ratio, test_ratio = split_ratios + + # Generate random indices + indices = np.arange(n_cells) + np.random.seed(self.random_state) + np.random.shuffle(indices) + + if self.hard_test_split: + # HARD: Val and test are completely separate + train_end = int(train_ratio * n_cells) + val_end = train_end + int(val_ratio * n_cells) + + train_idx = indices[:train_end] + val_idx = indices[train_end:val_end] if val_ratio > 0 else np.array([]) + test_idx = indices[val_end:] if test_ratio > 0 else np.array([]) + + logger.info("HARD RANDOM SPLIT: Completely separate val/test") + else: + # SOFT: Val and test can overlap (split val+test at cell level) + train_end = int(train_ratio * n_cells) + train_idx = indices[:train_end] + val_test_idx = indices[train_end:] + + # Split val+test according to val/test ratios + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT RANDOM SPLIT: Val/test can overlap") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_by_values( + self, + perturbation_covariates_mask: np.ndarray, + perturbation_idx_to_covariates: dict[int, tuple[str, ...]], + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Split by holding out specific perturbations.""" + if self.split_key is None: + raise ValueError("split_key must be provided for holdout_groups splitting") + + # Get all unique covariate values + unique_values = self._get_unique_perturbation_values(perturbation_idx_to_covariates) + + # Remove forced training values from consideration for val/test splits + available_values = [v for v in unique_values if v not in self.force_training_values] + forced_train_values = [v for v in unique_values if v in self.force_training_values] + + logger.info(f"Total unique values: {len(unique_values)}") + logger.info(f"Forced training values: {forced_train_values}") + logger.info(f"Available for val/test: {len(available_values)}") + + n_values = len(available_values) + + if n_values < 3: + warnings.warn( + f"Only {n_values} unique values found across columns {self.split_key}. " + "Consider using random split instead.", + stacklevel=2, + ) + + # Split values according to ratios + train_ratio, val_ratio, test_ratio = split_ratios + + # Calculate number of values for each split + n_train = max(1, int(train_ratio * n_values)) + n_val = int(val_ratio * n_values) + n_test = n_values - n_train - n_val + + # Ensure we don't exceed total values + if n_train + n_val + n_test != n_values: + n_test = n_values - n_train - n_val + + # Shuffle available values for random assignment (excluding forced training values) + np.random.seed(self.random_state) + shuffled_values = np.random.permutation(available_values) + + # Assign values to splits + train_values_random = shuffled_values[:n_train] + val_values = shuffled_values[n_train : n_train + n_val] if n_val > 0 else [] + test_values = shuffled_values[n_train + n_val :] if n_test > 0 else [] + + # Combine forced training values with randomly assigned training values + train_values = list(train_values_random) + forced_train_values + + logger.info(f"Split values - Train: {len(train_values)}, Val: {len(val_values)}, Test: {len(test_values)}") + logger.info(f"Train values: {train_values}") + logger.info(f"Val values: {val_values}") + logger.info(f"Test values: {test_values}") + + # Create masks by checking which perturbation indices contain which values + def _get_cells_with_values(values_set): + """Get cell indices for perturbations containing any of the specified values.""" + if len(values_set) == 0: + return np.array([], dtype=int) + + # Find perturbation indices that contain any of these values + matching_pert_indices = [] + for pert_idx, covariates in perturbation_idx_to_covariates.items(): + if any(val in covariates for val in values_set): + matching_pert_indices.append(pert_idx) + + # Get cells with these perturbation indices + if len(matching_pert_indices) == 0: + return np.array([], dtype=int) + + cell_mask = np.isin(perturbation_covariates_mask, matching_pert_indices) + return np.where(cell_mask)[0] + + if self.hard_test_split: + # HARD: Val and test get different values (existing logic) + train_idx = _get_cells_with_values(train_values) + val_idx = _get_cells_with_values(val_values) + test_idx = _get_cells_with_values(test_values) + + logger.info("HARD HOLDOUT GROUPS: Val and test get different values") + else: + # SOFT: Val and test can share values, split at cell level + train_values_all = list(train_values_random) + forced_train_values + val_test_values = list(val_values) + list(test_values) + + train_idx = _get_cells_with_values(train_values_all) + val_test_idx = _get_cells_with_values(val_test_values) + + # Split val+test cells according to val/test ratios + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT HOLDOUT GROUPS: Val/test can share values") + + # Log overlap information (important for combination treatments) + total_assigned = len(set(train_idx) | set(val_idx) | set(test_idx)) + logger.info(f"Total cells assigned to splits: {total_assigned} out of {len(perturbation_covariates_mask)}") + + overlaps = [] + if len(set(train_idx) & set(val_idx)) > 0: + overlaps.append("train-val") + if len(set(train_idx) & set(test_idx)) > 0: + overlaps.append("train-test") + if len(set(val_idx) & set(test_idx)) > 0: + overlaps.append("val-test") + + if overlaps: + logger.warning( + f"Found overlapping cells between splits: {overlaps}. This is expected with combination treatments." + ) + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_holdout_combinations( + self, + perturbation_covariates_mask: np.ndarray, + perturbation_idx_to_covariates: dict[int, tuple[str, ...]], + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Split by keeping singletons in training and holding out combinations for val/test.""" + if self.split_key is None: + raise ValueError("split_key must be provided for holdout_combinations splitting") + if self.control_value is None: + raise ValueError("control_value must be provided for holdout_combinations splitting") + + logger.info("Identifying combinations vs singletons from perturbation covariates") + logger.info(f"Control value(s): {self.control_value}") + + # Classify each perturbation index as control, singleton, or combination + control_pert_indices = [] + singleton_pert_indices = [] + combination_pert_indices = [] + + for pert_idx, covariates in perturbation_idx_to_covariates.items(): + non_control_values = [c for c in covariates if c not in self.control_value] + n_non_control = len(non_control_values) + + if n_non_control == 0: + control_pert_indices.append(pert_idx) + elif n_non_control == 1: + singleton_pert_indices.append(pert_idx) + else: + combination_pert_indices.append(pert_idx) + + # Get cell indices for each type + if len(control_pert_indices) > 0: + control_mask = np.isin(perturbation_covariates_mask, control_pert_indices) + else: + control_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + if len(singleton_pert_indices) > 0: + singleton_mask = np.isin(perturbation_covariates_mask, singleton_pert_indices) + else: + singleton_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + if len(combination_pert_indices) > 0: + combination_mask = np.isin(perturbation_covariates_mask, combination_pert_indices) + else: + combination_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + # Count each type + n_combinations = combination_mask.sum() + n_singletons = singleton_mask.sum() + n_controls = control_mask.sum() + + logger.info(f"Found {n_combinations} combination treatments") + logger.info(f"Found {n_singletons} singleton treatments") + logger.info(f"Found {n_controls} control treatments") + + if n_combinations == 0: + warnings.warn("No combination treatments found. Consider using 'holdout_groups' instead.", stacklevel=2) + + # Get indices for each type + combination_indices = np.where(combination_mask)[0] + singleton_indices = np.where(singleton_mask)[0] + control_indices = np.where(control_mask)[0] + + # All singletons and controls go to training + train_idx = np.concatenate([singleton_indices, control_indices]) + + # Split combinations according to the provided ratios + train_ratio, val_ratio, test_ratio = split_ratios + + if n_combinations > 0: + # Get perturbation identifiers for combination cells + # Map each cell to its perturbation tuple (non-control values only) + perturbation_ids = [] + for cell_idx in combination_indices: + pert_idx = perturbation_covariates_mask[cell_idx] + covariates = perturbation_idx_to_covariates[pert_idx] + # Extract non-control values + non_control_vals = [c for c in covariates if c not in self.control_value] + perturbation_id = tuple(sorted(non_control_vals)) + perturbation_ids.append(perturbation_id) + + # Get unique perturbation combinations + unique_perturbations = list(set(perturbation_ids)) + n_unique_perturbations = len(unique_perturbations) + + logger.info(f"Found {n_unique_perturbations} unique perturbation combinations") + + if self.hard_test_split: + # HARD TEST SPLIT: Val and test get completely different perturbations + # Calculate number of perturbation combinations for each split + n_train_perturbations = int(train_ratio * n_unique_perturbations) + n_val_perturbations = int(val_ratio * n_unique_perturbations) + n_test_perturbations = n_unique_perturbations - n_train_perturbations - n_val_perturbations + + # Ensure we don't exceed total perturbations + if n_train_perturbations + n_val_perturbations + n_test_perturbations != n_unique_perturbations: + n_test_perturbations = n_unique_perturbations - n_train_perturbations - n_val_perturbations + + # Shuffle perturbations for random assignment + np.random.seed(self.random_state) + shuffled_perturbations = np.random.permutation(unique_perturbations) + + # Assign perturbations to splits + train_perturbations = ( + shuffled_perturbations[:n_train_perturbations] if n_train_perturbations > 0 else [] + ) + val_perturbations = ( + shuffled_perturbations[n_train_perturbations : n_train_perturbations + n_val_perturbations] + if n_val_perturbations > 0 + else [] + ) + test_perturbations = ( + shuffled_perturbations[n_train_perturbations + n_val_perturbations :] + if n_test_perturbations > 0 + else [] + ) + + # Assign all cells with same perturbation to same split + train_combo_idx = [] + val_combo_idx = [] + test_combo_idx = [] + + for i, perturbation_id in enumerate(perturbation_ids): + cell_idx = combination_indices[i] + if perturbation_id in train_perturbations: + train_combo_idx.append(cell_idx) + elif perturbation_id in val_perturbations: + val_combo_idx.append(cell_idx) + elif perturbation_id in test_perturbations: + test_combo_idx.append(cell_idx) + + logger.info( + f"HARD TEST SPLIT - Perturbation split: Train={len(train_perturbations)}, Val={len(val_perturbations)}, Test={len(test_perturbations)}" + ) + + else: + # SOFT TEST SPLIT: Val and test can share perturbations, split at cell level + # First assign perturbations to train vs (val+test) + n_train_perturbations = int(train_ratio * n_unique_perturbations) + n_val_test_perturbations = n_unique_perturbations - n_train_perturbations + + # Shuffle perturbations + np.random.seed(self.random_state) + shuffled_perturbations = np.random.permutation(unique_perturbations) + + train_perturbations = ( + shuffled_perturbations[:n_train_perturbations] if n_train_perturbations > 0 else [] + ) + val_test_perturbations = ( + shuffled_perturbations[n_train_perturbations:] if n_val_test_perturbations > 0 else [] + ) + + # Get cells for train perturbations (all go to train) + train_combo_idx = [] + val_test_combo_idx = [] + + for i, perturbation_id in enumerate(perturbation_ids): + cell_idx = combination_indices[i] + if perturbation_id in train_perturbations: + train_combo_idx.append(cell_idx) + else: + val_test_combo_idx.append(cell_idx) + + # Now split val_test cells according to val/test ratios + if len(val_test_combo_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + np.random.seed(self.random_state + 1) # Different seed for cell-level split + + val_combo_idx, test_combo_idx = train_test_split( + val_test_combo_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_combo_idx = np.array([]) + test_combo_idx = np.array([]) + + logger.info( + f"SOFT TEST SPLIT - Perturbation split: Train={len(train_perturbations)}, Val+Test={len(val_test_perturbations)}" + ) + logger.info(f"Cell split within Val+Test: Val={len(val_combo_idx)}, Test={len(test_combo_idx)}") + + # Convert to numpy arrays + train_combo_idx = np.array(train_combo_idx) + val_combo_idx = np.array(val_combo_idx) + test_combo_idx = np.array(test_combo_idx) + + # Combine singletons/controls with assigned combinations + train_idx = np.concatenate([train_idx, train_combo_idx]) + val_idx = val_combo_idx + test_idx = test_combo_idx + + logger.info( + f"Final cell split: Train={len(train_combo_idx)}, Val={len(val_combo_idx)}, Test={len(test_combo_idx)}" + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info( + f"Final split - Train: {len(train_idx)} (singletons + controls + {len(train_combo_idx) if n_combinations > 0 else 0} combination cells)" + ) + logger.info(f"Final split - Val: {len(val_idx)} (combination cells only)") + logger.info(f"Final split - Test: {len(test_idx)} (combination cells only)") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_stratified( + self, + perturbation_covariates_mask: np.ndarray, + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Perform stratified split maintaining proportions of perturbations.""" + if self.split_key is None: + raise ValueError("split_key must be provided for stratified splitting") + + train_ratio, val_ratio, test_ratio = split_ratios + # Use perturbation indices as stratification labels + labels = perturbation_covariates_mask + indices = np.arange(len(perturbation_covariates_mask)) + + if self.hard_test_split: + # HARD: Val and test get different stratification groups (existing logic) + if val_ratio + test_ratio > 0: + train_idx, temp_idx = train_test_split( + indices, train_size=train_ratio, stratify=labels, random_state=self.random_state + ) + + if val_ratio > 0 and test_ratio > 0: + temp_labels = labels[temp_idx] + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + temp_idx, train_size=val_size, stratify=temp_labels, random_state=self.random_state + ) + elif val_ratio > 0: + val_idx = temp_idx + test_idx = np.array([]) + else: + val_idx = np.array([]) + test_idx = temp_idx + else: + train_idx = indices + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("HARD STRATIFIED SPLIT: Val and test get different strata") + else: + # SOFT: Val and test can share stratification groups, split at cell level + if val_ratio + test_ratio > 0: + train_idx, val_test_idx = train_test_split( + indices, train_size=train_ratio, stratify=labels, random_state=self.random_state + ) + + # Split val+test cells (not stratified) + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + else: + train_idx = indices + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT STRATIFIED SPLIT: Val/test can share strata") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def split_single_dataset(self, training_data: TrainingData | MappedCellData, dataset_index: int) -> dict: + """ + Split a single TrainingData or MappedCellData object according to the specified strategy. + + Parameters + ---------- + training_data : TrainingData | MappedCellData + Training data object to split + dataset_index : int + Index of the dataset to get the correct split ratios + + Returns + ------- + dict + Dictionary containing split indices and metadata + """ + # Extract perturbation information + pert_info = self.extract_perturbation_info(training_data) + perturbation_covariates_mask = pert_info["perturbation_covariates_mask"] + perturbation_idx_to_covariates = pert_info["perturbation_idx_to_covariates"] + n_cells = pert_info["n_cells"] + + # Get split ratios for this specific dataset + current_split_ratios = self.split_ratios[dataset_index] + + # Perform split based on strategy + if self.split_type == "random": + split_indices = self._split_random(n_cells, current_split_ratios) + elif self.split_type == "holdout_groups": + split_indices = self._split_by_values( + perturbation_covariates_mask, perturbation_idx_to_covariates, current_split_ratios + ) + elif self.split_type == "holdout_combinations": + split_indices = self._split_holdout_combinations( + perturbation_covariates_mask, perturbation_idx_to_covariates, current_split_ratios + ) + elif self.split_type == "stratified": + split_indices = self._split_stratified(perturbation_covariates_mask, current_split_ratios) + else: + raise ValueError(f"Unknown split_type: {self.split_type}") + + # Create result dictionary with indices and metadata + result = { + "indices": split_indices, + "metadata": { + "total_cells": n_cells, + "split_type": self.split_type, + "split_key": self.split_key, + "split_ratios": current_split_ratios, + "random_state": self.random_state, + }, + } + + # Add force_training_values and control_value to metadata + if self.force_training_values: + result["metadata"]["force_training_values"] = self.force_training_values + if self.control_value: + result["metadata"]["control_value"] = self.control_value + + # Add split values information if applicable + if self.split_type in ["holdout_groups", "holdout_combinations"] and self.split_key: + unique_values = self._get_unique_perturbation_values(perturbation_idx_to_covariates) + + def _get_split_values(indices): + """Get all unique covariate values for cells in this split.""" + if len(indices) == 0: + return [] + split_vals = set() + for idx in indices: + pert_idx = perturbation_covariates_mask[idx] + covariates = perturbation_idx_to_covariates[pert_idx] + split_vals.update(covariates) + return list(split_vals) + + train_values = _get_split_values(split_indices["train"]) + val_values = _get_split_values(split_indices["val"]) + test_values = _get_split_values(split_indices["test"]) + + result["split_values"] = { + "train": train_values, + "val": val_values, + "test": test_values, + "all_unique": unique_values, + } + + # Log split statistics + logger.info(f"Split results for {self.dataset_names[dataset_index]}:") + for split_name, indices in split_indices.items(): + if len(indices) > 0: + logger.info(f" {split_name}: {len(indices)} cells") + + return result + + def split_all_datasets(self) -> dict[str, dict]: + """ + Split all TrainingData objects according to the specified strategy. + + Returns + ------- + dict[str, dict] + Nested dictionary with dataset names as keys and split information as values + """ + logger.info(f"Starting data splitting with strategy: {self.split_type}") + logger.info(f"Number of datasets: {len(self.training_datasets)}") + for i, ratios in enumerate(self.split_ratios): + logger.info(f"Dataset {i} ratios: train={ratios[0]}, val={ratios[1]}, test={ratios[2]}") + + for i, (training_data, dataset_name) in enumerate(zip(self.training_datasets, self.dataset_names, strict=True)): + logger.info(f"\nProcessing dataset {i}: {dataset_name}") + logger.info(f"Using split ratios: {self.split_ratios[i]}") + + split_result = self.split_single_dataset(training_data, i) + self.split_results[dataset_name] = split_result + + logger.info(f"\nCompleted splitting {len(self.training_datasets)} datasets") + return self.split_results + + def save_splits(self, output_dir: str | Path) -> None: + """ + Save all split information to the specified directory. + + Parameters + ---------- + output_dir : str | Path + Directory to save the split information + """ + import json + import pickle + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving splits to: {output_dir}") + + for dataset_name, split_info in self.split_results.items(): + # Save indices as numpy arrays (more efficient for large datasets) + indices_dir = output_dir / dataset_name / "indices" + indices_dir.mkdir(parents=True, exist_ok=True) + + for split_name, indices in split_info["indices"].items(): + if len(indices) > 0: + indices_file = indices_dir / f"{split_name}_indices.npy" + np.save(indices_file, indices) + logger.info(f"Saved {split_name} indices: {len(indices)} cells -> {indices_file}") + + # Save metadata as JSON + metadata_file = output_dir / dataset_name / "metadata.json" + with open(metadata_file, "w") as f: + # Convert numpy arrays to lists for JSON serialization + metadata = split_info["metadata"].copy() + json.dump(metadata, f, indent=2) + logger.info(f"Saved metadata -> {metadata_file}") + + # Save split values if available + if "split_values" in split_info: + split_values_file = output_dir / dataset_name / "split_values.json" + with open(split_values_file, "w") as f: + json.dump(split_info["split_values"], f, indent=2) + logger.info(f"Saved split values -> {split_values_file}") + + # Save complete split info as pickle for easy loading + complete_file = output_dir / dataset_name / "split_info.pkl" + with open(complete_file, "wb") as f: + pickle.dump(split_info, f) + logger.info(f"Saved complete split info -> {complete_file}") + + logger.info("All splits saved successfully") + + @staticmethod + def load_split_info(split_dir: str | Path, dataset_name: str) -> dict: + """ + Load split information from disk. + + Parameters + ---------- + split_dir : str | Path + Directory containing saved splits + dataset_name : str + Name of the dataset + + Returns + ------- + dict + Dictionary containing split indices and metadata + """ + import pickle + + split_dir = Path(split_dir) + dataset_dir = split_dir / dataset_name + + if not dataset_dir.exists(): + raise FileNotFoundError(f"Split directory not found: {dataset_dir}") + + # Load complete split info from pickle + complete_file = dataset_dir / "split_info.pkl" + if complete_file.exists(): + with open(complete_file, "rb") as f: + return pickle.load(f) + + # Fallback: reconstruct from individual files + logger.warning("Complete split info not found, reconstructing from individual files") + + # Load indices + indices_dir = dataset_dir / "indices" + indices = {} + for split_name in ["train", "val", "test"]: + indices_file = indices_dir / f"{split_name}_indices.npy" + if indices_file.exists(): + indices[split_name] = np.load(indices_file) + else: + indices[split_name] = np.array([]) + + # Load metadata + import json + + metadata_file = dataset_dir / "metadata.json" + with open(metadata_file) as f: + metadata = json.load(f) + + # Load split values if available + split_values = None + split_values_file = dataset_dir / "split_values.json" + if split_values_file.exists(): + with open(split_values_file) as f: + split_values = json.load(f) + + result = {"indices": indices, "metadata": metadata} + if split_values: + result["split_values"] = split_values + + return result From f62c68b10663663987c0933d988560e40b341550 Mon Sep 17 00:00:00 2001 From: AlejandroTL Date: Wed, 8 Oct 2025 21:23:38 +0200 Subject: [PATCH 34/35] eqm testing --- src/scaleflow/model/_scaleflow.py | 128 +++++--- src/scaleflow/networks/__init__.py | 3 +- src/scaleflow/networks/_velocity_field.py | 159 +++++++++- src/scaleflow/solvers/__init__.py | 3 +- src/scaleflow/solvers/_eqm.py | 356 ++++++++++++++++++++++ 5 files changed, 602 insertions(+), 47 deletions(-) create mode 100644 src/scaleflow/solvers/_eqm.py diff --git a/src/scaleflow/model/_scaleflow.py b/src/scaleflow/model/_scaleflow.py index cfe8eb47..a22d651a 100644 --- a/src/scaleflow/model/_scaleflow.py +++ b/src/scaleflow/model/_scaleflow.py @@ -23,7 +23,7 @@ from scaleflow.model._utils import _write_predictions from scaleflow.networks import _velocity_field from scaleflow.plotting import _utils -from scaleflow.solvers import _genot, _otfm +from scaleflow.solvers import _genot, _otfm, _eqm from scaleflow.training._callbacks import BaseCallback from scaleflow.training._trainer import CellFlowTrainer from scaleflow.utils import match_linear @@ -43,23 +43,28 @@ class CellFlow: adata An :class:`~anndata.AnnData` object to extract the training data from. solver - Solver to use for training. Either ``'otfm'`` or ``'genot'``. + Solver to use for training. Either ``'otfm'``, ``'genot'`` or ``'eqm'``. """ - def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm"): + def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot", "eqm"] = "otfm"): self._adata = adata - self._solver_class = _otfm.OTFlowMatching if solver == "otfm" else _genot.GENOT - self._vf_class = ( - _velocity_field.ConditionalVelocityField - if solver == "otfm" - else _velocity_field.GENOTConditionalVelocityField - ) + if solver == "otfm": + self._solver_class = _otfm.OTFlowMatching + self._vf_class = _velocity_field.ConditionalVelocityField + elif solver == "genot": + self._solver_class = _genot.GENOT + self._vf_class = _velocity_field.GENOTConditionalVelocityField + elif solver == "eqm": + self._solver_class = _eqm.EquilibriumMatching + self._vf_class = _velocity_field.EquilibriumVelocityField + else: + raise ValueError(f"Unknown solver: {solver}. Must be 'otfm', 'genot', or 'eqm'.") self._dataloader: TrainSampler | JaxOutOfCoreTrainSampler | None = None self._trainer: CellFlowTrainer | None = None self._validation_data: dict[str, ValidationData] = {"predict_kwargs": {}} - self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None + self._solver: _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None = None self._condition_dim: int | None = None - self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None + self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None = None def prepare_data( self, @@ -431,8 +436,8 @@ def prepare_model( raise ValueError("Stochastic condition embeddings require `regularization`>0.") condition_encoder_kwargs = condition_encoder_kwargs or {} - if self._solver_class == _otfm.OTFlowMatching and vf_kwargs is not None: - raise ValueError("For `solver='otfm'`, `vf_kwargs` must be `None`.") + if (self._solver_class == _otfm.OTFlowMatching or self._solver_class == _eqm.EquilibriumMatching) and vf_kwargs is not None: + raise ValueError("For `solver='otfm'` or `solver='eqm'`, `vf_kwargs` must be `None`.") if self._solver_class == _genot.GENOT: if vf_kwargs is None: vf_kwargs = {"genot_source_dims": [1024, 1024, 1024], "genot_source_dropout": 0.0} @@ -446,34 +451,59 @@ def prepare_model( solver_kwargs = solver_kwargs or {} probability_path = probability_path or {"constant_noise": 0.0} - self.vf = self._vf_class( - output_dim=self._data_dim, - max_combination_length=self.train_data.max_combination_length, - condition_mode=condition_mode, - regularization=regularization, - condition_embedding_dim=condition_embedding_dim, - covariates_not_pooled=covariates_not_pooled, - pooling=pooling, - pooling_kwargs=pooling_kwargs, - layers_before_pool=layers_before_pool, - layers_after_pool=layers_after_pool, - cond_output_dropout=cond_output_dropout, - condition_encoder_kwargs=condition_encoder_kwargs, - act_fn=vf_act_fn, - time_freqs=time_freqs, - time_max_period=time_max_period, - time_encoder_dims=time_encoder_dims, - time_encoder_dropout=time_encoder_dropout, - hidden_dims=hidden_dims, - hidden_dropout=hidden_dropout, - conditioning=conditioning, - conditioning_kwargs=conditioning_kwargs, - decoder_dims=decoder_dims, - decoder_dropout=decoder_dropout, - layer_norm_before_concatenation=layer_norm_before_concatenation, - linear_projection_before_concatenation=linear_projection_before_concatenation, - **vf_kwargs, - ) + if self._solver_class == _eqm.EquilibriumMatching: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + ) + else: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + time_freqs=time_freqs, + time_max_period=time_max_period, + time_encoder_dims=time_encoder_dims, + time_encoder_dropout=time_encoder_dropout, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + **vf_kwargs, + ) probability_path, noise = next(iter(probability_path.items())) if probability_path == "constant_noise": @@ -495,6 +525,16 @@ def prepare_model( rng=jax.random.PRNGKey(seed), **solver_kwargs, ) + elif self._solver_class == _eqm.EquilibriumMatching: + # EqM doesn't use probability_path, only match_fn + self._solver = self._solver_class( + vf=self.vf, + match_fn=match_fn, + optimizer=optimizer, + conditions=self.train_data.condition_data, + rng=jax.random.PRNGKey(seed), + **solver_kwargs, + ) elif self._solver_class == _genot.GENOT: self._solver = self._solver_class( vf=self.vf, @@ -508,7 +548,7 @@ def prepare_model( **solver_kwargs, ) else: - raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") + raise NotImplementedError(f"Solver must be an instance of OTFlowMatching, EquilibriumMatching, or GENOT, got {type(self.solver)}") self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type] @@ -816,7 +856,7 @@ def adata(self) -> ad.AnnData: return self._adata @property - def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | None: + def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None: """The solver.""" return self._solver @@ -843,7 +883,7 @@ def data_manager(self) -> DataManager: @property def velocity_field( self, - ) -> _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None: + ) -> _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None: """The conditional velocity field.""" return self._vf diff --git a/src/scaleflow/networks/__init__.py b/src/scaleflow/networks/__init__.py index 2716121d..48285109 100644 --- a/src/scaleflow/networks/__init__.py +++ b/src/scaleflow/networks/__init__.py @@ -10,11 +10,12 @@ SelfAttentionBlock, TokenAttentionPooling, ) -from scaleflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField +from scaleflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField, EquilibriumVelocityField __all__ = [ "ConditionalVelocityField", "GENOTConditionalVelocityField", + "EquilibriumVelocityField", "ConditionEncoder", "MLPBlock", "SelfAttention", diff --git a/src/scaleflow/networks/_velocity_field.py b/src/scaleflow/networks/_velocity_field.py index 416dada5..68e0d46e 100644 --- a/src/scaleflow/networks/_velocity_field.py +++ b/src/scaleflow/networks/_velocity_field.py @@ -13,7 +13,7 @@ from scaleflow.networks._set_encoders import ConditionEncoder from scaleflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder -__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField"] +__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField", "EquilibriumVelocityField"] class ConditionalVelocityField(nn.Module): @@ -587,3 +587,160 @@ def create_train_state( train=False, )["params"] return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer) + + +class EquilibriumVelocityField(nn.Module): + """Parameterized neural gradient field for Equilibrium Matching (no time conditioning). + + Same as ConditionalVelocityField but without time encoder. + """ + + output_dim: int + max_combination_length: int + condition_mode: Literal["deterministic", "stochastic"] = "deterministic" + regularization: float = 1.0 + condition_embedding_dim: int = 32 + covariates_not_pooled: Sequence[str] = dc_field(default_factory=lambda: []) + pooling: Literal["mean", "attention_token", "attention_seed"] = "attention_token" + pooling_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + layers_before_pool: Layers_separate_input_t | Layers_t = dc_field(default_factory=lambda: []) + layers_after_pool: Layers_t = dc_field(default_factory=lambda: []) + cond_output_dropout: float = 0.0 + mask_value: float = 0.0 + condition_encoder_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + hidden_dims: Sequence[int] = (1024, 1024, 1024) + hidden_dropout: float = 0.0 + conditioning: Literal["concatenation", "film", "resnet"] = "concatenation" + conditioning_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + decoder_dims: Sequence[int] = (1024, 1024, 1024) + decoder_dropout: float = 0.0 + layer_norm_before_concatenation: bool = False + linear_projection_before_concatenation: bool = False + + def setup(self): + """Initialize the network.""" + if isinstance(self.conditioning_kwargs, dataclasses.Field): + conditioning_kwargs = dict(self.conditioning_kwargs.default_factory()) + else: + conditioning_kwargs = dict(self.conditioning_kwargs) + self.condition_encoder = ConditionEncoder( + condition_mode=self.condition_mode, + regularization=self.regularization, + output_dim=self.condition_embedding_dim, + pooling=self.pooling, + pooling_kwargs=self.pooling_kwargs, + layers_before_pool=self.layers_before_pool, + layers_after_pool=self.layers_after_pool, + covariates_not_pooled=self.covariates_not_pooled, + mask_value=self.mask_value, + **self.condition_encoder_kwargs, + ) + + self.layer_cond_output_dropout = nn.Dropout(rate=self.cond_output_dropout) + self.layer_norm_condition = nn.LayerNorm() if self.layer_norm_before_concatenation else lambda x: x + + self.x_encoder = MLPBlock( + dims=self.hidden_dims, + act_fn=self.act_fn, + dropout_rate=self.hidden_dropout, + act_last_layer=(False if self.linear_projection_before_concatenation else True), + ) + self.layer_norm_x = nn.LayerNorm() if self.layer_norm_before_concatenation else lambda x: x + + self.decoder = MLPBlock( + dims=self.decoder_dims, + act_fn=self.act_fn, + dropout_rate=self.decoder_dropout, + act_last_layer=(False if self.linear_projection_before_concatenation else True), + ) + + self.output_layer = nn.Dense(self.output_dim) + + if self.conditioning == "film": + self.film_block = FilmBlock( + input_dim=self.hidden_dims[-1], + cond_dim=self.condition_embedding_dim, # No time encoder! + **conditioning_kwargs, + ) + elif self.conditioning == "resnet": + self.resnet_block = ResNetBlock( + input_dim=self.hidden_dims[-1], + **conditioning_kwargs, + ) + elif self.conditioning == "concatenation": + if len(conditioning_kwargs) > 0: + raise ValueError("If `conditioning=='concatenation' mode, no conditioning kwargs can be passed.") + else: + raise ValueError(f"Unknown conditioning mode: {self.conditioning}") + + def __call__( + self, + x: jnp.ndarray, + cond: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + train: bool = True, + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + squeeze = x.ndim == 1 + cond_mean, cond_logvar = self.condition_encoder(cond, training=train) + if self.condition_mode == "deterministic": + cond_embedding = cond_mean + else: + cond_embedding = cond_mean + encoder_noise * jnp.exp(cond_logvar / 2.0) + + cond_embedding = self.layer_cond_output_dropout(cond_embedding, deterministic=not train) + x_encoded = self.x_encoder(x, training=train) + + x_encoded = self.layer_norm_x(x_encoded) + cond_embedding = self.layer_norm_condition(cond_embedding) + + if squeeze: + cond_embedding = jnp.squeeze(cond_embedding) + elif cond_embedding.shape[0] != x.shape[0]: + cond_embedding = jnp.tile(cond_embedding, (x.shape[0], 1)) + + if self.conditioning == "concatenation": + out = jnp.concatenate((x_encoded, cond_embedding), axis=-1) # No time! + elif self.conditioning == "film": + out = self.film_block(x_encoded, cond_embedding) # No time! + elif self.conditioning == "resnet": + out = self.resnet_block(x_encoded, cond_embedding) # No time! + else: + raise ValueError(f"Unknown conditioning mode: {self.conditioning}.") + + out = self.decoder(out, training=train) + return self.output_layer(out), cond_mean, cond_logvar + + def get_condition_embedding(self, condition: dict[str, jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]: + """Get the embedding of the condition.""" + condition_mean, condition_logvar = self.condition_encoder(condition, training=False) + return condition_mean, condition_logvar + + def create_train_state( + self, + rng: jax.Array, + optimizer: optax.OptState, + input_dim: int, + conditions: dict[str, jnp.ndarray], + ) -> train_state.TrainState: + """Create the training state.""" + x = jnp.ones((1, input_dim)) # No time variable! + encoder_noise = jnp.ones((1, self.condition_embedding_dim)) + cond = { + pert_cov: jnp.ones((1, self.max_combination_length, condition.shape[-1])) + for pert_cov, condition in conditions.items() + } + params_rng, condition_encoder_rng = jax.random.split(rng, 2) + params = self.init( + {"params": params_rng, "condition_encoder": condition_encoder_rng}, + x=x, + cond=cond, + encoder_noise=encoder_noise, + train=False, + )["params"] + return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer) + + @property + def output_dims(self): + """Dimensions of the output layers.""" + return tuple(self.decoder_dims) + (self.output_dim,) diff --git a/src/scaleflow/solvers/__init__.py b/src/scaleflow/solvers/__init__.py index 35ff8cb8..6c7aa964 100644 --- a/src/scaleflow/solvers/__init__.py +++ b/src/scaleflow/solvers/__init__.py @@ -1,4 +1,5 @@ from scaleflow.solvers._genot import GENOT from scaleflow.solvers._otfm import OTFlowMatching +from scaleflow.solvers._eqm import EquilibriumMatching -__all__ = ["GENOT", "OTFlowMatching"] +__all__ = ["GENOT", "OTFlowMatching", "EquilibriumMatching"] diff --git a/src/scaleflow/solvers/_eqm.py b/src/scaleflow/solvers/_eqm.py new file mode 100644 index 00000000..14cc65dc --- /dev/null +++ b/src/scaleflow/solvers/_eqm.py @@ -0,0 +1,356 @@ +# /home/icb/alejandro.tejada/CellFlow2/src/scaleflow/solvers/_eqm.py + +from collections.abc import Callable +from functools import partial +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core import frozen_dict +from flax.training import train_state +from ott.solvers import utils as solver_utils + +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import ConditionalVelocityField +from scaleflow.solvers.utils import ema_update + +__all__ = ["EquilibriumMatching"] + + +class EquilibriumMatching: + """Equilibrium Matching for generative modeling. + + Based on "Equilibrium Matching" (Wang & Du, 2024). + Learns a time-invariant equilibrium gradient field instead of + time-conditional velocities. + + Parameters + ---------- + vf + Vector field parameterized by a neural network (without time conditioning). + match_fn + Function to match samples from the source and the target + distributions. It has a ``(src, tgt) -> matching`` signature, + see e.g. :func:`scaleflow.utils.match_linear`. If :obj:`None`, no + matching is performed. + gamma_sampler + Noise level sampler with a ``(rng, n_samples) -> gamma`` signature. + Defaults to uniform sampling on [0, 1]. + c_fn + Weighting function c(gamma). Defaults to c(gamma) = 1 - gamma. + kwargs + Keyword arguments for :meth:`scaleflow.networks.ConditionalVelocityField.create_train_state`. + """ + + def __init__( + self, + vf: ConditionalVelocityField, + match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, + gamma_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, + c_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None, + **kwargs: Any, + ): + self._is_trained: bool = False + self.vf = vf + self.condition_encoder_mode = self.vf.condition_mode + self.condition_encoder_regularization = self.vf.regularization + self.gamma_sampler = gamma_sampler + self.c_fn = c_fn if c_fn is not None else lambda gamma: 1.0 - gamma + self.match_fn = jax.jit(match_fn) if match_fn is not None else None + self.ema = kwargs.pop("ema", 1.0) + + self.vf_state = self.vf.create_train_state(input_dim=self.vf.output_dims[-1], **kwargs) + self.vf_state_inference = self.vf.create_train_state(input_dim=self.vf.output_dims[-1], **kwargs) + self.vf_step_fn = self._get_vf_step_fn() + + def _get_vf_step_fn(self) -> Callable: + @jax.jit + def vf_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + gamma: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + ): + def loss_fn( + params: jnp.ndarray, + gamma: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_encoder, rng_dropout = jax.random.split(rng, 2) + + # Interpolate between source (noise) and target (data) + gamma_expanded = gamma[:, None] + x_gamma = gamma_expanded * target + (1.0 - gamma_expanded) * source + + # Predict gradient field (no time input) + f_pred, mean_cond, logvar_cond = vf_state.apply_fn( + {"params": params}, + x_gamma, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + + # Target gradient: (source - target) * c(gamma) + c_gamma = self.c_fn(gamma)[:, None] + target_gradient = (source - target) * c_gamma + + # EqM loss + eqm_loss = jnp.mean((f_pred - target_gradient) ** 2) + + # Condition encoder regularization (same as flow matching) + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + + return eqm_loss + encoder_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(vf_state.params, gamma, source, target, conditions, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return vf_step_fn + + def step_fn( + self, + rng: jnp.ndarray, + batch: dict[str, ArrayLike], + ) -> float: + """Single step function of the solver. + + Parameters + ---------- + rng + Random number generator. + batch + Data batch with keys ``src_cell_data``, ``tgt_cell_data``, and + optionally ``condition``. + + Returns + ------- + Loss value. + """ + src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] + condition = batch.get("condition") + rng_resample, rng_gamma, rng_step_fn, rng_encoder_noise = jax.random.split(rng, 4) + n = src.shape[0] + gamma = self.gamma_sampler(rng_gamma, n) + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + if self.match_fn is not None: + tmat = self.match_fn(src, tgt) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) + src, tgt = src[src_ixs], tgt[tgt_ixs] + + self.vf_state, loss = self.vf_step_fn( + rng_step_fn, + self.vf_state, + gamma, + src, + tgt, + condition, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_numpy=True) -> ArrayLike: + """Get learnt embeddings of the conditions. + + Parameters + ---------- + condition + Conditions to encode + return_as_numpy + Whether to return the embeddings as numpy arrays. + + Returns + ------- + Mean and log-variance of encoded conditions. + """ + cond_mean, cond_logvar = self.vf.apply( + {"params": self.vf_state_inference.params}, + condition, + method="get_condition_embedding", + ) + if return_as_numpy: + return np.asarray(cond_mean), np.asarray(cond_logvar) + return cond_mean, cond_logvar + + def _predict_jit( + self, + x: ArrayLike, + condition: dict[str, ArrayLike], + rng: jax.Array | None = None, + eta: float = 0.003, + max_steps: int = 250, + use_nesterov: bool = True, + mu: float = 0.35, + **kwargs: Any, + ) -> ArrayLike: + """Predict using gradient descent sampling. + + Parameters + ---------- + x + Initial samples (typically noise). + condition + Conditioning information. + rng + Random number generator for stochastic conditioning. + eta + Step size for gradient descent. + max_steps + Maximum number of gradient descent steps. + use_nesterov + Whether to use Nesterov accelerated gradient. + mu + Momentum parameter for Nesterov. + + Returns + ------- + Generated samples. + """ + noise_dim = (1, self.vf.condition_embedding_dim) + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim) + + def gradient_field(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + params = self.vf_state_inference.params + return self.vf_state_inference.apply_fn({"params": params}, x, condition, encoder_noise, train=False)[0] + + def sample_gd(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + """Basic gradient descent sampler.""" + for _ in range(max_steps): + f = gradient_field(x, condition, encoder_noise) + x = x - eta * f + return x + + def sample_nag(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + """Nesterov accelerated gradient descent sampler.""" + velocity = jnp.zeros_like(x) + for _ in range(max_steps): + x_lookahead = x - mu * velocity + f = gradient_field(x_lookahead, condition, encoder_noise) + velocity = mu * velocity + eta * f + x = x - velocity + return x + + sampler = sample_nag if use_nesterov else sample_gd + x_pred = jax.jit(jax.vmap(sampler, in_axes=[0, None, None]))(x, condition, encoder_noise) + return x_pred + + def predict( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + eta: float = 0.003, + max_steps: int = 250, + use_nesterov: bool = True, + mu: float = 0.35, + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict the translated source ``x`` under condition ``condition``. + + This function performs gradient descent on the learned equilibrium landscape. + + Parameters + ---------- + x + A dictionary with keys indicating the name of the condition and values containing + the input data as arrays. If ``batched=False`` provide an array of shape [batch_size, ...]. + condition + A dictionary with keys indicating the name of the condition and values containing + the condition of input data as arrays. If ``batched=False`` provide an array of shape + [batch_size, ...]. + rng + Random number generator to sample from the latent distribution, + only used if ``condition_mode='stochastic'``. If :obj:`None`, the + mean embedding is used. + batched + Whether to use batched prediction. + eta + Step size for gradient descent (default: 0.003 as in paper). + max_steps + Number of gradient descent steps (default: 250 as in paper). + use_nesterov + Whether to use Nesterov accelerated gradient (recommended). + mu + Momentum parameter for Nesterov (default: 0.35 as in paper). + kwargs + Additional keyword arguments (for compatibility). + + Returns + ------- + The push-forward distribution of ``x`` under condition ``condition``. + """ + if batched and not x: + return {} + + predict_fn = partial( + self._predict_jit, + rng=rng, + eta=eta, + max_steps=max_steps, + use_nesterov=use_nesterov, + mu=mu, + **kwargs, + ) + + if batched: + keys = sorted(x.keys()) + condition_keys = sorted(set().union(*(condition[k].keys() for k in keys))) + _predict_jit = jax.jit(lambda x, condition: predict_fn(x, condition)) + batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0))) + n_cells = x[keys[0]].shape[0] + for k in keys: + assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition" + src_inputs = jnp.stack([x[k] for k in keys], axis=0) + batched_conditions = {} + for cond_key in condition_keys: + batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys]) + + pred_targets = batched_predict(src_inputs, batched_conditions) + return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + return jax.tree.map( + predict_fn, + x, + condition, + ) + else: + x_pred = predict_fn(x, condition) + return np.array(x_pred) + + @property + def is_trained(self) -> bool: + """Whether the model is trained.""" + return self._is_trained + + @is_trained.setter + def is_trained(self, value: bool) -> None: + self._is_trained = value From b6d8e1fbcbf846bffb3679d54d5279776f2b5d05 Mon Sep 17 00:00:00 2001 From: AlejandroTL Date: Tue, 14 Oct 2025 14:30:43 +0200 Subject: [PATCH 35/35] tests for eqm and data splitter --- src/scaleflow/data/_data_splitter.py | 328 +++++++--- src/scaleflow/data/_datamanager.py | 14 +- src/scaleflow/model/_scaleflow.py | 3 +- src/scaleflow/networks/_velocity_field.py | 2 +- src/scaleflow/solvers/_eqm.py | 6 +- src/scaleflow/solvers/_genot.py | 7 + src/scaleflow/training/_callbacks.py | 7 + src/scaleflow/training/_trainer.py | 35 +- tests/data/test_datasplitter.py | 729 ++++++++++++++++++++++ tests/model/test_scaleflow.py | 53 +- tests/networks/test_velocityfield.py | 24 +- tests/solver/test_solver.py | 87 ++- 12 files changed, 1155 insertions(+), 140 deletions(-) create mode 100644 tests/data/test_datasplitter.py diff --git a/src/scaleflow/data/_data_splitter.py b/src/scaleflow/data/_data_splitter.py index cd3fb157..135f6789 100644 --- a/src/scaleflow/data/_data_splitter.py +++ b/src/scaleflow/data/_data_splitter.py @@ -23,10 +23,10 @@ class DataSplitter: making it memory-efficient for large datasets. Supports various splitting strategies: - - holdout_groups: Hold out specific groups (perturbations, cell lines, etc.) for validation/test - - holdout_combinations: Keep single treatments in training, hold out combinations for validation/test - - random: Random split of cells - - stratified: Stratified split maintaining proportions + - holdout_groups: Hold out specific groups (drugs, cell lines, donors, etc.) for validation/test + - holdout_combinations: Keep single treatments in training, hold out combination treatments for validation/test + - random: Random split of observations + - stratified: Stratified split maintaining condition proportions Parameters ---------- @@ -53,11 +53,25 @@ class DataSplitter: If False, validation and test can share groups, split at cell level. Applies to all split types for consistent val/test separation control. random_state : int - Random seed for reproducible splits + Random seed for reproducible splits. This controls: + - Observation-level splits in soft mode (hard_test_split=False) + - Fallback for test_random_state and val_random_state if they are None + Note: In hard mode with test_random_state and val_random_state specified, + this parameter only affects downstream training randomness (not DataSplitter itself). + test_random_state : int | None + Random seed specifically for selecting which conditions go to the test set. + If None, uses random_state as fallback. Only applies to 'holdout_groups' and + 'holdout_combinations' split types. This enables running multiple experiments with + different train/val splits while keeping the test set fixed for fair comparison. + val_random_state : int | None + Random seed specifically for selecting which conditions go to the validation set + (from the remaining conditions after test set selection). If None, uses random_state + as fallback. Only applies to 'holdout_groups' and 'holdout_combinations' split types. + This enables varying the validation set across runs while keeping test set fixed Examples -------- - >>> # Split by holdout groups with forced training values + >>> # Example 1: Basic split with forced training values >>> splitter = DataSplitter( ... training_datasets=[train_data1, train_data2], ... dataset_names=["dataset1", "dataset2"], @@ -66,7 +80,8 @@ class DataSplitter: ... split_key=["drug1", "drug2"], ... force_training_values=["control", "dmso"], ... ) - >>> # Split by holding out combinations (singletons in training) + + >>> # Example 2: Split by holding out combinations (singletons in training) >>> splitter = DataSplitter( ... training_datasets=[train_data], ... dataset_names=["dataset"], @@ -75,10 +90,37 @@ class DataSplitter: ... split_key=["drug1", "drug2"], ... control_value=["control", "dmso"], ... ) + + >>> # Example 3: Fixed test set across multiple runs (for drug discovery benchmarking) + >>> # All runs will test on the same drugs, but with different train/val splits + >>> for seed in [42, 43, 44, 45]: + ... splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["experiment"], + ... split_ratios=[[0.6, 0.2, 0.2]], + ... split_type="holdout_groups", + ... split_key=["drug"], + ... test_random_state=999, # Fixed: same test drugs across all runs + ... val_random_state=seed, # Varies: different validation drugs per run + ... random_state=seed, # Varies: different training randomness + ... ) + ... results = splitter.split_all_datasets() + ... # Train model with this split... + + >>> # Example 4: Completely different splits per run (current behavior) + >>> for seed in [42, 43, 44]: + ... splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["experiment"], + ... split_ratios=[[0.8, 0.1, 0.1]], + ... split_type="holdout_groups", + ... split_key=["drug"], + ... random_state=seed, # All three seeds derived from this + ... ) + + >>> # Save and load splits >>> results = splitter.split_all_datasets() >>> splitter.save_splits("./splits") - - >>> # Load split information later >>> split_info = DataSplitter.load_split_info("./splits", "dataset1") >>> train_indices = split_info["indices"]["train"] """ @@ -94,6 +136,8 @@ def __init__( control_value: str | list[str] | None = None, hard_test_split: bool = True, random_state: int = 42, + test_random_state: int | None = None, + val_random_state: int | None = None, ): self.training_datasets = training_datasets self.dataset_names = dataset_names @@ -104,6 +148,8 @@ def __init__( self.control_value = [control_value] if isinstance(control_value, str) else control_value self.hard_test_split = hard_test_split self.random_state = random_state + self.test_random_state = test_random_state if test_random_state is not None else random_state + self.val_random_state = val_random_state if val_random_state is not None else random_state self._validate_inputs() @@ -151,7 +197,11 @@ def _validate_inputs(self) -> None: def extract_perturbation_info(self, training_data: TrainingData | MappedCellData) -> dict: """ - Extract perturbation information from TrainingData or MappedCellData. + Extract condition information from TrainingData or MappedCellData. + + Note: Internal variable names use 'perturbation' for compatibility with + TrainingData structure, but conceptually these represent any conditions + (drugs, cell lines, donors, etc.). Parameters ---------- @@ -162,17 +212,17 @@ def extract_perturbation_info(self, training_data: TrainingData | MappedCellData ------- dict Dictionary containing: - - perturbation_covariates_mask: array mapping cells to perturbation indices - - perturbation_idx_to_covariates: dict mapping perturbation indices to covariate tuples - - n_cells: total number of cells + - perturbation_covariates_mask: array mapping observations to condition indices + - perturbation_idx_to_covariates: dict mapping condition indices to covariate tuples + - n_cells: total number of observations """ perturbation_covariates_mask = np.asarray(training_data.perturbation_covariates_mask) perturbation_idx_to_covariates = training_data.perturbation_idx_to_covariates n_cells = len(perturbation_covariates_mask) - logger.info(f"Extracted perturbation info for {n_cells} cells") - logger.info(f"Number of unique perturbations: {len(perturbation_idx_to_covariates)}") + logger.info(f"Extracted condition info for {n_cells} observations") + logger.info(f"Number of unique conditions: {len(perturbation_idx_to_covariates)}") return { "perturbation_covariates_mask": perturbation_covariates_mask, @@ -234,7 +284,7 @@ def _split_by_values( perturbation_idx_to_covariates: dict[int, tuple[str, ...]], split_ratios: list[float], ) -> dict[str, np.ndarray]: - """Split by holding out specific perturbations.""" + """Split by holding out specific condition groups.""" if self.split_key is None: raise ValueError("split_key must be provided for holdout_groups splitting") @@ -258,28 +308,32 @@ def _split_by_values( stacklevel=2, ) - # Split values according to ratios + # Split values according to ratios using three-level seed hierarchy train_ratio, val_ratio, test_ratio = split_ratios # Calculate number of values for each split - n_train = max(1, int(train_ratio * n_values)) + n_test = int(test_ratio * n_values) n_val = int(val_ratio * n_values) - n_test = n_values - n_train - n_val - - # Ensure we don't exceed total values - if n_train + n_val + n_test != n_values: - n_test = n_values - n_train - n_val - - # Shuffle available values for random assignment (excluding forced training values) - np.random.seed(self.random_state) - shuffled_values = np.random.permutation(available_values) - - # Assign values to splits - train_values_random = shuffled_values[:n_train] - val_values = shuffled_values[n_train : n_train + n_val] if n_val > 0 else [] - test_values = shuffled_values[n_train + n_val :] if n_test > 0 else [] - - # Combine forced training values with randomly assigned training values + n_train = n_values - n_test - n_val + + # Ensure we have at least one value for train if train_ratio > 0 + if train_ratio > 0 and n_train == 0: + n_train = 1 + n_test = max(0, n_test - 1) + + # Step 1: Select test values using test_random_state + np.random.seed(self.test_random_state) + shuffled_for_test = np.random.permutation(available_values) + test_values = shuffled_for_test[-n_test:] if n_test > 0 else [] + remaining_after_test = shuffled_for_test[:-n_test] if n_test > 0 else shuffled_for_test + + # Step 2: Select val values from remaining using val_random_state + np.random.seed(self.val_random_state) + shuffled_for_val = np.random.permutation(remaining_after_test) + val_values = shuffled_for_val[-n_val:] if n_val > 0 else [] + train_values_random = shuffled_for_val[:-n_val] if n_val > 0 else shuffled_for_val + + # Step 3: Combine forced training values with randomly assigned training values train_values = list(train_values_random) + forced_train_values logger.info(f"Split values - Train: {len(train_values)}, Val: {len(val_values)}, Test: {len(test_values)}") @@ -335,7 +389,7 @@ def _get_cells_with_values(values_set): # Log overlap information (important for combination treatments) total_assigned = len(set(train_idx) | set(val_idx) | set(test_idx)) - logger.info(f"Total cells assigned to splits: {total_assigned} out of {len(perturbation_covariates_mask)}") + logger.info(f"Total observations assigned to splits: {total_assigned} out of {len(perturbation_covariates_mask)}") overlaps = [] if len(set(train_idx) & set(val_idx)) > 0: @@ -358,13 +412,13 @@ def _split_holdout_combinations( perturbation_idx_to_covariates: dict[int, tuple[str, ...]], split_ratios: list[float], ) -> dict[str, np.ndarray]: - """Split by keeping singletons in training and holding out combinations for val/test.""" + """Split by keeping single conditions in training and holding out combinations for val/test.""" if self.split_key is None: raise ValueError("split_key must be provided for holdout_combinations splitting") if self.control_value is None: raise ValueError("control_value must be provided for holdout_combinations splitting") - logger.info("Identifying combinations vs singletons from perturbation covariates") + logger.info("Identifying combinations vs singletons from condition covariates") logger.info(f"Control value(s): {self.control_value}") # Classify each perturbation index as control, singleton, or combination @@ -438,36 +492,38 @@ def _split_holdout_combinations( unique_perturbations = list(set(perturbation_ids)) n_unique_perturbations = len(unique_perturbations) - logger.info(f"Found {n_unique_perturbations} unique perturbation combinations") + logger.info(f"Found {n_unique_perturbations} unique condition combinations") if self.hard_test_split: - # HARD TEST SPLIT: Val and test get completely different perturbations + # HARD TEST SPLIT: Val and test get completely different conditions # Calculate number of perturbation combinations for each split - n_train_perturbations = int(train_ratio * n_unique_perturbations) + n_test_perturbations = int(test_ratio * n_unique_perturbations) n_val_perturbations = int(val_ratio * n_unique_perturbations) - n_test_perturbations = n_unique_perturbations - n_train_perturbations - n_val_perturbations + n_train_perturbations = n_unique_perturbations - n_test_perturbations - n_val_perturbations - # Ensure we don't exceed total perturbations - if n_train_perturbations + n_val_perturbations + n_test_perturbations != n_unique_perturbations: - n_test_perturbations = n_unique_perturbations - n_train_perturbations - n_val_perturbations + # Ensure we have at least one perturbation for train if train_ratio > 0 + if train_ratio > 0 and n_train_perturbations == 0: + n_train_perturbations = 1 + n_test_perturbations = max(0, n_test_perturbations - 1) - # Shuffle perturbations for random assignment - np.random.seed(self.random_state) - shuffled_perturbations = np.random.permutation(unique_perturbations) - - # Assign perturbations to splits - train_perturbations = ( - shuffled_perturbations[:n_train_perturbations] if n_train_perturbations > 0 else [] + # Step 1: Select test perturbations using test_random_state + np.random.seed(self.test_random_state) + shuffled_for_test = np.random.permutation(unique_perturbations) + test_perturbations = ( + [tuple(p) for p in shuffled_for_test[-n_test_perturbations:]] if n_test_perturbations > 0 else [] + ) + remaining_after_test = ( + shuffled_for_test[:-n_test_perturbations] if n_test_perturbations > 0 else shuffled_for_test ) + + # Step 2: Select val perturbations from remaining using val_random_state + np.random.seed(self.val_random_state) + shuffled_for_val = np.random.permutation(remaining_after_test) val_perturbations = ( - shuffled_perturbations[n_train_perturbations : n_train_perturbations + n_val_perturbations] - if n_val_perturbations > 0 - else [] + [tuple(p) for p in shuffled_for_val[-n_val_perturbations:]] if n_val_perturbations > 0 else [] ) - test_perturbations = ( - shuffled_perturbations[n_train_perturbations + n_val_perturbations :] - if n_test_perturbations > 0 - else [] + train_perturbations = ( + [tuple(p) for p in shuffled_for_val[:-n_val_perturbations]] if n_val_perturbations > 0 else [tuple(p) for p in shuffled_for_val] ) # Assign all cells with same perturbation to same split @@ -485,17 +541,22 @@ def _split_holdout_combinations( test_combo_idx.append(cell_idx) logger.info( - f"HARD TEST SPLIT - Perturbation split: Train={len(train_perturbations)}, Val={len(val_perturbations)}, Test={len(test_perturbations)}" + f"HARD TEST SPLIT - Condition split: Train={len(train_perturbations)}, Val={len(val_perturbations)}, Test={len(test_perturbations)}" ) + if len(test_perturbations) > 0: + logger.info(f"Test perturbations: {list(test_perturbations)[:3]}") + if len(val_perturbations) > 0: + logger.info(f"Val perturbations: {list(val_perturbations)[:3]}") else: - # SOFT TEST SPLIT: Val and test can share perturbations, split at cell level - # First assign perturbations to train vs (val+test) + # SOFT TEST SPLIT: Val and test can share conditions, split at cell level + # First assign conditions to train vs (val+test) using test_random_state + # (In soft mode, val and test share conditions, so we only need one seed for this split) n_train_perturbations = int(train_ratio * n_unique_perturbations) n_val_test_perturbations = n_unique_perturbations - n_train_perturbations - # Shuffle perturbations - np.random.seed(self.random_state) + # Shuffle perturbations using test_random_state + np.random.seed(self.test_random_state) shuffled_perturbations = np.random.permutation(unique_perturbations) train_perturbations = ( @@ -529,7 +590,7 @@ def _split_holdout_combinations( test_combo_idx = np.array([]) logger.info( - f"SOFT TEST SPLIT - Perturbation split: Train={len(train_perturbations)}, Val+Test={len(val_test_perturbations)}" + f"SOFT TEST SPLIT - Condition split: Train={len(train_perturbations)}, Val+Test={len(val_test_perturbations)}" ) logger.info(f"Cell split within Val+Test: Val={len(val_combo_idx)}, Test={len(test_combo_idx)}") @@ -551,10 +612,10 @@ def _split_holdout_combinations( test_idx = np.array([]) logger.info( - f"Final split - Train: {len(train_idx)} (singletons + controls + {len(train_combo_idx) if n_combinations > 0 else 0} combination cells)" + f"Final split - Train: {len(train_idx)} (singletons + controls + {len(train_combo_idx) if n_combinations > 0 else 0} combination observations)" ) - logger.info(f"Final split - Val: {len(val_idx)} (combination cells only)") - logger.info(f"Final split - Test: {len(test_idx)} (combination cells only)") + logger.info(f"Final split - Val: {len(val_idx)} (combination observations only)") + logger.info(f"Final split - Test: {len(test_idx)} (combination observations only)") return {"train": train_idx, "val": val_idx, "test": test_idx} @@ -563,7 +624,7 @@ def _split_stratified( perturbation_covariates_mask: np.ndarray, split_ratios: list[float], ) -> dict[str, np.ndarray]: - """Perform stratified split maintaining proportions of perturbations.""" + """Perform stratified split maintaining proportions of conditions.""" if self.split_key is None: raise ValueError("split_key must be provided for stratified splitting") @@ -672,6 +733,9 @@ def split_single_dataset(self, training_data: TrainingData | MappedCellData, dat "split_key": self.split_key, "split_ratios": current_split_ratios, "random_state": self.random_state, + "test_random_state": self.test_random_state, + "val_random_state": self.val_random_state, + "hard_test_split": self.hard_test_split, }, } @@ -711,7 +775,7 @@ def _get_split_values(indices): logger.info(f"Split results for {self.dataset_names[dataset_index]}:") for split_name, indices in split_indices.items(): if len(indices) > 0: - logger.info(f" {split_name}: {len(indices)} cells") + logger.info(f" {split_name}: {len(indices)} observations") return result @@ -739,10 +803,127 @@ def split_all_datasets(self) -> dict[str, dict]: logger.info(f"\nCompleted splitting {len(self.training_datasets)} datasets") return self.split_results + def generate_split_summary(self) -> dict[str, dict]: + """ + Generate a human-readable summary of split conditions for each dataset. + + This method creates a comprehensive summary showing which specific conditions + (perturbations, cell lines, donors, etc.) are assigned to train/val/test splits. + Useful for tracking what was tested across different random seeds. + + Returns + ------- + dict[str, dict] + Dictionary with dataset names as keys and split summaries as values. + Each summary contains: + - conditions_per_split: Lists of condition values in each split + - observations_per_condition: Number of observations for each condition in each split + - statistics: Observation and condition counts per split + - configuration: Random states and split parameters used + + Examples + -------- + >>> splitter = DataSplitter(...) + >>> results = splitter.split_all_datasets() + >>> summary = splitter.generate_split_summary() + >>> print(summary["dataset1"]["conditions_per_split"]["test"]) + ['DrugA', 'DrugB', 'DrugC'] + >>> print(summary["dataset1"]["observations_per_condition"]["test"]["DrugA"]) + 150 + """ + if not self.split_results: + raise ValueError("No split results available. Run split_all_datasets() first.") + + summary = {} + + for i, (dataset_name, split_info) in enumerate(self.split_results.items()): + dataset_summary = { + "configuration": { + "split_type": split_info["metadata"]["split_type"], + "split_key": split_info["metadata"]["split_key"], + "split_ratios": split_info["metadata"]["split_ratios"], + "random_state": split_info["metadata"]["random_state"], + "test_random_state": split_info["metadata"]["test_random_state"], + "val_random_state": split_info["metadata"]["val_random_state"], + "hard_test_split": split_info["metadata"]["hard_test_split"], + }, + "statistics": { + "total_observations": split_info["metadata"]["total_cells"], + }, + } + + if self.force_training_values: + dataset_summary["configuration"]["force_training_values"] = self.force_training_values + if self.control_value: + dataset_summary["configuration"]["control_value"] = self.control_value + + # Add split statistics + for split_name, indices in split_info["indices"].items(): + dataset_summary["statistics"][f"{split_name}_observations"] = len(indices) + if split_info["metadata"]["total_cells"] > 0: + percentage = 100 * len(indices) / split_info["metadata"]["total_cells"] + dataset_summary["statistics"][f"{split_name}_percentage"] = round(percentage, 2) + + # Add condition information if available + if "split_values" in split_info: + dataset_summary["conditions_per_split"] = { + "train": sorted(split_info["split_values"]["train"]), + "val": sorted(split_info["split_values"]["val"]), + "test": sorted(split_info["split_values"]["test"]), + } + dataset_summary["statistics"]["total_unique_conditions"] = len( + split_info["split_values"]["all_unique"] + ) + dataset_summary["statistics"]["train_conditions"] = len(split_info["split_values"]["train"]) + dataset_summary["statistics"]["val_conditions"] = len(split_info["split_values"]["val"]) + dataset_summary["statistics"]["test_conditions"] = len(split_info["split_values"]["test"]) + + # Add observations per condition for each split + training_data = self.training_datasets[i] + pert_info = self.extract_perturbation_info(training_data) + perturbation_covariates_mask = pert_info["perturbation_covariates_mask"] + perturbation_idx_to_covariates = pert_info["perturbation_idx_to_covariates"] + + observations_per_condition = {} + for split_name, indices in split_info["indices"].items(): + if len(indices) == 0: + observations_per_condition[split_name] = {} + continue + + # Count observations per condition for this split + condition_counts = {} + for idx in indices: + pert_idx = perturbation_covariates_mask[idx] + condition_tuple = perturbation_idx_to_covariates[pert_idx] + + # Convert tuple to string representation for JSON compatibility + if len(condition_tuple) == 1: + condition_str = condition_tuple[0] + else: + condition_str = "+".join(condition_tuple) + + condition_counts[condition_str] = condition_counts.get(condition_str, 0) + 1 + + # Sort by condition name for consistent output + observations_per_condition[split_name] = dict(sorted(condition_counts.items())) + + dataset_summary["observations_per_condition"] = observations_per_condition + + summary[dataset_name] = dataset_summary + + return summary + def save_splits(self, output_dir: str | Path) -> None: """ Save all split information to the specified directory. + This saves multiple files per dataset: + - split_summary.json: Human-readable summary with conditions per split + - indices/*.npy: Cell indices for each split + - metadata.json: Configuration and parameters + - split_values.json: Condition values per split (if applicable) + - split_info.pkl: Complete split information + Parameters ---------- output_dir : str | Path @@ -756,6 +937,13 @@ def save_splits(self, output_dir: str | Path) -> None: logger.info(f"Saving splits to: {output_dir}") + # Generate and save split summary + split_summary = self.generate_split_summary() + summary_file = output_dir / "split_summary.json" + with open(summary_file, "w") as f: + json.dump(split_summary, f, indent=2) + logger.info(f"Saved split summary -> {summary_file}") + for dataset_name, split_info in self.split_results.items(): # Save indices as numpy arrays (more efficient for large datasets) indices_dir = output_dir / dataset_name / "indices" @@ -765,7 +953,7 @@ def save_splits(self, output_dir: str | Path) -> None: if len(indices) > 0: indices_file = indices_dir / f"{split_name}_indices.npy" np.save(indices_file, indices) - logger.info(f"Saved {split_name} indices: {len(indices)} cells -> {indices_file}") + logger.info(f"Saved {split_name} indices: {len(indices)} observations -> {indices_file}") # Save metadata as JSON metadata_file = output_dir / dataset_name / "metadata.json" diff --git a/src/scaleflow/data/_datamanager.py b/src/scaleflow/data/_datamanager.py index 2274256d..3826e83e 100644 --- a/src/scaleflow/data/_datamanager.py +++ b/src/scaleflow/data/_datamanager.py @@ -1,11 +1,21 @@ -import abc -from typing import Any, Literal +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any + +import scipy.sparse as sp +import sklearn.preprocessing as preprocessing import numpy as np +import pandas as pd +from pandas.api.types import is_numeric_dtype import tqdm import threading from concurrent.futures import ThreadPoolExecutor, Future import os +import anndata +import dask +import dask.dataframe as dd +from dask.diagnostics import ProgressBar from scaleflow.data._data import ( PredictionData, diff --git a/src/scaleflow/model/_scaleflow.py b/src/scaleflow/model/_scaleflow.py index a22d651a..8113ae43 100644 --- a/src/scaleflow/model/_scaleflow.py +++ b/src/scaleflow/model/_scaleflow.py @@ -243,7 +243,8 @@ def prepare_validation_data( predict_kwargs = predict_kwargs or {} # Check if predict_kwargs is alreday provided from an earlier call if "predict_kwargs" in self._validation_data and len(predict_kwargs): - predict_kwargs = self._validation_data["predict_kwargs"].update(predict_kwargs) + self._validation_data["predict_kwargs"].update(predict_kwargs) + predict_kwargs = self._validation_data["predict_kwargs"] # Set batched prediction to False if split_val is True if split_val: predict_kwargs["batched"] = False diff --git a/src/scaleflow/networks/_velocity_field.py b/src/scaleflow/networks/_velocity_field.py index 68e0d46e..113b3b40 100644 --- a/src/scaleflow/networks/_velocity_field.py +++ b/src/scaleflow/networks/_velocity_field.py @@ -495,7 +495,7 @@ def setup(self): elif self.conditioning == "resnet": self.resnet_block = ResNetBlock( input_dim=self.hidden_dims[-1], - **self.conditioning_kwargs, + **conditioning_kwargs, ) elif self.conditioning == "concatenation": if len(conditioning_kwargs) > 0: diff --git a/src/scaleflow/solvers/_eqm.py b/src/scaleflow/solvers/_eqm.py index 14cc65dc..af436bf8 100644 --- a/src/scaleflow/solvers/_eqm.py +++ b/src/scaleflow/solvers/_eqm.py @@ -88,7 +88,7 @@ def loss_fn( rng_encoder, rng_dropout = jax.random.split(rng, 2) # Interpolate between source (noise) and target (data) - gamma_expanded = gamma[:, None] + gamma_expanded = gamma[:, jnp.newaxis] x_gamma = gamma_expanded * target + (1.0 - gamma_expanded) * source # Predict gradient field (no time input) @@ -101,7 +101,7 @@ def loss_fn( ) # Target gradient: (source - target) * c(gamma) - c_gamma = self.c_fn(gamma)[:, None] + c_gamma = self.c_fn(gamma)[:, jnp.newaxis] target_gradient = (source - target) * c_gamma # EqM loss @@ -149,7 +149,7 @@ def step_fn( condition = batch.get("condition") rng_resample, rng_gamma, rng_step_fn, rng_encoder_noise = jax.random.split(rng, 4) n = src.shape[0] - gamma = self.gamma_sampler(rng_gamma, n) + gamma = self.gamma_sampler(rng_gamma, n).squeeze() encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) if self.match_fn is not None: diff --git a/src/scaleflow/solvers/_genot.py b/src/scaleflow/solvers/_genot.py index 079ab922..588a6036 100644 --- a/src/scaleflow/solvers/_genot.py +++ b/src/scaleflow/solvers/_genot.py @@ -284,6 +284,13 @@ def predict( pred_targets = batched_predict(src_inputs, batched_conditions) return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + predict_fn = functools.partial(self._predict_jit, rng=rng, rng_genot=rng_genot, **kwargs) + return jax.tree.map( + predict_fn, + x, + condition, + ) else: x_pred = self._predict_jit(x, condition, rng, rng_genot, **kwargs) return np.array(x_pred) diff --git a/src/scaleflow/training/_callbacks.py b/src/scaleflow/training/_callbacks.py index 92a4524d..82fef53c 100644 --- a/src/scaleflow/training/_callbacks.py +++ b/src/scaleflow/training/_callbacks.py @@ -517,6 +517,7 @@ def on_log_iteration( valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, + additional_metrics: dict[str, Any] | None = None, ) -> dict[str, Any]: """Called at each validation/log iteration to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks. @@ -531,6 +532,8 @@ def on_log_iteration( solver :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. + additional_metrics + Optional dictionary of metrics to include before computing validation metrics (e.g., train_loss) Returns ------- @@ -538,6 +541,10 @@ def on_log_iteration( """ dict_to_log: dict[str, Any] = {} + # Add additional metrics first + if additional_metrics is not None: + dict_to_log.update(additional_metrics) + for callback in self.computation_callbacks: results = callback.on_log_iteration(valid_source_data, valid_data, pred_data, solver) dict_to_log.update(results) diff --git a/src/scaleflow/training/_trainer.py b/src/scaleflow/training/_trainer.py index ee8f3155..03816a88 100644 --- a/src/scaleflow/training/_trainer.py +++ b/src/scaleflow/training/_trainer.py @@ -7,24 +7,26 @@ from tqdm import tqdm from scaleflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler -from scaleflow.solvers import _genot, _otfm +from scaleflow.solvers import _eqm, _genot, _otfm from scaleflow.training._callbacks import BaseCallback, CallbackRunner class CellFlowTrainer: - """Trainer for the OTFM/GENOT solver with a conditional velocity field. + """Trainer for the OTFM/GENOT/EqM solver with a conditional velocity field. Parameters ---------- dataloader Data sampler. solver - :class:`~scaleflow.solvers._otfm.OTFlowMatching` or - :class:`~scaleflow.solvers._genot.GENOT` solver with a conditional velocity field. + :class:`~scaleflow.solvers._otfm.OTFlowMatching`, + :class:`~scaleflow.solvers._genot.GENOT`, or + :class:`~scaleflow.solvers._eqm.EquilibriumMatching` solver with a conditional velocity field. predict_kwargs Keyword arguments for the prediction functions - :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or - :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict`, + :func:`scaleflow.solvers._genot.GENOT.predict`, or + :func:`scaleflow.solvers._eqm.EquilibriumMatching.predict` used during validation. seed Random seed for subsampling validation data. @@ -35,12 +37,12 @@ class CellFlowTrainer: def __init__( self, - solver: _otfm.OTFlowMatching | _genot.GENOT, + solver: _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching, predict_kwargs: dict[str, Any] | None = None, seed: int = 0, ): - if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)): - raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}") + if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching)): + raise NotImplementedError(f"Solver must be an instance of OTFlowMatching, GENOT, or EquilibriumMatching, got {type(solver)}") self.solver = solver self.predict_kwargs = predict_kwargs or {} @@ -87,7 +89,7 @@ def train( valid_loaders: dict[str, ValidationSampler] | None = None, monitor_metrics: Sequence[str] = [], callbacks: Sequence[BaseCallback] = [], - ) -> _otfm.OTFlowMatching | _genot.GENOT: + ) -> _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching: """Trains the model. Parameters @@ -136,14 +138,19 @@ def train( valid_loaders, mode="on_log_iteration" ) - # Run callbacks - metrics = crun.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] + # Calculate mean loss + mean_loss = np.mean(self.training_logs["loss"][-valid_freq:]) + + # Run callbacks with loss as additional metric + metrics = crun.on_log_iteration( + valid_source_data, valid_true_data, valid_pred_data, self.solver, + additional_metrics={"train_loss": mean_loss} + ) self._update_logs(metrics) # Update progress bar - mean_loss = np.mean(self.training_logs["loss"][-valid_freq:]) postfix_dict = {metric: round(self.training_logs[metric][-1], 3) for metric in monitor_metrics} - postfix_dict["loss"] = round(mean_loss, 3) + postfix_dict["train_loss"] = round(mean_loss, 3) # or keep as "loss" pbar.set_postfix(postfix_dict) if num_iterations > 0: diff --git a/tests/data/test_datasplitter.py b/tests/data/test_datasplitter.py new file mode 100644 index 00000000..3957e508 --- /dev/null +++ b/tests/data/test_datasplitter.py @@ -0,0 +1,729 @@ +from pathlib import Path + +import numpy as np +import pytest + +from scaleflow.data import DataManager +from scaleflow.data._data_splitter import DataSplitter + + +class TestDataSplitterValidation: + def test_mismatched_datasets_and_names(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="training_datasets length.*must match.*dataset_names length"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1", "dataset2"], + split_ratios=[[0.8, 0.1, 0.1]], + ) + + def test_mismatched_datasets_and_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="split_ratios length.*must match.*training_datasets length"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1], [0.7, 0.2, 0.1]], + ) + + def test_invalid_split_ratios_format(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must be a list of 3 values"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.2]], + ) + + def test_split_ratios_dont_sum_to_one(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must sum to 1.0"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.2]], + ) + + def test_negative_split_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must be non-negative"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.9, 0.2, -0.1]], + ) + + def test_holdout_groups_requires_split_key(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="split_key must be provided"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + ) + + def test_holdout_combinations_requires_control_value(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="control_value must be provided"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_combinations", + split_key="drug", + ) + + +class TestRandomSplit: + @pytest.mark.parametrize("hard_test_split", [True, False]) + @pytest.mark.parametrize("split_ratios", [[0.8, 0.1, 0.1], [0.7, 0.2, 0.1], [1.0, 0.0, 0.0]]) + def test_random_split_ratios(self, adata_perturbation, hard_test_split, split_ratios): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[split_ratios], + split_type="random", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + n_cells = train_data.perturbation_covariates_mask.shape[0] + total_assigned = len(indices["train"]) + len(indices["val"]) + len(indices["test"]) + assert total_assigned == n_cells + + train_ratio, val_ratio, test_ratio = split_ratios + assert len(indices["train"]) == pytest.approx(train_ratio * n_cells, abs=1) + if val_ratio > 0: + assert len(indices["val"]) > 0 + if test_ratio > 0: + assert len(indices["test"]) > 0 + + def test_random_split_reproducibility(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter1 = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + results1 = splitter1.split_all_datasets() + + splitter2 = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + results2 = splitter2.split_all_datasets() + + assert np.array_equal(results1["dataset1"]["indices"]["train"], results2["dataset1"]["indices"]["train"]) + assert np.array_equal(results1["dataset1"]["indices"]["val"], results2["dataset1"]["indices"]["val"]) + assert np.array_equal(results1["dataset1"]["indices"]["test"], results2["dataset1"]["indices"]["test"]) + + def test_random_split_no_overlap_hard(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.7, 0.2, 0.1]], + split_type="random", + hard_test_split=True, + random_state=42, + ) + results = splitter.split_all_datasets() + indices = results["dataset1"]["indices"] + + train_set = set(indices["train"]) + val_set = set(indices["val"]) + test_set = set(indices["test"]) + + assert len(train_set & val_set) == 0 + assert len(train_set & test_set) == 0 + assert len(val_set & test_set) == 0 + + +class TestHoldoutGroups: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_holdout_groups_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + assert "split_values" in results["dataset1"] + + split_values = results["dataset1"]["split_values"] + assert "train" in split_values + assert "val" in split_values + assert "test" in split_values + + def test_holdout_groups_force_training_values(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=[], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + # Get available perturbation values (not control) + unique_values = set() + for covariates in train_data.perturbation_idx_to_covariates.values(): + unique_values.update(covariates) + + # Use "drug_a" instead of "control" since control cells are filtered out + force_value = "drug_a" + if force_value not in unique_values: + pytest.skip("drug_a not in perturbation values") + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + force_training_values=[force_value], + random_state=42, + ) + + results = splitter.split_all_datasets() + split_values = results["dataset1"]["split_values"] + + assert force_value in split_values["train"] + assert force_value not in split_values["val"] + assert force_value not in split_values["test"] + + def test_holdout_groups_fixed_test_seed(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + results_list = [] + for seed in [42, 43, 44]: + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + test_random_state=999, + val_random_state=seed, + random_state=seed, + ) + results = splitter.split_all_datasets() + results_list.append(results) + + test_values_1 = set(results_list[0]["dataset1"]["split_values"]["test"]) + test_values_2 = set(results_list[1]["dataset1"]["split_values"]["test"]) + test_values_3 = set(results_list[2]["dataset1"]["split_values"]["test"]) + + assert test_values_1 == test_values_2 == test_values_3 + + val_values_1 = set(results_list[0]["dataset1"]["split_values"]["val"]) + val_values_2 = set(results_list[1]["dataset1"]["split_values"]["val"]) + + if len(val_values_1) > 0 and len(val_values_2) > 0: + assert val_values_1 != val_values_2 + + +class TestHoldoutCombinations: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_holdout_combinations_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_combinations", + split_key=["drug1", "drug2"], + control_value="control", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + assert len(indices["train"]) > 0 + assert len(indices["val"]) >= 0 + assert len(indices["test"]) >= 0 + + def test_holdout_combinations_singletons_in_train(self): + # Create test data with a good number of combinations + import anndata as ad + + n_obs = 1000 # Increased to accommodate more combinations + n_vars = 50 + n_pca = 10 + + X_data = np.random.rand(n_obs, n_vars) + my_counts = np.random.rand(n_obs, n_vars) + X_pca = np.random.rand(n_obs, n_pca) + + # Use 5 drugs to get 20 unique combinations (5 * 4) + drugs = ["drug_a", "drug_b", "drug_c", "drug_d", "drug_e"] + cell_lines = ["cell_line_a", "cell_line_b", "cell_line_c"] + + # Create structured data with known combinations + drug1_list = [] + drug2_list = [] + + # Control cells (100) + drug1_list.extend(["control"] * 100) + drug2_list.extend(["control"] * 100) + + # Singleton on drug1 (250 cells: 50 per drug) + for drug in drugs: + drug1_list.extend([drug] * 50) + drug2_list.extend(["control"] * 50) + + # Singleton on drug2 (250 cells: 50 per drug) + for drug in drugs: + drug1_list.extend(["control"] * 50) + drug2_list.extend([drug] * 50) + + # Combinations (400 cells distributed across 20 combinations = 20 cells each) + # Create all possible non-control combinations + combinations = [] + for d1 in drugs: + for d2 in drugs: + if d1 != d2: # Different drugs (true combinations) + combinations.append((d1, d2)) + + # Distribute 400 cells evenly across combinations (20 cells per combination) + cells_per_combo = 400 // len(combinations) + + for d1, d2 in combinations: + drug1_list.extend([d1] * cells_per_combo) + drug2_list.extend([d2] * cells_per_combo) + + # Create cell line assignments + import pandas as pd + cell_type_list = np.random.choice(cell_lines, n_obs) + dosages = np.random.choice([10.0, 100.0, 1000.0], n_obs) + + obs_data = pd.DataFrame({ + "cell_type": cell_type_list, + "dosage": dosages, + "drug1": drug1_list, + "drug2": drug2_list, + "drug3": ["control"] * n_obs, + "dosage_a": np.random.choice([10.0, 100.0, 1000.0], n_obs), + "dosage_b": np.random.choice([10.0, 100.0, 1000.0], n_obs), + "dosage_c": np.random.choice([10.0, 100.0, 1000.0], n_obs), + }) + + # Create an AnnData object + adata_combinations = ad.AnnData(X=X_data, obs=obs_data) + adata_combinations.layers["my_counts"] = my_counts + adata_combinations.obsm["X_pca"] = X_pca + + # Add boolean columns for each drug + for drug in drugs: + adata_combinations.obs[drug] = ( + (adata_combinations.obs["drug1"] == drug) | + (adata_combinations.obs["drug2"] == drug) | + (adata_combinations.obs["drug3"] == drug) + ) + + adata_combinations.obs["control"] = ( + (adata_combinations.obs["drug1"] == "control") & + (adata_combinations.obs["drug2"] == "control") + ) + + # Convert to categorical EXCEPT for control and boolean drug columns + for col in adata_combinations.obs.columns: + if col not in ["control"] + drugs: + adata_combinations.obs[col] = adata_combinations.obs[col].astype("category") + + # Add embeddings + drug_emb = {} + for drug in adata_combinations.obs["drug1"].cat.categories: + drug_emb[drug] = np.random.randn(5, 1) + adata_combinations.uns["drug"] = drug_emb + + cell_type_emb = {} + for cell_type in adata_combinations.obs["cell_type"].cat.categories: + cell_type_emb[cell_type] = np.random.randn(3, 1) + adata_combinations.uns["cell_type"] = cell_type_emb + + # Now run the actual test + dm = DataManager( + adata_combinations, + sample_rep="X", + split_covariates=[], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_combinations) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_combinations", + split_key=["drug1", "drug2"], + control_value="control", + random_state=42, + ) + + results = splitter.split_all_datasets() + + perturbation_covariates_mask = train_data.perturbation_covariates_mask + perturbation_idx_to_covariates = train_data.perturbation_idx_to_covariates + + train_indices = results["dataset1"]["indices"]["train"] + val_indices = results["dataset1"]["indices"]["val"] + test_indices = results["dataset1"]["indices"]["test"] + + # Verify that ALL singletons and controls are in training + all_singletons = [] + all_combinations = [] + + for idx in range(len(perturbation_covariates_mask)): + pert_idx = perturbation_covariates_mask[idx] + if pert_idx >= 0: + covariates = perturbation_idx_to_covariates[pert_idx] + non_control_count = sum(1 for c in covariates if c != "control") + if non_control_count == 1: + all_singletons.append(idx) + elif non_control_count > 1: + all_combinations.append(idx) + + train_set = set(train_indices) + + # All singletons should be in training + for singleton_idx in all_singletons: + assert singleton_idx in train_set, "All singleton perturbations should be in training" + + # Some (but not all) combinations should be in training according to split_ratios + combinations_in_train = [idx for idx in all_combinations if idx in train_set] + combinations_in_val = [idx for idx in all_combinations if idx in set(val_indices)] + combinations_in_test = [idx for idx in all_combinations if idx in set(test_indices)] + + # With enough combinations, we should see proper distribution + assert len(all_combinations) > 0, "Test data should have combination perturbations" + + train_combo_ratio = len(combinations_in_train) / len(all_combinations) + val_combo_ratio = len(combinations_in_val) / len(all_combinations) + test_combo_ratio = len(combinations_in_test) / len(all_combinations) + + # With 0.6, 0.2, 0.2 ratios, allow some tolerance + assert 0.4 < train_combo_ratio < 0.8, f"Expected ~60% of combinations in training, got {train_combo_ratio:.2%}" + assert 0.05 < val_combo_ratio < 0.35, f"Expected ~20% of combinations in val, got {val_combo_ratio:.2%}" + assert 0.05 < test_combo_ratio < 0.35, f"Expected ~20% of combinations in test, got {test_combo_ratio:.2%}" + + +class TestStratifiedSplit: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_stratified_split_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="stratified", + split_key="drug", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + n_cells = train_data.perturbation_covariates_mask.shape[0] + total_assigned = len(indices["train"]) + len(indices["val"]) + len(indices["test"]) + assert total_assigned == n_cells + + +class TestMultipleDatasets: + def test_multiple_datasets_different_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data1 = dm.get_train_data(adata_perturbation) + train_data2 = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data1, train_data2], + dataset_names=["dataset1", "dataset2"], + split_ratios=[[0.8, 0.1, 0.1], [0.7, 0.2, 0.1]], + split_type="random", + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + assert "dataset2" in results + + n_cells = train_data1.perturbation_covariates_mask.shape[0] + + assert len(results["dataset1"]["indices"]["train"]) == pytest.approx(0.8 * n_cells, abs=1) + assert len(results["dataset2"]["indices"]["train"]) == pytest.approx(0.7 * n_cells, abs=1) + + +class TestSaveAndLoad: + def test_save_and_load_splits(self, adata_perturbation, tmp_path): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + split_key="drug", + random_state=42, + ) + + results = splitter.split_all_datasets() + splitter.save_splits(tmp_path / "splits") + + assert (tmp_path / "splits" / "split_summary.json").exists() + assert (tmp_path / "splits" / "dataset1" / "metadata.json").exists() + assert (tmp_path / "splits" / "dataset1" / "split_info.pkl").exists() + + loaded_info = DataSplitter.load_split_info(tmp_path / "splits", "dataset1") + + assert "indices" in loaded_info + assert "metadata" in loaded_info + + assert np.array_equal(loaded_info["indices"]["train"], results["dataset1"]["indices"]["train"]) + assert np.array_equal(loaded_info["indices"]["val"], results["dataset1"]["indices"]["val"]) + assert np.array_equal(loaded_info["indices"]["test"], results["dataset1"]["indices"]["test"]) + + def test_load_nonexistent_split(self, tmp_path): + with pytest.raises(FileNotFoundError): + DataSplitter.load_split_info(tmp_path / "nonexistent", "dataset1") + + +class TestSplitSummary: + def test_generate_split_summary(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + split_key="drug", + random_state=42, + ) + + splitter.split_all_datasets() + summary = splitter.generate_split_summary() + + assert "dataset1" in summary + assert "configuration" in summary["dataset1"] + assert "statistics" in summary["dataset1"] + assert "observations_per_condition" in summary["dataset1"] + + config = summary["dataset1"]["configuration"] + assert config["split_type"] == "holdout_groups" + assert config["random_state"] == 42 + + stats = summary["dataset1"]["statistics"] + assert "total_observations" in stats + assert "train_observations" in stats + assert "val_observations" in stats + assert "test_observations" in stats + + def test_summary_before_split_raises(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + + with pytest.raises(ValueError, match="No split results available"): + splitter.generate_split_summary() + + +class TestExtractPerturbationInfo: + def test_extract_perturbation_info(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + ) + + pert_info = splitter.extract_perturbation_info(train_data) + + assert "perturbation_covariates_mask" in pert_info + assert "perturbation_idx_to_covariates" in pert_info + assert "n_cells" in pert_info + + assert isinstance(pert_info["perturbation_covariates_mask"], np.ndarray) + assert isinstance(pert_info["perturbation_idx_to_covariates"], dict) + assert pert_info["n_cells"] == len(train_data.perturbation_covariates_mask) diff --git a/tests/model/test_scaleflow.py b/tests/model/test_scaleflow.py index 643ac293..90566fcc 100644 --- a/tests/model/test_scaleflow.py +++ b/tests/model/test_scaleflow.py @@ -17,7 +17,7 @@ class TestCellFlow: @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm"]) # , "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) @@ -29,7 +29,7 @@ def test_scaleflow_solver( regularization, conditioning, ): - if solver == "genot" and ((condition_mode == "stochastic") or (regularization > 0.0)): + if solver in ["genot", "eqm"] and ((condition_mode == "stochastic") or (regularization > 0.0)): return None sample_rep = "X" control_key = "control" @@ -77,15 +77,14 @@ def test_scaleflow_solver( cf.train(num_iterations=3) assert cf._dataloader is not None - # we assume these are all source cells now in adata_perturbation adata_perturbation_pred = adata_perturbation.copy() adata_perturbation_pred.obs["control"] = True + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} pred = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, - max_steps=3, - throw=False, + **predict_kwargs, ) assert isinstance(pred, dict) key, out = next(iter(pred.items())) @@ -97,16 +96,14 @@ def test_scaleflow_solver( sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, key_added_prefix="MY_PREDICTION_", - max_steps=3, - throw=False, + **predict_kwargs, ) assert pred_stored is None - if solver == "otfm": + if solver in ["otfm", "genot", "eqm"]: assert "MY_PREDICTION_" + str(key) in adata_perturbation_pred.obsm if solver == "genot": - assert "MY_PREDICTION_" + str(key) in adata_perturbation_pred.obsm pred2 = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, @@ -133,7 +130,7 @@ def test_scaleflow_solver( assert cond_embed_var.shape[1] == condition_embedding_dim @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("perturbation_covariate_reps", [{}, {"drug": "drug"}]) def test_scaleflow_covar_reps( self, @@ -166,24 +163,24 @@ def test_scaleflow_covar_reps( ) assert cf._trainer is not None - vector_field_class = ( - _velocity_field.ConditionalVelocityField - if solver == "otfm" - else _velocity_field.GENOTConditionalVelocityField - ) + if solver == "otfm": + vector_field_class = _velocity_field.ConditionalVelocityField + elif solver == "genot": + vector_field_class = _velocity_field.GENOTConditionalVelocityField + else: + vector_field_class = _velocity_field.EquilibriumVelocityField assert cf._vf_class == vector_field_class cf.train(num_iterations=3) assert cf._dataloader is not None - # we assume these are all source cells now in adata_perturbation adata_perturbation_pred = adata_perturbation.copy() adata_perturbation_pred.obs["control"] = True + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} pred = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, - max_steps=3, - throw=False, + **predict_kwargs, ) assert isinstance(pred, dict) out = next(iter(pred.values())) @@ -248,7 +245,7 @@ def test_scaleflow_val_data_loading( assert cond_data[k].shape[1] == cf.train_data.max_combination_length @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 0, 1]) @pytest.mark.parametrize("n_conditions_on_train_end", [None, 0, 1]) def test_scaleflow_with_validation( @@ -259,6 +256,7 @@ def test_scaleflow_with_validation( n_conditions_on_train_end, ): vf_kwargs = {"genot_source_dims": (2, 2), "genot_source_dropout": 0.1} if solver == "genot" else None + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", @@ -274,7 +272,7 @@ def test_scaleflow_with_validation( name="val", n_conditions_on_log_iteration=n_conditions_on_log_iteration, n_conditions_on_train_end=n_conditions_on_train_end, - predict_kwargs={"max_steps": 3, "throw": False}, + predict_kwargs=predict_kwargs, ) assert isinstance(cf._validation_data, dict) assert "val" in cf._validation_data @@ -307,7 +305,7 @@ def test_scaleflow_with_validation( assert f"val_{metric_to_compute}_mean" in cf._trainer.training_logs @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) def test_scaleflow_predict( @@ -317,6 +315,8 @@ def test_scaleflow_predict( condition_mode, regularization, ): + if solver in ["genot", "eqm"] and ((condition_mode == "stochastic") or (regularization > 0.0)): + return None cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", @@ -354,7 +354,8 @@ def test_scaleflow_predict( adata_pred.obs["control"] = True covariate_data = adata_perturbation.obs.iloc[:3] - pred = cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, max_steps=3, throw=False) + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} + pred = cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, **predict_kwargs) assert isinstance(pred, dict) out = next(iter(pred.values())) @@ -365,11 +366,11 @@ def test_scaleflow_predict( ValueError, match=r".*If both `adata` and `covariate_data` are given, all samples in `adata` must be control samples*", ): - cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, max_steps=3, throw=False) + cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, **predict_kwargs) with pytest.raises(ValueError, match="`covariate_data` is empty."): empty_covariate_data = covariate_data.head(0) - cf.predict(adata_pred, sample_rep="X", covariate_data=empty_covariate_data, max_steps=3, throw=False) + cf.predict(adata_pred, sample_rep="X", covariate_data=empty_covariate_data, **predict_kwargs) with pytest.raises( ValueError, @@ -381,7 +382,7 @@ def test_scaleflow_predict( adata_pred_cell_type_2 = adata_pred[adata_pred.obs["cell_type"] == "cell_line_b"] adata_pred_cell_type_2.obs["control"] = True cf.predict( - adata_pred_cell_type_2, sample_rep="X", covariate_data=cov_data_cell_type_1, max_steps=3, throw=False + adata_pred_cell_type_2, sample_rep="X", covariate_data=cov_data_cell_type_1, **predict_kwargs ) def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): @@ -395,7 +396,7 @@ def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): ) with pytest.raises( ValueError, - match=r".*For `solver='otfm'`, `vf_kwargs` must be `None`.*", + match=r".*For `solver='otfm'` or `solver='eqm'`, `vf_kwargs` must be `None`.*", ): cf.prepare_model( condition_embedding_dim=2, diff --git a/tests/networks/test_velocityfield.py b/tests/networks/test_velocityfield.py index b592573f..96db1f2d 100644 --- a/tests/networks/test_velocityfield.py +++ b/tests/networks/test_velocityfield.py @@ -19,7 +19,7 @@ class TestVelocityField: @pytest.mark.parametrize("linear_projection_before_concatenation", [True, False]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocity_field_init( @@ -62,6 +62,15 @@ def test_velocity_field_init( train=True, rngs={"condition_encoder": apply_rng}, ) + elif isinstance(vf, _velocity_field.EquilibriumVelocityField): + out, out_mean, out_logvar = vf_state.apply_fn( + {"params": vf_state.params}, + x_test, + cond, + encoder_noise, + train=True, + rngs={"condition_encoder": apply_rng}, + ) elif isinstance(vf, _velocity_field.ConditionalVelocityField): out, out_mean, out_logvar = vf_state.apply_fn( {"params": vf_state.params}, @@ -84,7 +93,7 @@ def test_velocity_field_init( @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_cls, conditioning): @@ -127,6 +136,15 @@ def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_ train=True, rngs={"condition_encoder": apply_rng, "dropout": dropout_rng}, ) + elif isinstance(vf, _velocity_field.EquilibriumVelocityField): + out, out_mean, out_logvar = vf_state.apply_fn( + {"params": vf_state.params}, + x_test, + cond, + encoder_noise, + train=True, + rngs={"condition_encoder": apply_rng, "dropout": dropout_rng}, + ) elif isinstance(vf, _velocity_field.ConditionalVelocityField): out, out_mean, out_logvar = vf_state.apply_fn( {"params": vf_state.params}, @@ -145,7 +163,7 @@ def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_ @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocityfield_conditioning_raises(self, condition_mode, velocity_field_cls, conditioning): diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index fe31f73d..08454141 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -2,13 +2,14 @@ import time import jax +import jax.numpy as jnp import numpy as np import optax import pytest from ott.neural.methods.flows import dynamics import scaleflow -from scaleflow.solvers import _genot, _otfm +from scaleflow.solvers import _eqm, _genot, _otfm from scaleflow.utils import match_linear src = { @@ -22,13 +23,30 @@ vf_rng = jax.random.PRNGKey(111) +@pytest.fixture +def eqm_dataloader(): + class DataLoader: + n_conditions = 10 + + def sample(self, rng): + return { + "src_cell_data": jnp.ones((10, 5)) * 10, + "tgt_cell_data": jnp.ones((10, 5)), + "condition": {"pert1": jnp.ones((10, 2, 3))}, + } + + return DataLoader() + + class TestSolver: - @pytest.mark.parametrize("solver_class", ["otfm", "genot"]) - def test_predict_batch(self, dataloader, solver_class): + @pytest.mark.parametrize("solver_class", ["otfm", "genot", "eqm"]) + def test_predict_batch(self, dataloader, eqm_dataloader, solver_class): if solver_class == "otfm": vf_class = scaleflow.networks.ConditionalVelocityField - else: + elif solver_class == "genot": vf_class = scaleflow.networks.GENOTConditionalVelocityField + else: + vf_class = scaleflow.networks.EquilibriumVelocityField opt = optax.adam(1e-3) vf = vf_class( @@ -47,7 +65,7 @@ def test_predict_batch(self, dataloader, solver_class): conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, ) - else: + elif solver_class == "genot": solver = _genot.GENOT( vf=vf, data_match_fn=match_linear, @@ -58,11 +76,20 @@ def test_predict_batch(self, dataloader, solver_class): conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, ) + else: + solver = _eqm.EquilibriumMatching( + vf=vf, + match_fn=match_linear, + optimizer=opt, + conditions={"pert1": np.random.rand(2, 2, 3)}, + rng=vf_rng, + ) - predict_kwargs = {"max_steps": 3, "throw": False} + predict_kwargs = {"max_steps": 3, "throw": False} if solver_class != "eqm" else {"max_steps": 3, "eta": 0.01} trainer = scaleflow.training.CellFlowTrainer(solver=solver, predict_kwargs=predict_kwargs) + train_dataloader = eqm_dataloader if solver_class == "eqm" else dataloader trainer.train( - dataloader=dataloader, + dataloader=train_dataloader, num_iterations=2, valid_freq=1, ) @@ -89,10 +116,18 @@ def test_predict_batch(self, dataloader, solver_class): ) assert diff_nonbatched - diff_batched > 0.5 + @pytest.mark.parametrize("solver_class", ["otfm", "eqm"]) @pytest.mark.parametrize("ema", [0.5, 1.0]) - def test_EMA(self, dataloader, ema): - vf_class = scaleflow.networks.ConditionalVelocityField - drug = np.random.rand(2, 1, 3) + def test_EMA(self, dataloader, eqm_dataloader, solver_class, ema): + if solver_class == "otfm": + vf_class = scaleflow.networks.ConditionalVelocityField + drug = np.random.rand(2, 1, 3) + condition_key = "drug" + else: + vf_class = scaleflow.networks.EquilibriumVelocityField + drug = np.random.rand(2, 2, 3) + condition_key = "pert1" + opt = optax.adam(1e-3) vf1 = vf_class( output_dim=5, @@ -102,18 +137,30 @@ def test_EMA(self, dataloader, ema): decoder_dims=(5, 5), ) - solver1 = _otfm.OTFlowMatching( - vf=vf1, - match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), - optimizer=opt, - conditions={"drug": drug}, - rng=vf_rng, - ema=ema, - ) + if solver_class == "otfm": + solver1 = _otfm.OTFlowMatching( + vf=vf1, + match_fn=match_linear, + probability_path=dynamics.ConstantNoiseFlow(0.0), + optimizer=opt, + conditions={condition_key: drug}, + rng=vf_rng, + ema=ema, + ) + else: + solver1 = _eqm.EquilibriumMatching( + vf=vf1, + match_fn=match_linear, + optimizer=opt, + conditions={condition_key: drug}, + rng=vf_rng, + ema=ema, + ) + trainer1 = scaleflow.training.CellFlowTrainer(solver=solver1) + train_dataloader = eqm_dataloader if solver_class == "eqm" else dataloader trainer1.train( - dataloader=dataloader, + dataloader=train_dataloader, num_iterations=5, valid_freq=10, )