Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
08c0bc9
integrate new dataloaders
selmanozleyen Aug 8, 2025
14dda75
put working state
selmanozleyen Aug 8, 2025
e40e575
add new files
selmanozleyen Aug 8, 2025
691e941
fix
selmanozleyen Aug 8, 2025
a2a8ab1
fix this
selmanozleyen Aug 8, 2025
4a2cb76
format
selmanozleyen Aug 8, 2025
73900f6
remove extra test files
selmanozleyen Aug 19, 2025
245b595
update the write function
selmanozleyen Aug 25, 2025
2a0d870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2025
6e34cc7
remove compat test
selmanozleyen Aug 25, 2025
1b78a05
Merge branch 'feature/zarr-data' of https://github.com/theislab/cellf…
selmanozleyen Aug 25, 2025
cc2d53b
fix import problems and rename function to write_zarr
selmanozleyen Aug 25, 2025
297a83c
hide explicit torch imports
selmanozleyen Aug 25, 2025
2be2bd6
add read and write zarr tests
selmanozleyen Aug 25, 2025
a1f974c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2025
f4062bb
push working state
selmanozleyen Aug 25, 2025
93c66a7
Merge branch 'feature/zarr-data' of https://github.com/theislab/cellf…
selmanozleyen Aug 25, 2025
7ac0f8f
remove torch test for cellflow workflow
selmanozleyen Aug 26, 2025
8abe9b1
Merge branch 'main' into feature/zarr-data
selmanozleyen Aug 26, 2025
042e07a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2025
d454e34
Delete tests/test_optional.py
selmanozleyen Aug 26, 2025
9e56b37
fix unintentionally removed line
selmanozleyen Aug 26, 2025
e67de7d
ability to add names and tests
selmanozleyen Aug 26, 2025
feae2dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2025
5c611f9
bug fix
selmanozleyen Sep 17, 2025
3423c39
add trainsampler with pool
selmanozleyen Sep 19, 2025
8291b7a
save current state
selmanozleyen Sep 20, 2025
2a26de9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
470 changes: 470 additions & 0 deletions docs/notebooks/600_trainsampler.ipynb

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions scripts/create_tahoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from sc_exp_design.models import CellFlow
import anndata as ad
import h5py

from anndata.experimental import read_lazy

print("loading data")
with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f:
adata_all = ad.AnnData(
obs=ad.io.read_elem(f["obs"]),
var=read_lazy(f["var"]),
uns = read_lazy(f["uns"]),
obsm = read_lazy(f["obsm"]),
)
cf = CellFlow()

print(" preparing train data ")
cf.prepare_train_data(adata_all,
sample_rep="X_pca",
control_key="control",
perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)},
perturbation_covariate_reps={"drugs": "drug_embeddings"},
sample_covariates=["cell_line"],
sample_covariate_reps={"cell_line": "cell_line_embeddings"},
split_covariates=["cell_line"])




print("writing zarr")
cf.train_data.write_zarr(f"/lustre/groups/ml01/workspace/100mil/tahoe_train_data.zarr")
print("zarr written")
9 changes: 9 additions & 0 deletions src/cellflow/_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class OptionalDependencyNotAvailable(ImportError):
pass


def torch_required_msg() -> str:
return (
"Optional dependency 'torch' is required for this feature.\n"
"Install it via: pip install torch # or pip install 'cellflow-tools[torch]'"
)
3 changes: 3 additions & 0 deletions src/cellflow/compat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .torch_ import TorchIterableDataset

__all__ = ["TorchIterableDataset"]
19 changes: 19 additions & 0 deletions src/cellflow/compat/torch_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import TYPE_CHECKING

from cellflow._optional import OptionalDependencyNotAvailable, torch_required_msg

try:
from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore

TORCH_AVAILABLE = True
except ImportError as _:
TORCH_AVAILABLE = False

class TorchIterableDataset: # noqa: D101
def __init__(self, *args, **kwargs):
raise OptionalDependencyNotAvailable(torch_required_msg())


if TYPE_CHECKING:
# keeps type checkers aligned with the real type
from torch.utils.data import IterableDataset as TorchIterableDataset # noqa: F401
22 changes: 20 additions & 2 deletions src/cellflow/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from cellflow.data._data import BaseDataMixin, ConditionData, PredictionData, TrainingData, ValidationData
from cellflow.data._dataloader import PredictionSampler, TrainSampler, ValidationSampler
from cellflow.data._data import (
BaseDataMixin,
ConditionData,
PredictionData,
TrainingData,
ValidationData,
ZarrTrainingData,
)
from cellflow.data._dataloader import (
PredictionSampler,
TrainSampler,
TrainSamplerWithPool,
ValidationSampler,
)
from cellflow.data._datamanager import DataManager
from cellflow.data._jax_dataloader import JaxOutOfCoreTrainSampler
from cellflow.data._torch_dataloader import TorchCombinedTrainSampler

__all__ = [
"DataManager",
Expand All @@ -9,7 +23,11 @@
"PredictionData",
"TrainingData",
"ValidationData",
"ZarrTrainingData",
"TrainSampler",
"ValidationSampler",
"PredictionSampler",
"TorchCombinedTrainSampler",
"JaxOutOfCoreTrainSampler",
"TrainSamplerWithPool",
]
141 changes: 138 additions & 3 deletions src/cellflow/data/_data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import jax
import numpy as np
import zarr
from zarr.storage import LocalStore

from cellflow._types import ArrayLike
from cellflow.data._utils import write_sharded

__all__ = [
"BaseDataMixin",
"ConditionData",
"PredictionData",
"TrainingData",
"ValidationData",
"ZarrTrainingData",
]


Expand Down Expand Up @@ -121,6 +126,65 @@ class TrainingData(BaseDataMixin):
null_value: Any
data_manager: Any

# --- Zarr export helpers -------------------------------------------------
def write_zarr(
self,
path: str,
*,
chunk_size: int = 4096,
shard_size: int = 65536,
compressors: Any | None = None,
) -> None:
"""Write this training data to Zarr v3 with sharded, compressed arrays.

Parameters
----------
path
Path to a Zarr group to create or open for writing.
chunk_size
Chunk size along the first axis.
shard_size
Shard size along the first axis.
compressors
Optional list/tuple of Zarr codecs. If ``None``, a sensible default is used.
"""
# Convert to numpy-backed containers for serialization
cell_data = np.asarray(self.cell_data)
split_covariates_mask = np.asarray(self.split_covariates_mask)
perturbation_covariates_mask = np.asarray(self.perturbation_covariates_mask)
condition_data = {str(k): np.asarray(v) for k, v in (self.condition_data or {}).items()}
control_to_perturbation = {str(k): np.asarray(v) for k, v in (self.control_to_perturbation or {}).items()}
split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (self.split_idx_to_covariates or {}).items()}
perturbation_idx_to_covariates = {
str(k): np.asarray(v) for k, v in (self.perturbation_idx_to_covariates or {}).items()
}
perturbation_idx_to_id = {str(k): v for k, v in (self.perturbation_idx_to_id or {}).items()}

train_data_dict: dict[str, Any] = {
"cell_data": cell_data,
"split_covariates_mask": split_covariates_mask,
"perturbation_covariates_mask": perturbation_covariates_mask,
"split_idx_to_covariates": split_idx_to_covariates,
"perturbation_idx_to_covariates": perturbation_idx_to_covariates,
"perturbation_idx_to_id": perturbation_idx_to_id,
"condition_data": condition_data,
"control_to_perturbation": control_to_perturbation,
"max_combination_length": int(self.max_combination_length),
}

additional_kwargs = {}
if compressors is not None:
additional_kwargs["compressors"] = compressors

zgroup = zarr.open_group(path, mode="w")
write_sharded(
zgroup,
train_data_dict,
chunk_size=chunk_size,
shard_size=shard_size,
**additional_kwargs,
)


@dataclass
class ValidationData(BaseDataMixin):
Expand Down Expand Up @@ -171,6 +235,11 @@ class ValidationData(BaseDataMixin):
n_conditions_on_train_end: int | None = None


def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]:
keys = zgroup[key].keys()
return {k: zgroup[key][k] for k in keys}


@dataclass
class PredictionData(BaseDataMixin):
"""Data container to perform prediction.
Expand All @@ -191,8 +260,8 @@ class PredictionData(BaseDataMixin):
Token to use for masking ``null_value``.
"""

cell_data: jax.Array # (n_cells, n_features)
split_covariates_mask: jax.Array # (n_cells,), which cell assigned to which source distribution
cell_data: ArrayLike # (n_cells, n_features)
split_covariates_mask: ArrayLike # (n_cells,), which cell assigned to which source distribution
split_idx_to_covariates: dict[int, tuple[Any, ...]] # (n_sources,) dictionary explaining split_covariates_mask
perturbation_idx_to_covariates: dict[
int, tuple[str, ...]
Expand All @@ -203,3 +272,69 @@ class PredictionData(BaseDataMixin):
max_combination_length: int
null_value: Any
data_manager: Any


@dataclass
class ZarrTrainingData(BaseDataMixin):
"""Lazy, Zarr-backed variant of :class:`TrainingData`.

Fields mirror those in :class:`TrainingData`, but array-like members are
Zarr arrays or Zarr-backed mappings. This enables out-of-core training and
composition without loading everything into memory.

Use :meth:`read_zarr` to construct from a Zarr v3 group written via
:meth:`TrainingData.to_zarr`.
"""

# Note: annotations use Any to allow zarr.Array and zarr groups without
# importing zarr at module import time.
cell_data: Any
split_covariates_mask: Any
perturbation_covariates_mask: Any
split_idx_to_covariates: dict[int, tuple[Any, ...]]
perturbation_idx_to_covariates: dict[int, tuple[str, ...]]
perturbation_idx_to_id: dict[int, Any]
condition_data: dict[str, Any]
control_to_perturbation: dict[int, Any]
max_combination_length: int

def __post_init__(self):
# load everything except cell_data to memory

# load masks as numpy arrays
self.split_covariates_mask = self.split_covariates_mask[...]
self.perturbation_covariates_mask = self.perturbation_covariates_mask[...]

self.condition_data = {k: np.asarray(v) for k, v in self.condition_data.items()}
self.control_to_perturbation = {int(k): np.asarray(v) for k, v in self.control_to_perturbation.items()}
self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()}
self.perturbation_idx_to_covariates = {
int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items()
}
self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()}

@classmethod
def read_zarr(cls, path: str) -> ZarrTrainingData:
if isinstance(path, str):
path = LocalStore(path, read_only=True)
group = zarr.open_group(path, mode="r")
max_len_node = group.get("max_combination_length")
if max_len_node is None:
max_combination_length = 0
else:
try:
max_combination_length = int(max_len_node[()])
except Exception: # noqa: BLE001
max_combination_length = int(max_len_node)

return cls(
cell_data=group["cell_data"],
split_covariates_mask=group["split_covariates_mask"],
perturbation_covariates_mask=group["perturbation_covariates_mask"],
split_idx_to_covariates=_read_dict(group, "split_idx_to_covariates"),
perturbation_idx_to_covariates=_read_dict(group, "perturbation_idx_to_covariates"),
perturbation_idx_to_id=_read_dict(group, "perturbation_idx_to_id"),
condition_data=_read_dict(group, "condition_data"),
control_to_perturbation=_read_dict(group, "control_to_perturbation"),
max_combination_length=max_combination_length,
)
Loading
Loading