diff --git a/.gitignore b/.gitignore index 1fd1f144..9872ae7b 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ tests/plotting/actual_figures/ # huggingface hub/ +docs/notebooks/test.zarr \ No newline at end of file diff --git a/README.md b/README.md index e0b0f9d3..3cac74ea 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ -CellFlow +CellFlow -[![PyPI](https://img.shields.io/pypi/v/cellflow-tools.svg)](https://pypi.org/project/cellflow-tools/) -[![Downloads](https://static.pepy.tech/badge/cellflow-tools)](https://pepy.tech/project/cellflow-tools) -[![CI](https://img.shields.io/github/actions/workflow/status/theislab/cellflow/test.yaml?branch=main)](https://github.com/theislab/cellflow/actions) +[![PyPI](https://img.shields.io/pypi/v/scaleflow-tools.svg)](https://pypi.org/project/scaleflow-tools/) +[![Downloads](https://static.pepy.tech/badge/scaleflow-tools)](https://pepy.tech/project/scaleflow-tools) +[![CI](https://img.shields.io/github/actions/workflow/status/theislab/scaleflow/test.yaml?branch=main)](https://github.com/theislab/scaleflow/actions) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/theislab/CellFlow/main.svg)](https://results.pre-commit.ci/latest/github/theislab/CellFlow/main) -[![Codecov](https://codecov.io/gh/theislab/cellflow/branch/main/graph/badge.svg?token=Rgtm5Tsblo)](https://codecov.io/gh/theislab/cellflow) -[![Docs](https://img.shields.io/readthedocs/cellflow)](https://cellflow.readthedocs.io/en/latest/) +[![Codecov](https://codecov.io/gh/theislab/scaleflow/branch/main/graph/badge.svg?token=Rgtm5Tsblo)](https://codecov.io/gh/theislab/scaleflow) +[![Docs](https://img.shields.io/readthedocs/scaleflow)](https://scaleflow.readthedocs.io/en/latest/) CellFlow - Modeling Complex Perturbations with Flow Matching ============================================================ @@ -21,20 +21,20 @@ Check out the [preprint](https://www.biorxiv.org/content/10.1101/2025.04.11.6482 - Modeling the development of perturbed organisms - Cell fate engineering - Optimizing protocols for growing organoids -- ... and more; check out the [documentation](https://cellflow.readthedocs.io) for more information. +- ... and more; check out the [documentation](https://scaleflow.readthedocs.io) for more information. Installation ------------ Install **CellFlow** by running:: - pip install cellflow-tools + pip install scaleflow-tools In order to install **CellFlow** in editable mode, run:: - git clone https://github.com/theislab/cellflow - cd cellflow + git clone https://github.com/theislab/scaleflow + cd scaleflow pip install -e . For further instructions how to install jax, please refer to https://github.com/google/jax. diff --git a/docs/_static/images/cellflow_dark.png b/docs/_static/images/scaleflow_dark.png similarity index 100% rename from docs/_static/images/cellflow_dark.png rename to docs/_static/images/scaleflow_dark.png diff --git a/docs/_static/images/cellflow_light.png b/docs/_static/images/scaleflow_light.png similarity index 100% rename from docs/_static/images/cellflow_light.png rename to docs/_static/images/scaleflow_light.png diff --git a/docs/conf.py b/docs/conf.py index b21eedfd..46cdf374 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,15 +13,15 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -import cellflow +import scaleflow sys.path.insert(0, str(Path(__file__).parent / "extensions")) # -- Project information ----------------------------------------------------- -project = cellflow.__name__ -author = "CellFlow team" -version = ilm.version("cellflow-tools") +project = scaleflow.__name__ +author = "ScaleFlow team" +version = ilm.version("scaleflow-tools") copyright = f"{datetime.now():%Y}, Theislab" # -- General configuration --------------------------------------------------- @@ -67,9 +67,9 @@ ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ - (r"py:class", r"cellflow\..*(K|B|O)"), - (r"py:class", r"cellflow\._typing.*"), - (r"py:class", r"cellflow\..*Protocol.*"), + (r"py:class", r"scaleflow\..*(K|B|O)"), + (r"py:class", r"scaleflow\._typing.*"), + (r"py:class", r"scaleflow\..*Protocol.*"), ] @@ -152,8 +152,8 @@ html_show_sourcelink = False html_theme_options = { "sidebar_hide_name": True, - "light_logo": "images/cellflow_dark.png", - "dark_logo": "images/cellflow_dark.png", + "light_logo": "images/scaleflow_dark.png", + "dark_logo": "images/scaleflow_dark.png", "light_css_variables": { "color-brand-primary": "#003262", "color-brand-content": "#003262", @@ -164,7 +164,7 @@ "footer_icons": [ { "name": "GitHub", - "url": "https://github.com/theislab/cellflow", + "url": "https://github.com/theislab/scaleflow", "html": "", "class": "fab fa-github", }, diff --git a/docs/developer.rst b/docs/developer.rst index a8e16530..8eecafa1 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -4,8 +4,8 @@ Developer API CellFlow model ~~~~~~~~~~~~~~ -.. module:: cellflow.data -.. currentmodule:: cellflow.data +.. module:: scaleflow.data +.. currentmodule:: scaleflow.data .. autosummary:: :toctree: genapi diff --git a/docs/index.rst b/docs/index.rst index a1e64452..0291acc6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,11 +1,11 @@ CellFlow ===================== -.. module:: cellflow +.. module:: scaleflow -:mod:`cellflow` is a framework for modeling single-cell perturbation screens. CellFlow is very flexible and enables researchers to systematically explore how cells respond to a wide range of experimental interventions, including drug treatments, genetic modifications, cytokine stimulation, morphogen pathway modulation or even entire organoid protocols. +:mod:`scaleflow` is a framework for modeling single-cell perturbation screens. CellFlow is very flexible and enables researchers to systematically explore how cells respond to a wide range of experimental interventions, including drug treatments, genetic modifications, cytokine stimulation, morphogen pathway modulation or even entire organoid protocols. -:note: This is a work in progress. We are actively working on extending the documentation of :mod:`cellflow` with more tutorials to cover a wide range of use cases. If you have any questions or suggestions, please feel free to reach out to us. +:note: This is a work in progress. We are actively working on extending the documentation of :mod:`scaleflow` with more tutorials to cover a wide range of use cases. If you have any questions or suggestions, please feel free to reach out to us. .. grid:: 3 @@ -15,13 +15,13 @@ CellFlow :link: installation :link-type: doc - Learn how to install :mod:`cellflow`. + Learn how to install :mod:`scaleflow`. .. grid-item-card:: User API :link: user/index :link-type: doc - The API reference with all the details on how to use :mod:`cellflow` functions. + The API reference with all the details on how to use :mod:`scaleflow` functions. .. grid-item-card:: Manuscript diff --git a/docs/installation.rst b/docs/installation.rst index 9e8f5711..4baf5d48 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,22 +1,22 @@ Installation ============ -:mod:`cellflow` requires Python version >= 3.11 to run. +:mod:`scaleflow` requires Python version >= 3.11 to run. PyPI ---- -Install :mod:`cellflow` by running:: +Install :mod:`scaleflow` by running:: - pip install cellflow-tools + pip install scaleflow-tools Installing `rapids-singlecell` and `cuml`: -While it's not necessary to install :mod:`cellflow` with `rapids-singlecell` and `cuml`, +While it's not necessary to install :mod:`scaleflow` with `rapids-singlecell` and `cuml`, it is recommended to do so for faster preprocessing or downstream functions. -To install :mod:`cellflow` with `rapids-singlecell` and `cuml`, please refer to +To install :mod:`scaleflow` with `rapids-singlecell` and `cuml`, please refer to `instructions how to install rapids `_. Development version ------------------- -To install :mod:`cellflow` from `GitHub `_, run:: +To install :mod:`scaleflow` from `GitHub `_, run:: pip install git+https://github.com/theislab/CellFlow.git@main diff --git a/docs/notebooks/100_pbmc.ipynb b/docs/notebooks/100_pbmc.ipynb index 4260e0cd..b28b9b68 100644 --- a/docs/notebooks/100_pbmc.ipynb +++ b/docs/notebooks/100_pbmc.ipynb @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -66,13 +66,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.pbmc_cytokines()" + "adata = scaleflow.datasets.pbmc_cytokines()" ] }, { @@ -444,7 +444,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the solver `\"otfm\"`, which deterministically maps a cell to its perturbed equivalent. If we wanted to incorporate stochasticity on single-cell level, we would select `\"genot\"`." ] @@ -464,7 +464,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -476,9 +476,9 @@ "\n", "The `perturbation_covariates` indicates the external intervention, i.e. the cytokine treatment. We define a key (of arbitrary name) `\"cytokine_treatment\"` for this, and have the values be tuples with the perturbation and potential perturbation covariates. As we don't have a perturbation covariate (e.g. always the same dose), we only have one tuple, and as we don't observe combinations of treatments, the tuple has length 1. We use ESM2 embeddings for representing the cytokines, which we have precomputed already for the purpose of this notebook, saved in {attr}`uns['esm2_embeddings'] `. Thus, we pass the information that `\"esm2_embeddings\"` stores embeddings of the {attr}`obs['cytokine'] ` treatments via `perturbation_covariate_reps`.\n", "\n", - "The sample covariate describes the cellular context independent of the perturbation. In our case, these are donors, and given in the {attr}`obs['donor'] ` column. We use the mean of the control sample as donor representation, precomputed and saved in {attr}`uns['donor_embeddings'] `. We thus pass this piece of information to {class}`~cellflow.model.CellFlow` via `sample_covariate_reps`. \n", + "The sample covariate describes the cellular context independent of the perturbation. In our case, these are donors, and given in the {attr}`obs['donor'] ` column. We use the mean of the control sample as donor representation, precomputed and saved in {attr}`uns['donor_embeddings'] `. We thus pass this piece of information to {class}`~scaleflow.model.CellFlow` via `sample_covariate_reps`. \n", "\n", - "It remains to define `split_covariates`, according to which {class}`~cellflow.model.CellFlow` trains and predicts perturbations. In effect, `split_covariates` defines how to split the control distributions, and often coincides with `sample_covariates`. This ensure that we don't learn a mapping from the control distribution of donor A to a perturbed population of donor B, but only within the same donor. \n", + "It remains to define `split_covariates`, according to which {class}`~scaleflow.model.CellFlow` trains and predicts perturbations. In effect, `split_covariates` defines how to split the control distributions, and often coincides with `sample_covariates`. This ensure that we don't learn a mapping from the control distribution of donor A to a perturbed population of donor B, but only within the same donor. \n", "\n", "Finally, we can pass `max_combination_length` and `null_value`. These are relevant for combinations of treatments, which doesn't apply for this use case, as we don't want to predict combinationatorial effects of cytokines. In particular, `max_combination_length` is the maximum number of combinations of cytokines which we train on or we want to eventually predict for. The null value is the token representing no treatment, e.g. relevant when we have a treatment with fewer interventions than `max_combination_length`, see tutorials with combinatorial treatments as examples." ] @@ -548,7 +548,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We can now prepare the data for validation using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. We can pass arbitrary splits, which we define with the `name` parameter. The corresponding {class}`adata ` object has to contain the true value, such that during evaluation, we can compare the generated with the true cells.\n", + "We can now prepare the data for validation using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. We can pass arbitrary splits, which we define with the `name` parameter. The corresponding {class}`adata ` object has to contain the true value, such that during evaluation, we can compare the generated with the true cells.\n", "\n", "Note that inference takes relatively long due to solving a neural ODE, hence we might not want to evaluate on the full {class}`adata ` objects, but only on a subset of conditions, the number of which we define using `n_conditions_on_log_iteration` and `n_conditions_on_train_end`. The number of cells we generate for each condition corresponds to the number of control cells, in our case to the number of control cells specific to each donor. As in this dataset the number of control cells is relatively large, we now first subsample the {class}`adata ` object to accelerate inference. " ] @@ -642,7 +642,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -650,36 +650,36 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", "We walk through the parameters one by one:\n", "\n", - "- `condition_mode` defines the structure of the learnt condition embedding space. We will use `deterministic` mode with `regularization=0.0`, which means we learn point estimates of the condition embedding. If we added `regularization>0.0`, this would mean we impose some regularization with respect to the L2-norm of the embeddings. `condition_mode=\"stochastic\"` parameterizes the embeddings space as like a decoder-free variational auto-encoder, i.e. we set a normal isotropic prior on the embeddings, this allows to learn a stochastic mapping and evaluate the uncertainty of predictions on a distributional level (rather than on a single-cell level which can be done with {class}`~cellflow.solvers.GENOT`).\n", + "- `condition_mode` defines the structure of the learnt condition embedding space. We will use `deterministic` mode with `regularization=0.0`, which means we learn point estimates of the condition embedding. If we added `regularization>0.0`, this would mean we impose some regularization with respect to the L2-norm of the embeddings. `condition_mode=\"stochastic\"` parameterizes the embeddings space as like a decoder-free variational auto-encoder, i.e. we set a normal isotropic prior on the embeddings, this allows to learn a stochastic mapping and evaluate the uncertainty of predictions on a distributional level (rather than on a single-cell level which can be done with {class}`~scaleflow.solvers.GENOT`).\n", "- `regularization`, as mentioned above, is a tradeoff between the flow matching loss (which also implicitly learns the condition embeddings space), and the regularization of the mean and potentially the variance of the embedding space. Here, we learn point-wise estimates without any prior on the embeddings space, thus setting `regularization` to 0.0.\n", "- `pooling` defines how we aggregate combinations of conditions, which doesn't apply here. Putting `\"mean\"` thus has no effect, while `\"attention_token\"` or `\"attention_seed\"` would reduce to self-attention.\n", - "- `pooling_kwargs` specifies further keyword arguments for {class}`~cellflow.networks.TokenAttentionPooling` if `pooling` is\n", - " `\"attention_token\"` or {class}`~cellflow.networks.SeedAttentionPooling` if `pooling` is `\"attention_seed\"`.\n", - "- `layers_before_pool` specifies the layers processing the perturbation variables, i.e. perturbations, perturbation covariates, and sample covariates. It must be a dictionary with keys corresponding to the keys we used in {meth}`~cellflow.model.CellFlow.prepare_data`. In this case, this means that we have keys `\"cytokine_treatment\"` and `\"donor_embeddings\"`, with values specifying the architecture, e.g. the type of the module (`\"mlp\"` or `\"self_attention\"`) and layer specifications like number of layers, width, and dropout rate.\n", + "- `pooling_kwargs` specifies further keyword arguments for {class}`~scaleflow.networks.TokenAttentionPooling` if `pooling` is\n", + " `\"attention_token\"` or {class}`~scaleflow.networks.SeedAttentionPooling` if `pooling` is `\"attention_seed\"`.\n", + "- `layers_before_pool` specifies the layers processing the perturbation variables, i.e. perturbations, perturbation covariates, and sample covariates. It must be a dictionary with keys corresponding to the keys we used in {meth}`~scaleflow.model.CellFlow.prepare_data`. In this case, this means that we have keys `\"cytokine_treatment\"` and `\"donor_embeddings\"`, with values specifying the architecture, e.g. the type of the module (`\"mlp\"` or `\"self_attention\"`) and layer specifications like number of layers, width, and dropout rate.\n", "- `layers_before_pool` specifies the architecture of the module after the pooling has been performed.\n", "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder. We set it to 64.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", - "- `condition_encoder_kwargs` specify the architecture of the {class}`~cellflow.networks.ConditionEncoder`. Here, we don't apply any more specifications.\n", + "- `condition_encoder_kwargs` specify the architecture of the {class}`~scaleflow.networks.ConditionEncoder`. Here, we don't apply any more specifications.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", "- `time_freqs` thus (deterministically) embeds the time component before being processed by a feed-forward neural network. This choice is relatively independent of the data. \n", "- `time_encoder_dims` specifies the architecture how to process the time embedding needed for the neural ODE. Note that we pre-encode the time with a sinusoidal embedding of dimension `time_freqs`. This choice is relatively independent of the data. \n", "- `time_encoder_dropout` denotes the dropout applied to the layers processing the time component. This choice is relatively independent of the data. \n", "- `hidden_dims` specifies the layers processing the control cells. The choice depends on the dimensionality of the cell embedding.\n", "- `hidden_dropout` specifies the dropout in the layers defined by `hidden_dims`.\n", - "- `conditioning` specifies the method we use to integrate the different embeddings into the model. Here, we use `\"concatenation\"`, which simply concatenates the time, condition and data embeddings into a single array. Alternative options for `conditioning` are `\"film\"`, which conditions using a {class}`~cellflow.networks.FilmBlock` based on [Perez et al.](https://arxiv.org/abs/1709.07871) and `\"resnet\"` which conditions using a {class}`~cellflow.networks.ResNetBlock` based on [He et al.](https://arxiv.org/abs/1512.03385). \n", - "- `conditioning_kwargs` provides further keyword arguments when the conditioning is not `\"concatenation\"`, e.g. it provides keywords for {class}`~cellflow.networks.FilmBlock` and {class}`~cellflow.networks.ResNetBlock`, which we don't require for this use case.\n", + "- `conditioning` specifies the method we use to integrate the different embeddings into the model. Here, we use `\"concatenation\"`, which simply concatenates the time, condition and data embeddings into a single array. Alternative options for `conditioning` are `\"film\"`, which conditions using a {class}`~scaleflow.networks.FilmBlock` based on [Perez et al.](https://arxiv.org/abs/1709.07871) and `\"resnet\"` which conditions using a {class}`~scaleflow.networks.ResNetBlock` based on [He et al.](https://arxiv.org/abs/1512.03385). \n", + "- `conditioning_kwargs` provides further keyword arguments when the conditioning is not `\"concatenation\"`, e.g. it provides keywords for {class}`~scaleflow.networks.FilmBlock` and {class}`~scaleflow.networks.ResNetBlock`, which we don't require for this use case.\n", "- `decoder_dims` specifies the layers processing the embedding of the condition, the embedding of the cell, and the embedding of the time. It depends on the dimensionality of the cell representation, i.e. the higher-dimensional the cell representation, the higher `decoder_dims` should be chosen.\n", "- `decoder_dropout` sets the dropout rate of the layers processing `decoder_dims`.\n", - "- `vf_act_fn` sets the activation function in the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` if not specified otherwise.\n", - "- `vf_kwargs` provides further keyword arguments when the solver is not `\"otfm\"`, e.g. it provides keywords for {class}`~cellflow.networks._velocity_field.GENOTConditionalVelocityField`, which we don't require for this use case.\n", - "- `probability_path` defines the path between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, which internally applies {class}`~ott.neural.methods.flows.dynamics.ConstantNoiseFlow`. This means that the paths are augmented with random normal noise. Note that the magnitude should depend on the support / variance of the cell embedding. The higher the noise, the more the data is augmented, but the less the marginal distributions are fitted. To maintain convergence on the marginals, one can use `{\"bridge\"}\n", + "- `vf_act_fn` sets the activation function in the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` if not specified otherwise.\n", + "- `vf_kwargs` provides further keyword arguments when the solver is not `\"otfm\"`, e.g. it provides keywords for {class}`~scaleflow.networks._velocity_field.GENOTConditionalVelocityField`, which we don't require for this use case.\n", + "- `probability_path` defines the path between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, which internally applies {class}`~ott.neural.methods.flows.dynamics.ConstantNoiseFlow`. This means that the paths are augmented with random normal noise. Note that the magnitude should depend on the support / variance of the cell embedding. The higher the noise, the more the data is augmented, but the less the marginal distributions are fitted. To maintain convergence on the marginals, one can use `{\"bridge\"}\n", "- `match_fn` defines how to sample pairs batch-wise. If we have largely heterogeneous populations (e.g. whole embryos), we should choose a small entropic regularisation, while for homoegeneous cell populations like cell lines, a large entropic regularisation parameter is sufficient. Moreover, we can select the hyperparameters `tau_a` and `tau_b` determining the extent of unbalancedness in the learnt coupling, see e.g. [moscot](moscot-tools.org) for an in-depth discussion of optimal transport parameters.\n", "- `optimizer` should be used with gradient averaging to have parameter updates after having seen also multiple conditions, not only multiple cells. We found 20 to be a good value, but we recommend to perform a hyperparameter search. \n", - "- `solver_kwargs` is primarily necessary for using a different {attr}`~cellflow.model.CellFlow.solver` than {class}`~cellflow.solvers.OTFlowMatching`, e.g. when using {class}`~cellflow.solvers.GENOT`. So this doesn't apply here.\n", + "- `solver_kwargs` is primarily necessary for using a different {attr}`~scaleflow.model.CellFlow.solver` than {class}`~scaleflow.solvers.OTFlowMatching`, e.g. when using {class}`~scaleflow.solvers.GENOT`. So this doesn't apply here.\n", "- `layer_norm_before_concatenation` determines whether to apply a linear layer before concatenating the condition embedding, the time embedding, and the cell embedding. It can be hyperparameterized over, but we generally found it to not significantly help.\n", "- `linear_projection_before_concatenation` applies linear layers to the embeddings of the condition, the time, and the cell. It can be hyperparameterized over, but we generally found it to not significantly help.\n", "- `seed` sets the seed for solvers." @@ -779,11 +779,11 @@ "source": [ "## Computing and logging metrics during training \n", "\n", - "For computing metrics during training, we provide callbacks. We divide callbacks into two categories: The first one performs computations, thus is an instance of {class}`~cellflow.training.ComputationCallback`; the second one are instances of {class}`~cellflow.training.LoggingCallback` and is used for logging. Users can either provide their own callbacks, or make use of existing ones, including {class}`~cellflow.training.Metrics` for computing metrics in the space which the cells are generated in, e.g. in PCA or VAE-space. For computing metrics in gene space, we can use {class}`~cellflow.training.PCADecodedMetrics` in case cells are PCA-embedded, or {class}`~cellflow.training.VAEDecodedMetrics` in case cells are embedding using {class}`~cellflow.external.CFJaxSCVI`. For computing metrics, we can provide user-defined ones, or metrics provided by CellFlow, which we will do below.\n", + "For computing metrics during training, we provide callbacks. We divide callbacks into two categories: The first one performs computations, thus is an instance of {class}`~scaleflow.training.ComputationCallback`; the second one are instances of {class}`~scaleflow.training.LoggingCallback` and is used for logging. Users can either provide their own callbacks, or make use of existing ones, including {class}`~scaleflow.training.Metrics` for computing metrics in the space which the cells are generated in, e.g. in PCA or VAE-space. For computing metrics in gene space, we can use {class}`~scaleflow.training.PCADecodedMetrics` in case cells are PCA-embedded, or {class}`~scaleflow.training.VAEDecodedMetrics` in case cells are embedding using {class}`~scaleflow.external.CFJaxSCVI`. For computing metrics, we can provide user-defined ones, or metrics provided by CellFlow, which we will do below.\n", "\n", - "For logging, we recommend using [Weights and Biases](https://wandb.ai), for which we provide a callback: {class}`~cellflow.training.WandbLogger`.\n", + "For logging, we recommend using [Weights and Biases](https://wandb.ai), for which we provide a callback: {class}`~scaleflow.training.WandbLogger`.\n", "\n", - "As our cells live in PCA-space, we use the {class}`~cellflow.training.PCADecodedMetrics` callback, which takes as input also an {class}`adata ` object which contains the PCs computed from the training data." + "As our cells live in PCA-space, we use the {class}`~scaleflow.training.PCADecodedMetrics` callback, which takes as input also an {class}`adata ` object which contains the PCs computed from the training data." ] }, { @@ -793,9 +793,9 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"r_squared\", \"mmd\", \"e_distance\"])\n", - "decoded_metrics_callback = cellflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", - "wandb_callback = cellflow.training.WandbLogger(project=\"cellflow_tutorials\", out_dir=\"~\", config={\"name\": \"100m_pbmc\"})\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"r_squared\", \"mmd\", \"e_distance\"])\n", + "decoded_metrics_callback = scaleflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", + "wandb_callback = scaleflow.training.WandbLogger(project=\"scaleflow_tutorials\", out_dir=\"~\", config={\"name\": \"100m_pbmc\"})\n", "\n", "# we don't pass the wandb_callback as it requires a user-specific account, but recommend setting it up\n", "callbacks = [metrics_callback, decoded_metrics_callback]\n" @@ -839,7 +839,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -926,7 +926,7 @@ "id": "c1e33895-4d76-4113-97b1-8f46a3de9037", "metadata": {}, "source": [ - "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Therefore, we have to provide a {class}`~pandas.DataFrame` with the same structure of {attr}`adata.obs ` (at least the columns which we used for {meth}`~cellflow.model.CellFlow.prepare_data`). Note that the embedding is independent of the cells, we thus don't need to pass the cellular representation. Moreover, {meth}`~cellflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", + "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Therefore, we have to provide a {class}`~pandas.DataFrame` with the same structure of {attr}`adata.obs ` (at least the columns which we used for {meth}`~scaleflow.model.CellFlow.prepare_data`). Note that the embedding is independent of the cells, we thus don't need to pass the cellular representation. Moreover, {meth}`~scaleflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", "For now, let's use all conditions, but indicate whether they're seen during training or not:" ] }, @@ -983,7 +983,7 @@ "id": "078cfff2-f938-44f1-9630-491a4db408ca", "metadata": {}, "source": [ - "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~cellflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." + "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~scaleflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." ] }, { @@ -1746,7 +1746,7 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow_mod", + "display_name": "scaleflow_mod", "language": "python", "name": "python3" }, diff --git a/docs/notebooks/200_zebrafish.ipynb b/docs/notebooks/200_zebrafish.ipynb index c566c686..e1164d57 100644 --- a/docs/notebooks/200_zebrafish.ipynb +++ b/docs/notebooks/200_zebrafish.ipynb @@ -13,7 +13,7 @@ "id": "bc061b8c-aaab-413f-8f20-a0f07c812dde", "metadata": {}, "source": [ - "In this tutorial, we predict perturbations on embryo-scale. Therefore, we consider [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2), which captures up to 23 perturbations at 5 different time points, resulting in 71 perturbed phenotypes. The experimental design is sparse, hence we investigate to what extent we can fill missing measurements with {class}`~cellflow.model.CellFlow`'s predictions." + "In this tutorial, we predict perturbations on embryo-scale. Therefore, we consider [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2), which captures up to 23 perturbations at 5 different time points, resulting in 71 perturbed phenotypes. The experimental design is sparse, hence we investigate to what extent we can fill missing measurements with {class}`~scaleflow.model.CellFlow`'s predictions." ] }, { @@ -42,7 +42,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -66,13 +66,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.zesta()" + "adata = scaleflow.datasets.zesta()" ] }, { @@ -484,7 +484,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -504,7 +504,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -557,7 +557,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We now prepare the data validation data using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. \n", + "We now prepare the data validation data using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. \n", "\n", "As for some conditions, and in particular for control cells, we have a large number of measurements, we subsample for inference to be faster. However, due to the heterogeneity of the cellular distribution, covering hundreds of cell types, we should not subsample by too much." ] @@ -635,7 +635,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -643,9 +643,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -653,7 +653,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `probability_path` defines the path between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `probability_path` defines the path between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -754,7 +754,7 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", "callbacks = [metrics_callback]\n" ] }, @@ -796,7 +796,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -966,7 +966,7 @@ "source": [ "## Predicting with CellFlow\n", "\n", - "Predictions with {class}`~cellflow.model.CellFlow` require an {class}`adata ` object with control cells. As we only want to generate cells corresponding the unseen perturbation cdx4 and cdx1a, we only need control cells for time point 36. Moreover, we require `covariate_data` to store the information about what we would like to predict. " + "Predictions with {class}`~scaleflow.model.CellFlow` require an {class}`adata ` object with control cells. As we only want to generate cells corresponding the unseen perturbation cdx4 and cdx1a, we only need control cells for time point 36. Moreover, we require `covariate_data` to store the information about what we would like to predict. " ] }, { @@ -1227,9 +1227,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/201_zebrafish_continuous.ipynb b/docs/notebooks/201_zebrafish_continuous.ipynb index 60f78516..38a74c6d 100644 --- a/docs/notebooks/201_zebrafish_continuous.ipynb +++ b/docs/notebooks/201_zebrafish_continuous.ipynb @@ -13,7 +13,7 @@ "id": "bc061b8c-aaab-413f-8f20-a0f07c812dde", "metadata": {}, "source": [ - "Similary to {doc}`200_zebrafish_continuous`, we make use of the [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2) dataset, which captures up to 23 perturbations at 5 different time points. Here, we leverage {class}`~cellflow.model.CellFlow` to interpolate the perturbed development at densely sampled time points." + "Similary to {doc}`200_zebrafish_continuous`, we make use of the [ZSCAPE](https://www.nature.com/articles/s41586-023-06720-2) dataset, which captures up to 23 perturbations at 5 different time points. Here, we leverage {class}`~scaleflow.model.CellFlow` to interpolate the perturbed development at densely sampled time points." ] }, { @@ -34,7 +34,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -58,13 +58,13 @@ "import rapids_singlecell as rsc\n", "import flax.linen as nn\n", "import optax\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -74,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "adata = cellflow.datasets.zesta()" + "adata = scaleflow.datasets.zesta()" ] }, { @@ -532,7 +532,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -552,7 +552,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -632,7 +632,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -640,9 +640,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -650,7 +650,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `probability_path` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `probability_path` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -751,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", "callbacks = [metrics_callback]\n" ] }, @@ -808,7 +808,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -867,7 +867,7 @@ "id": "c1e33895-4d76-4113-97b1-8f46a3de9037", "metadata": {}, "source": [ - "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Note that {meth}`~cellflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", + "We can visualize the learnt latent space for any condition using {meth}`~CellFlow.get_condition_embedding`. Note that {meth}`~scaleflow.model.CellFlow.get_condition_embedding` returns both the learnt mean embedding and the logvariance. The latter is 0 when `condition_mode=\"stochastic\"`, hence we now only visualize the learnt mean. \n", "For now, let's use all conditions, but indicate whether they're seen during training or not:" ] }, @@ -903,7 +903,7 @@ "id": "078cfff2-f938-44f1-9630-491a4db408ca", "metadata": {}, "source": [ - "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~cellflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." + "We can now visualize the embedding, which is 256-dimensional, by calling {meth}`~scaleflow.plotting.plot_condition_embedding`. We first visualize it according to whether it was seen during training or not. We choose a kernel PCA representation, but we recommend trying other dimensionaly reduction methods as well. We can see that the unseen conditions integrate well." ] }, { @@ -1416,9 +1416,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/300_ineuron_tutorial.ipynb b/docs/notebooks/300_ineuron_tutorial.ipynb index 25d87eef..26c3bc6b 100644 --- a/docs/notebooks/300_ineuron_tutorial.ipynb +++ b/docs/notebooks/300_ineuron_tutorial.ipynb @@ -7,7 +7,7 @@ "source": [ "# Neuron fate prediction from combinatorial morphogen treatment\n", "\n", - "In this notebook, we show how {class}`~cellflow.model.CellFlow` can be used to predict the outcome of **neuron fate programming experiments**. We use the the dataset from [Lin, Janssens et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. \n", + "In this notebook, we show how {class}`~scaleflow.model.CellFlow` can be used to predict the outcome of **neuron fate programming experiments**. We use the the dataset from [Lin, Janssens et al.](https://www.biorxiv.org/content/10.1101/2023.12.12.571318v2), which contains scRNA-seq data from an morphogen screen in NGN2-induced neurons (iNeurons). The treatment conditions comprised combinations of modulators of anterior-posterior (AP) patterning (RA, CHIR99021, XAV-939, FGF8) with modulators of dorso-ventral (DV) patterning (BMP4, SHH), each applied in multiple concentrations. We use CellFlow to predict neuron distributions for held-out combinations of morphogens. \n", "\n", "## Preparing the data" ] @@ -33,8 +33,8 @@ "from scipy.sparse import csr_matrix\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", - "import cellflow\n", - "import cellflow.preprocessing as cfpp" + "import scaleflow\n", + "import scaleflow.preprocessing as cfpp" ] }, { @@ -59,7 +59,7 @@ } ], "source": [ - "adata = cellflow.datasets.ineurons()\n", + "adata = scaleflow.datasets.ineurons()\n", "print(adata)" ] }, @@ -222,7 +222,7 @@ "metadata": {}, "outputs": [], "source": [ - "cf = cellflow.model.CellFlow(adata_train_full, solver=\"otfm\")" + "cf = scaleflow.model.CellFlow(adata_train_full, solver=\"otfm\")" ] }, { @@ -230,7 +230,7 @@ "id": "d8849016", "metadata": {}, "source": [ - "### Preparing CellFlow’s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`\n", + "### Preparing CellFlow’s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`\n", "We set up the data as follows:\n", "- We use `.obsm[\"X_pca\"]` as the cellular representation (`sample_rep`)\n", "- `\"CTRL\"` indicated the source distribution we constructed earlier\n", @@ -266,16 +266,16 @@ "id": "3cdd1103", "metadata": {}, "source": [ - "### Preparing CellFlow’s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`\n", - "Now we can set up the architecture of the CellFlow model. For a detailed description of all hyperparameters, please have a look at {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`. \n", + "### Preparing CellFlow’s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`\n", + "Now we can set up the architecture of the CellFlow model. For a detailed description of all hyperparameters, please have a look at {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`. \n", "\n", "While there is some intuition behind which parameter settings to use, we generally we use hyperparameter optimization on a separate validation set to find the best hyperparameters for each task. These are some of the most relevant parameters for this task:\n", "\n", "- `layers_before_pool` and `layers_after_pool` define the networks before and after permutation-invariant pooling of combinatorial conditions. Here, we only define a network before pooling to encode the one-hot-encoded morphogen representations and no `layers_after_pool` to use only one layer to transform the pooled representation into the condition embedding.\n", "- We found that that pooling the combinations by their mean (`pooling_type=\"mean\"`) works best for this task. This might be due to the fact that the morphogen combination conditions are *relatively* simple and their total number is somewhat small, which might make it harder to learn a more complex attention-based pooling.\n", "- `match_fn` defines how to sample pairs between the source and the perturbed cells. Here, the source distribution is a random distribution rather than a control condition, but because the output distributions are relatively complex an might contain outliers, we still found some unbalancedness to be useful for this task, so we se `tau_a=tau_b=0.99`.\n", - "- `flow` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. We don't use any noise here as our cell population is highly heterogenous.\n", - "- We found that sometimes the relationship between sizes of the condition embedding as well as encoded `x`, and `time` in the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` can matter quite a bit to the model. We are not sure exactly why this is tha case, but we found it to be especially important for iNeuron and organoid applications, where we generate from noise into a complex output distribution. We therefore set the `hidden_dims=[2048] * 2 + [128]` to transform the `x` embedding into a smaller dimension with the last layer." + "- `flow` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. We don't use any noise here as our cell population is highly heterogenous.\n", + "- We found that sometimes the relationship between sizes of the condition embedding as well as encoded `x`, and `time` in the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` can matter quite a bit to the model. We are not sure exactly why this is tha case, but we found it to be especially important for iNeuron and organoid applications, where we generate from noise into a complex output distribution. We therefore set the `hidden_dims=[2048] * 2 + [128]` to transform the `x` embedding into a smaller dimension with the last layer." ] }, { @@ -339,7 +339,7 @@ "id": "c806a95a", "metadata": {}, "source": [ - "Now we can train the model. To make training quicker, we here don't compute validation metrics during training, but only evaluare predictions afterwards. If you are running a model for the fist time, we recommend to monitor training behaviour with validation data through {meth}`~cellflow.model.CellFlow.prepare_validation_data` as explained in {doc}`100_pbmc`." + "Now we can train the model. To make training quicker, we here don't compute validation metrics during training, but only evaluare predictions afterwards. If you are running a model for the fist time, we recommend to monitor training behaviour with validation data through {meth}`~scaleflow.model.CellFlow.prepare_validation_data` as explained in {doc}`100_pbmc`." ] }, { @@ -365,7 +365,7 @@ "id": "87a4848b", "metadata": {}, "source": [ - "After training, we can save the model to disk with {meth}`~cellflow.model.CellFlow.save_model` and load it again with {meth}`~cellflow.model.CellFlow.load_model`. " + "After training, we can save the model to disk with {meth}`~scaleflow.model.CellFlow.save_model` and load it again with {meth}`~scaleflow.model.CellFlow.load_model`. " ] }, { @@ -375,9 +375,9 @@ "metadata": {}, "outputs": [], "source": [ - "cf.save(\"cellflow_model/\", overwrite=True)\n", - "cf = cellflow.model.CellFlow.load(\n", - " \"cellflow_model/\"\n", + "cf.save(\"scaleflow_model/\", overwrite=True)\n", + "cf = scaleflow.model.CellFlow.load(\n", + " \"scaleflow_model/\"\n", ")" ] }, @@ -387,7 +387,7 @@ "metadata": {}, "source": [ "### Making predictions\n", - "Now we can finally check out the predictions. we use {meth}`~cellflow.model.CellFlow.predict` to generate predictions for the held-out conditions in the validation dataset." + "Now we can finally check out the predictions. we use {meth}`~scaleflow.model.CellFlow.predict` to generate predictions for the held-out conditions in the validation dataset." ] }, { @@ -428,7 +428,7 @@ "id": "d037aadd", "metadata": {}, "source": [ - "{meth}`~cellflow.model.CellFlow.predict` returns a dictionaly with predictions for each condition. We now convert the predictions into an {class}`adata ` object." + "{meth}`~scaleflow.model.CellFlow.predict` returns a dictionaly with predictions for each condition. We now convert the predictions into an {class}`adata ` object." ] }, { @@ -456,7 +456,7 @@ "id": "deadfa4e", "metadata": {}, "source": [ - "To obtain gene expression values for our predictions, we use {meth}`cellflow.preprocessing.reconstruct_pca` to reconstruct the PCA space where the predictions were made. We then reproject the predictions into a new PCA space with the full ground truth data. " + "To obtain gene expression values for our predictions, we use {meth}`scaleflow.preprocessing.reconstruct_pca` to reconstruct the PCA space where the predictions were made. We then reproject the predictions into a new PCA space with the full ground truth data. " ] }, { diff --git a/docs/notebooks/500_combosciplex.ipynb b/docs/notebooks/500_combosciplex.ipynb index 8e42b30b..879ba0a9 100644 --- a/docs/notebooks/500_combosciplex.ipynb +++ b/docs/notebooks/500_combosciplex.ipynb @@ -34,7 +34,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/icb/dominik.klein/mambaforge/envs/cellflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/icb/dominik.klein/mambaforge/envs/scaleflow/lib/python3.12/site-packages/optuna/study/_optimize.py:29: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from optuna import progress_bar as pbar_module\n" ] } @@ -63,13 +63,13 @@ "import flax.linen as nn\n", "import optax\n", "import pertpy\n", - "import cellflow\n", - "from cellflow.model import CellFlow\n", - "import cellflow.preprocessing as cfpp\n", - "from cellflow.utils import match_linear\n", - "from cellflow.plotting import plot_condition_embedding\n", - "from cellflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca, annotate_compounds, get_molecular_fingerprints\n", - "from cellflow.metrics import compute_r_squared, compute_e_distance\n" + "import scaleflow\n", + "from scaleflow.model import CellFlow\n", + "import scaleflow.preprocessing as cfpp\n", + "from scaleflow.utils import match_linear\n", + "from scaleflow.plotting import plot_condition_embedding\n", + "from scaleflow.preprocessing import transfer_labels, compute_wknn, centered_pca, project_pca, reconstruct_pca, annotate_compounds, get_molecular_fingerprints\n", + "from scaleflow.metrics import compute_r_squared, compute_e_distance\n" ] }, { @@ -356,7 +356,7 @@ "id": "82a74985-1d8b-43e5-97a6-a3cf220e33e5", "metadata": {}, "source": [ - "We require embeddings for the drugs. While we encourage users to try different ones, we use molecular fingerprints in the following. Therefore, we first annotate the drugs, i.e. we retrieve the SMILES and [PubChem](https://pubchem.ncbi.nlm.nih.gov/) metadata using {func}`~cellflow.preprocessing.annotate_compounds`:" + "We require embeddings for the drugs. While we encourage users to try different ones, we use molecular fingerprints in the following. Therefore, we first annotate the drugs, i.e. we retrieve the SMILES and [PubChem](https://pubchem.ncbi.nlm.nih.gov/) metadata using {func}`~scaleflow.preprocessing.annotate_compounds`:" ] }, { @@ -634,7 +634,7 @@ "id": "7438db45-50af-4c97-8be9-c546a0982f70", "metadata": {}, "source": [ - "Among others, this gave us the SMILES strings, such that we can now get the molecular fingerprints for the SMILES strings using {func}`~cellflow.preprocessing.get_molecular_fingerprints`. We have {attr}`uns['fingerprints'] ` added, and see that all drugs have been assigned a fingerprint." + "Among others, this gave us the SMILES strings, such that we can now get the molecular fingerprints for the SMILES strings using {func}`~scaleflow.preprocessing.get_molecular_fingerprints`. We have {attr}`uns['fingerprints'] ` added, and see that all drugs have been assigned a fingerprint." ] }, { @@ -664,7 +664,7 @@ "id": "abe53039-afc8-42fc-a255-40c94d8e79b4", "metadata": {}, "source": [ - "We now add a zero token which is going to be ignored during training for \"filling\" the second drug in case of single drug perturbations. Note that this zero token will be specified later in {meth}`~cellflow.model.CellFlow.prepare_data`." + "We now add a zero token which is going to be ignored during training for \"filling\" the second drug in case of single drug perturbations. Note that this zero token will be specified later in {meth}`~scaleflow.model.CellFlow.prepare_data`." ] }, { @@ -770,7 +770,7 @@ "source": [ "## Setting up the CellFlow model\n", "\n", - "We are now ready to setup the {class}`~cellflow.model.CellFlow` model.\n", + "We are now ready to setup the {class}`~scaleflow.model.CellFlow` model.\n", "\n", "Therefore, we first choose the flow matching solver. We select the default solver `\"otfm\"`." ] @@ -790,7 +790,7 @@ "id": "e1500afe-18b6-4d18-aa6a-91451548cca4", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s data handling with {meth}`~cellflow.model.CellFlow.prepare_data`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s data handling with {meth}`~scaleflow.model.CellFlow.prepare_data`" ] }, { @@ -837,7 +837,7 @@ "id": "a1fc1515-30d3-4ee9-92bb-0c8299f94d21", "metadata": {}, "source": [ - "We now prepare the data validation data using {meth}`~cellflow.model.CellFlow.prepare_validation_data`. \n", + "We now prepare the data validation data using {meth}`~scaleflow.model.CellFlow.prepare_validation_data`. \n", "\n", "As for some conditions, and in particular for control cells, we have a large number of measurements, we subsample for inference to be faster. However, due to the heterogeneity of the cellular distribution, covering hundreds of cell types, we should not subsample by too much." ] @@ -878,7 +878,7 @@ "id": "806d7551-1a1a-4080-abfc-d8839724d7a2", "metadata": {}, "source": [ - "## Preparing {class}`~cellflow.model.CellFlow`'s model architecture with {meth}`~cellflow.model.CellFlow.prepare_model`" + "## Preparing {class}`~scaleflow.model.CellFlow`'s model architecture with {meth}`~scaleflow.model.CellFlow.prepare_model`" ] }, { @@ -886,9 +886,9 @@ "id": "93f0ed52-4cd9-45c6-af62-3b230a49903c", "metadata": {}, "source": [ - "We are now ready to specify the architecture of {class}`~cellflow.model.CellFlow`.\n", + "We are now ready to specify the architecture of {class}`~scaleflow.model.CellFlow`.\n", "\n", - "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~cellflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", + "We only consider the most relevant parameters, for a detailed description, please have a look at the documentation of {meth}`~scaleflow.model.CellFlow.prepare_model` or {doc}`100_pbmc`.\n", "\n", "- We use `condition_mode=\"deterministic\"` to learn point estimates of condition embeddings, and thus have a fully deterministic mapping. We set `regularization=0.0`, thus don't regularize the learnt latent space. \n", "- `pooling` defines how we aggregate combinations of conditions in a permutation-invariant manner, which we choose to do learning a class token indicated by `\"attention_token\"`.\n", @@ -896,7 +896,7 @@ "- `condition_embedding_dim` is the dimension of the latent space of the condition encoder.\n", "- `cond_output_dropout` is the dropout applied to the condition embedding, we recommend to set it relatively high, especially if the `condition_embedding_dim` is large.\n", "- `pool_sample_covariates` defines whether the concatenation of the sample covariates should happen before or after pooling, in our case indicating whether it's part of the self-attention or only appended afterwards. \n", - "- `flow` defines the reference vector field between pairs of samples which the {class}`~cellflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", + "- `flow` defines the reference vector field between pairs of samples which the {class}`~scaleflow.networks._velocity_field.ConditionalVelocityField` is regressed against. Here, we use `{\"constant_noise\": 0.5}`, i.e. we use a relatively small value as we have a highly heterogeneous cell population. In fact, if we augment a cell with noise, we should be careful not to augment it to the extent that it is e.g. in a completely different organ of the zebrafish.\n", "- `match_fn` defines how to sample pairs between the control and the perturbed cells. As we have a strongly heterogeneous population, we choose a higher batch size of 2048. We don't expect large outliers, and are not interested in the trajectory of a single cell, hence we choose `tau_a=tau_b=1.0`." ] }, @@ -998,8 +998,8 @@ "metadata": {}, "outputs": [], "source": [ - "metrics_callback = cellflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", - "decoded_metrics_callback = cellflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", + "metrics_callback = scaleflow.training.Metrics(metrics=[\"mmd\", \"e_distance\"])\n", + "decoded_metrics_callback = scaleflow.training.PCADecodedMetrics(ref_adata=adata_train, metrics=[\"r_squared\"])\n", "callbacks = [metrics_callback, decoded_metrics_callback]\n" ] }, @@ -1041,7 +1041,7 @@ "id": "44171594-47e4-458d-8e29-34c3d5e2979f", "metadata": {}, "source": [ - "We can now investigate some training statistics, stored by the {class}`~cellflow.training.CellFlowTrainer`." + "We can now investigate some training statistics, stored by the {class}`~scaleflow.training.CellFlowTrainer`." ] }, { @@ -1178,7 +1178,7 @@ "source": [ "## Predicting with CellFlow\n", "\n", - "Predictions with {class}`~cellflow.model.CellFlow` require an {class}`adata ` object with control cells. Moreover, we need `covariate_data` to store the information about what we would like to predict. " + "Predictions with {class}`~scaleflow.model.CellFlow` require an {class}`adata ` object with control cells. Moreover, we need `covariate_data` to store the information about what we would like to predict. " ] }, { @@ -1613,7 +1613,7 @@ "id": "2b290bcf-30c7-4077-b3ac-959dabed0aee", "metadata": {}, "source": [ - "We also compute the metrics of CellFlow with respect to the ground truth data going through the encoder-decoder in order to separate CellFlow's model performance from the encoder-decoder. Note that this is what is computed during training with {class}`~cellflow.training.PCADecodedMetrics`. " + "We also compute the metrics of CellFlow with respect to the ground truth data going through the encoder-decoder in order to separate CellFlow's model performance from the encoder-decoder. Note that this is what is computed during training with {class}`~scaleflow.training.PCADecodedMetrics`. " ] }, { @@ -1989,9 +1989,9 @@ ], "metadata": { "kernelspec": { - "display_name": "cellflow", + "display_name": "scaleflow", "language": "python", - "name": "cellflow" + "name": "scaleflow" }, "language_info": { "codemirror_mode": { diff --git a/docs/notebooks/600_trainsampler copy.ipynb b/docs/notebooks/600_trainsampler copy.ipynb new file mode 100644 index 00000000..79a4072e --- /dev/null +++ b/docs/notebooks/600_trainsampler copy.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "5765bb6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from scaleflow.data import MappedCellData" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5e77bb94", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "data_path = Path(\"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "cb38a3f8", + "metadata": {}, + "outputs": [], + "source": [ + "mcd = MappedCellData.read_zarr(data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "675044bc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: ,\n", + " 1: ,\n", + " 10: ,\n", + " 11: ,\n", + " 12: ,\n", + " 13: ,\n", + " 14: ,\n", + " 15: ,\n", + " 16: ,\n", + " 17: ,\n", + " 18: ,\n", + " 19: ,\n", + " 2: ,\n", + " 20: ,\n", + " 21: ,\n", + " 22: ,\n", + " 23: ,\n", + " 24: ,\n", + " 25: ,\n", + " 26: ,\n", + " 27: ,\n", + " 28: ,\n", + " 29: ,\n", + " 3: ,\n", + " 30: ,\n", + " 31: ,\n", + " 32: ,\n", + " 33: ,\n", + " 34: ,\n", + " 35: ,\n", + " 36: ,\n", + " 37: ,\n", + " 38: ,\n", + " 39: ,\n", + " 4: ,\n", + " 40: ,\n", + " 41: ,\n", + " 42: ,\n", + " 43: ,\n", + " 44: ,\n", + " 45: ,\n", + " 46: ,\n", + " 47: ,\n", + " 48: ,\n", + " 49: ,\n", + " 5: ,\n", + " 6: ,\n", + " 7: ,\n", + " 8: ,\n", + " 9: }" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcd.src_cell_data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "33793ea8", + "metadata": {}, + "outputs": [], + "source": [ + "from scaleflow.data import ReservoirSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "05bd4946", + "metadata": {}, + "outputs": [], + "source": [ + "rs = ReservoirSampler(mcd, batch_size=1024, pool_size=40, replacement_prob=0.01)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b40a9520", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "rng = np.random.default_rng(0)\n", + "rs.init_pool(rng)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "799aad1f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "scheduled replacement of 40 with 10 (slot 2)\n", + "scheduled replacement of 49 with 16 (slot 0)\n", + "average time per iteration: 0.0003986350893974304\n", + "iterations per second: 2508.5599000117672\n" + ] + } + ], + "source": [ + "import time\n", + "iter_times = []\n", + "start_time = time.time()\n", + "for iter in range(4000):\n", + " batch = rs.sample(rng)\n", + " end_time = time.time()\n", + " iter_times.append(end_time - start_time)\n", + " start_time = end_time\n", + "\n", + "print(\"average time per iteration: \", np.mean(iter_times))\n", + "print(\"iterations per second: \", 1 / np.mean(iter_times))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/tahoe_sizes.ipynb b/docs/notebooks/tahoe_sizes.ipynb new file mode 100644 index 00000000..ea311b09 --- /dev/null +++ b/docs/notebooks/tahoe_sizes.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3ed731bd", + "metadata": {}, + "outputs": [], + "source": [ + "from scaleflow.data import MappedCellData" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "62955dea", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def calculate_memory_cost(\n", + " data: MappedCellData,\n", + " src_idx: int,\n", + " include_condition_data: bool = True\n", + ") -> dict[str, int | list | dict]:\n", + " \"\"\"Calculate memory cost in bytes for a given source index and its target distributions.\n", + " \n", + " Parameters\n", + " ----------\n", + " data\n", + " The training data.\n", + " src_idx\n", + " The source distribution index.\n", + " include_condition_data\n", + " Whether to include condition data in memory calculations.\n", + " \n", + " Returns\n", + " -------\n", + " Dictionary with memory statistics in bytes for the source and its targets.\n", + " \"\"\"\n", + " if src_idx not in data.control_to_perturbation:\n", + " raise ValueError(f\"Source index {src_idx} not found in control_to_perturbation mapping\")\n", + " \n", + " # Get target indices for this source\n", + " target_indices = data.control_to_perturbation[src_idx]\n", + " \n", + " # Calculate memory for source cells\n", + " source_mask = data.split_covariates_mask == src_idx\n", + " n_source_cells = data.src_cell_idx[src_idx].shape[0]\n", + " source_memory = data.src_cell_data[src_idx].nbytes\n", + " \n", + " # Calculate memory for target cells\n", + " target_memories = {}\n", + " total_target_memory = 0\n", + " \n", + " for target_idx in target_indices:\n", + " n_target_cells = data.tgt_cell_idx[target_idx].shape[0]\n", + " target_memory = data.tgt_cell_data[target_idx].nbytes\n", + " target_memories[f\"target_{target_idx}\"] = target_memory\n", + " total_target_memory += target_memory\n", + " \n", + " # Calculate condition data memory if available and requested\n", + " condition_memory = 0\n", + " condition_details = {}\n", + " if include_condition_data and data.condition_data is not None:\n", + " for cond_name, cond_array in data.condition_data.items():\n", + " # Condition data is indexed by target indices\n", + " relevant_condition_size = len(target_indices) * cond_array.shape[1] * cond_array.dtype.itemsize\n", + " condition_details[f\"condition_{cond_name}\"] = relevant_condition_size\n", + " condition_memory += relevant_condition_size\n", + " \n", + " # Calculate total memory\n", + " total_memory = source_memory + total_target_memory + condition_memory\n", + " \n", + " # Calculate average target memory\n", + " avg_target_memory = total_target_memory // len(target_indices) if target_indices.size > 0 else 0\n", + " \n", + " result = {\n", + " \"source_idx\": src_idx,\n", + " \"target_indices\": target_indices.tolist(),\n", + " \"source_memory\": source_memory,\n", + " \"source_cell_count\": int(n_source_cells),\n", + " \"total_target_memory\": total_target_memory,\n", + " \"avg_target_memory\": avg_target_memory,\n", + " \"condition_memory\": condition_memory,\n", + " \"total_memory\": total_memory,\n", + " \"target_details\": target_memories,\n", + " }\n", + " \n", + " if condition_details:\n", + " result[\"condition_details\"] = condition_details\n", + " \n", + " return result\n", + "\n", + "def format_memory_stats(memory_stats: dict, unit: str = \"auto\", summary: bool = False) -> str:\n", + " \"\"\"Format memory statistics into a human-readable string.\n", + " \n", + " Parameters\n", + " ----------\n", + " memory_stats\n", + " Dictionary with memory statistics from calculate_memory_cost.\n", + " unit\n", + " Memory unit to use for display. Options: 'B', 'KB', 'MB', 'GB', 'auto'.\n", + " If 'auto', the most appropriate unit will be chosen automatically.\n", + " summary\n", + " If True, includes a summary with average, min, and max target memory statistics\n", + " and omits detailed per-target breakdown.\n", + " \n", + " Returns\n", + " -------\n", + " Human-readable string representation of memory statistics.\n", + " \"\"\"\n", + " def format_bytes(bytes_value, unit=\"auto\"):\n", + " if unit == \"auto\":\n", + " # Choose appropriate unit\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " elif unit == \"KB\":\n", + " bytes_value /= 1024\n", + " elif unit == \"MB\":\n", + " bytes_value /= (1024 * 1024)\n", + " elif unit == \"GB\":\n", + " bytes_value /= (1024 * 1024 * 1024)\n", + " \n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " src_idx = memory_stats[\"source_idx\"]\n", + " target_indices = memory_stats[\"target_indices\"]\n", + " \n", + " # Base information\n", + " lines = [\n", + " f\"Memory statistics for source index {src_idx} with {len(target_indices)} targets:\",\n", + " f\"- Source cells: {memory_stats['source_cell_count']} cells, {format_bytes(memory_stats['source_memory'], unit)}\",\n", + " f\"- Total memory: {format_bytes(memory_stats['total_memory'], unit)}\",\n", + " ]\n", + " \n", + " # Calculate min and max target memory if summary is requested\n", + " if summary and memory_stats[\"target_details\"]:\n", + " target_memories = list(memory_stats[\"target_details\"].values())\n", + " min_target = min(target_memories)\n", + " max_target = max(target_memories)\n", + " \n", + " lines.extend([\n", + " \"\\nTarget memory summary:\",\n", + " f\"- Total: {format_bytes(memory_stats['total_target_memory'], unit)}\",\n", + " f\"- Average: {format_bytes(memory_stats['avg_target_memory'], unit)}\",\n", + " f\"- Min: {format_bytes(min_target, unit)}\",\n", + " f\"- Max: {format_bytes(max_target, unit)}\",\n", + " f\"- Range: {format_bytes(max_target - min_target, unit)}\"\n", + " ])\n", + " \n", + " # Add condition memory summary if available\n", + " if memory_stats[\"condition_memory\"] > 0:\n", + " lines.append(f\"\\nCondition memory: {format_bytes(memory_stats['condition_memory'], unit)}\")\n", + " else:\n", + " # Detailed output (original format)\n", + " lines.extend([\n", + " f\"- Target memory: {format_bytes(memory_stats['total_target_memory'], unit)} total, {format_bytes(memory_stats['avg_target_memory'], unit)} average per target\",\n", + " f\"- Condition memory: {format_bytes(memory_stats['condition_memory'], unit)}\",\n", + " \"\\nTarget details:\"\n", + " ])\n", + " \n", + " for target_key, target_memory in memory_stats[\"target_details\"].items():\n", + " target_id = target_key.split(\"_\")[1]\n", + " lines.append(f\" - Target {target_id}: {format_bytes(target_memory, unit)}\")\n", + " \n", + " if \"condition_details\" in memory_stats:\n", + " lines.append(\"\\nCondition details:\")\n", + " for cond_key, cond_memory in memory_stats[\"condition_details\"].items():\n", + " cond_name = cond_key.split(\"_\", 1)[1]\n", + " lines.append(f\" - {cond_name}: {format_bytes(cond_memory, unit)}\")\n", + " \n", + " return \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "316e3a6a", + "metadata": {}, + "outputs": [], + "source": [ + "data = MappedCellData.read_zarr(\n", + " \"/lustre/groups/ml01/workspace/100mil/tahoe.zarr\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3d101216", + "metadata": {}, + "outputs": [], + "source": [ + "stats = calculate_memory_cost(data, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a79f9fc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics for source index 0 with 194 targets:\n", + "- Source cells: 60135 cells, 68.82 MB\n", + "- Total memory: 548.11 MB\n", + "\n", + "Target memory summary:\n", + "- Total: 479.28 MB\n", + "- Average: 2.47 MB\n", + "- Min: 44.53 KB\n", + "- Max: 6.35 MB\n", + "- Range: 6.31 MB\n", + "\n", + "Condition memory: 4.55 KB\n" + ] + } + ], + "source": [ + "print(format_memory_stats(stats, summary=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8c400080", + "metadata": {}, + "outputs": [], + "source": [ + "data_stats = {}\n", + "for i in range(data.n_controls):\n", + " data_stats[i] = calculate_memory_cost(data, i)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "710fb69d", + "metadata": {}, + "outputs": [], + "source": [ + "def print_average_memory_per_source(stats_dict):\n", + " \"\"\"Print the average total memory per source index.\n", + " \n", + " Parameters\n", + " ----------\n", + " stats_dict\n", + " Optional pre-calculated memory statistics dictionary.\n", + " If None, statistics will be calculated for all source indices.\n", + " \"\"\"\n", + " \n", + " \n", + " # Extract total memory for each source index\n", + " total_memories = [stats[\"total_memory\"] for stats in stats_dict.values()]\n", + " \n", + " # Calculate statistics\n", + " avg_memory = np.mean(total_memories)\n", + " min_memory = np.min(total_memories)\n", + " max_memory = np.max(total_memories)\n", + " median_memory = np.median(total_memories)\n", + " \n", + " # Format the output\n", + " def format_bytes(bytes_value):\n", + " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", + " if bytes_value < 1024 or unit == \"GB\":\n", + " break\n", + " bytes_value /= 1024\n", + " return f\"{bytes_value:.2f} {unit}\"\n", + " \n", + " print(f\"Memory statistics across {len(stats_dict)} source indices:\")\n", + " print(f\"- Average total memory per source: {format_bytes(avg_memory)}\")\n", + " print(f\"- Minimum total memory: {format_bytes(min_memory)}\")\n", + " print(f\"- Maximum total memory: {format_bytes(max_memory)}\")\n", + " print(f\"- Median total memory: {format_bytes(median_memory)}\")\n", + " print(f\"- Range: {format_bytes(max_memory - min_memory)}\")\n", + " \n", + " # Identify source indices with min and max memory\n", + " min_idx = min(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " max_idx = max(stats_dict.keys(), key=lambda k: stats_dict[k][\"total_memory\"])\n", + " \n", + " print(f\"\\nSource index with minimum memory: {min_idx} ({format_bytes(min_memory)})\")\n", + " print(f\"Source index with maximum memory: {max_idx} ({format_bytes(max_memory)})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e2f8f809", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory statistics across 50 source indices:\n", + "- Average total memory per source: 2.14 GB\n", + "- Minimum total memory: 21.01 MB\n", + "- Maximum total memory: 6.75 GB\n", + "- Median total memory: 2.05 GB\n", + "- Range: 6.73 GB\n", + "\n", + "Source index with minimum memory: 39 (21.01 MB)\n", + "Source index with maximum memory: 22 (6.75 GB)\n" + ] + } + ], + "source": [ + "print_average_memory_per_source(data_stats)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f07c55d9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/user/datasets.rst b/docs/user/datasets.rst index b674e7db..8c70f2f7 100644 --- a/docs/user/datasets.rst +++ b/docs/user/datasets.rst @@ -1,7 +1,7 @@ Datasets ~~~~~~~~ -.. module:: cellflow.datasets -.. currentmodule:: cellflow +.. module:: scaleflow.datasets +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/external.rst b/docs/user/external.rst index 124c16b2..2556f735 100644 --- a/docs/user/external.rst +++ b/docs/user/external.rst @@ -1,7 +1,7 @@ External ~~~~~~~~ -.. module:: cellflow.external -.. currentmodule:: cellflow +.. module:: scaleflow.external +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/index.rst b/docs/user/index.rst index af12c42e..ea9e5584 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -1,7 +1,7 @@ User API ######## -.. module:: cellflow.user +.. module:: scaleflow.user .. toctree:: :maxdepth: 2 diff --git a/docs/user/metrics.rst b/docs/user/metrics.rst index 3f1b558e..19927d7f 100644 --- a/docs/user/metrics.rst +++ b/docs/user/metrics.rst @@ -1,7 +1,7 @@ Metrics ~~~~~~~ -.. module:: cellflow.metrics -.. currentmodule:: cellflow +.. module:: scaleflow.metrics +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/model.rst b/docs/user/model.rst index aae3774f..9f8c2aad 100644 --- a/docs/user/model.rst +++ b/docs/user/model.rst @@ -1,7 +1,7 @@ Model ~~~~~ -.. module:: cellflow.model -.. currentmodule:: cellflow +.. module:: scaleflow.model +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/networks.rst b/docs/user/networks.rst index 76302a4b..d6148740 100644 --- a/docs/user/networks.rst +++ b/docs/user/networks.rst @@ -1,7 +1,7 @@ Networks ~~~~~~~~ -.. module:: cellflow.networks -.. currentmodule:: cellflow +.. module:: scaleflow.networks +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/plotting.rst b/docs/user/plotting.rst index 88a0389b..9414813e 100644 --- a/docs/user/plotting.rst +++ b/docs/user/plotting.rst @@ -1,7 +1,7 @@ Plotting ~~~~~~~~ -.. module:: cellflow.plotting -.. currentmodule:: cellflow +.. module:: scaleflow.plotting +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/preprocessing.rst b/docs/user/preprocessing.rst index ada1ca2b..8c897774 100644 --- a/docs/user/preprocessing.rst +++ b/docs/user/preprocessing.rst @@ -1,7 +1,7 @@ Preprocessing ~~~~~~~~~~~~~ -.. module:: cellflow.preprocessing -.. currentmodule:: cellflow +.. module:: scaleflow.preprocessing +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/solvers.rst b/docs/user/solvers.rst index f1991264..4f1dc4dc 100644 --- a/docs/user/solvers.rst +++ b/docs/user/solvers.rst @@ -1,8 +1,8 @@ Solvers ~~~~~~~ -.. module:: cellflow.solvers -.. currentmodule:: cellflow +.. module:: scaleflow.solvers +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/training.rst b/docs/user/training.rst index a2741424..dd257024 100644 --- a/docs/user/training.rst +++ b/docs/user/training.rst @@ -1,7 +1,7 @@ Training ~~~~~~~~ -.. module:: cellflow.training -.. currentmodule:: cellflow +.. module:: scaleflow.training +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/docs/user/utils.rst b/docs/user/utils.rst index 28be6224..ec1f4d19 100644 --- a/docs/user/utils.rst +++ b/docs/user/utils.rst @@ -1,7 +1,7 @@ Utils ~~~~~ -.. module:: cellflow.utils -.. currentmodule:: cellflow +.. module:: scaleflow.utils +.. currentmodule:: scaleflow .. autosummary:: :toctree: genapi diff --git a/pyproject.toml b/pyproject.toml index 152c690b..eda7b744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ build-backend = "hatchling.build" requires = [ "hatch-vcs", "hatchling" ] [project] -name = "cellflow-tools" +name = "scaleflow-tools" description = "Modeling complex perturbations with flow matching at single-cell resolution" readme = "README.md" license = "PolyForm-Noncommercial-1.0.0" @@ -88,9 +88,9 @@ optional-dependencies.pp = [ "rdkit", ] optional-dependencies.test = [ - "cellflow-tools[embedding]", - "cellflow-tools[external]", - "cellflow-tools[pp]", + "scaleflow-tools[embedding]", + "scaleflow-tools[external]", + "scaleflow-tools[pp]", "coverage[toml]>=7", "pytest", "pytest-cov>=6", @@ -98,12 +98,12 @@ optional-dependencies.test = [ "pytest-xdist>=3", ] -urls.Documentation = "https://cellflow.readthedocs.io/" -urls.Home-page = "https://github.com/theislab/cellflow" -urls.Source = "https://github.com/theislab/cellflow" +urls.Documentation = "https://scaleflow.readthedocs.io/" +urls.Home-page = "https://github.com/theislab/scaleflow" +urls.Source = "https://github.com/theislab/scaleflow" [tool.hatch.build.targets.wheel] -packages = [ 'src/cellflow' ] +packages = [ 'src/scaleflow' ] [tool.hatch.version] source = "vcs" @@ -201,7 +201,7 @@ extras = test,pp,external,embedding pass_env = PYTEST_*,CI commands = coverage run -m pytest {tty:--color=yes} {posargs: \ - --cov={env_site_packages_dir}{/}cellflow --cov-config={tox_root}{/}pyproject.toml \ + --cov={env_site_packages_dir}{/}scaleflow --cov-config={tox_root}{/}pyproject.toml \ --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} [testenv:lint-code] @@ -236,7 +236,7 @@ deps = leidenalg changedir = {tox_root}{/}docs commands = - python -m ipykernel install --user --name=cellflow + python -m ipykernel install --user --name=scaleflow bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks [testenv:clean-docs] diff --git a/scripts/process_tahoe.py b/scripts/process_tahoe.py new file mode 100644 index 00000000..8d1669c5 --- /dev/null +++ b/scripts/process_tahoe.py @@ -0,0 +1,210 @@ +# %% +# %load_ext autoreload +# %autoreload 2 + + +# %% +import anndata as ad +import h5py +import zarr +from scaleflow.data._utils import write_sharded +from anndata.experimental import read_lazy +from scaleflow.data import DataManager +import cupy as cp +import tqdm +import dask +import concurrent.futures +from functools import partial +import numpy as np +import dask.array as da +from dask.diagnostics import ProgressBar + +print("loading data") +with h5py.File("/lustre/groups/ml01/workspace/100mil/100m_int_indices.h5ad", "r") as f: + adata_all = ad.AnnData( + obs=ad.io.read_elem(f["obs"]), + var=read_lazy(f["var"]), + uns = read_lazy(f["uns"]), + obsm = read_lazy(f["obsm"]), + ) + +dm = DataManager(adata_all, + sample_rep="X_pca", + control_key="control", + perturbation_covariates={"drugs": ("drug",), "dosage": ("dosage",)}, + perturbation_covariate_reps={"drugs": "drug_embeddings"}, + sample_covariates=["cell_line"], + sample_covariate_reps={"cell_line": "cell_line_embeddings"}, + split_covariates=["cell_line"], + max_combination_length=None, + null_value=0.0 +) +print("data loaded") + +# %% +cond_data = dm._get_condition_data(adata=adata_all) +cell_data = dm._get_cell_data(adata_all) + +# %% +n_source_dists = len(cond_data.split_idx_to_covariates) +n_target_dists = len(cond_data.perturbation_idx_to_covariates) + +tgt_cell_data = {} +src_cell_data = {} +gpu_per_cov_mask = cp.asarray(cond_data.perturbation_covariates_mask) +gpu_spl_cov_mask = cp.asarray(cond_data.split_covariates_mask) + +for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data idcs"): + mask = gpu_spl_cov_mask == src_idx + src_cell_data[str(src_idx)] = { + "cell_data_index": cp.where(mask)[0].get(), + } + +for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data idcs"): + mask = gpu_per_cov_mask == tgt_idx + tgt_cell_data[str(tgt_idx)] = { + "cell_data_index": cp.where(mask)[0].get(), + } + +# %% + +print("Computing cell data") +cell_data = cell_data.compute() +print("cell data computed") + +for src_idx in tqdm.tqdm(range(n_source_dists), desc="Computing source to cell data"): + indices = src_cell_data[str(src_idx)]["cell_data_index"] + src_cell_data[str(src_idx)]["cell_data"] = cell_data[indices] + +for tgt_idx in tqdm.tqdm(range(n_target_dists), desc="Computing target to cell data"): + indices = tgt_cell_data[str(tgt_idx)]["cell_data_index"] + tgt_cell_data[str(tgt_idx)]["cell_data"] = cell_data[indices] + + +# %% + +split_covariates_mask = np.asarray(cond_data.split_covariates_mask) +perturbation_covariates_mask = np.asarray(cond_data.perturbation_covariates_mask) +condition_data = {str(k): np.asarray(v) for k, v in (cond_data.condition_data or {}).items()} +control_to_perturbation = {str(k): np.asarray(v) for k, v in (cond_data.control_to_perturbation or {}).items()} +split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (cond_data.split_idx_to_covariates or {}).items()} +perturbation_idx_to_covariates = { + str(k): np.asarray(v) for k, v in (cond_data.perturbation_idx_to_covariates or {}).items() +} +perturbation_idx_to_id = {str(k): v for k, v in (cond_data.perturbation_idx_to_id or {}).items()} + +train_data_dict = { + "split_covariates_mask": split_covariates_mask, + "perturbation_covariates_mask": perturbation_covariates_mask, + "split_idx_to_covariates": split_idx_to_covariates, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "perturbation_idx_to_id": perturbation_idx_to_id, + "condition_data": condition_data, + "control_to_perturbation": control_to_perturbation, + "max_combination_length": int(cond_data.max_combination_length), + # "src_cell_data": src_cell_data, + # "tgt_cell_data": tgt_cell_data, +} + +print("prepared train_data_dict") +# %% +path = "/lustre/groups/ml01/workspace/100mil/tahoe.zarr" +zgroup = zarr.open_group(path, mode="w") +chunk_size = 131072 +shard_size = chunk_size * 8 + +ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr + +def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]: + shard_size_used = shard_size + chunk_size_used = chunk_size + if chunk_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + elif chunk_size < shape[0] or shard_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + return chunk_size_used, shard_size_used + + + + +def write_single_array(group, key, arr, idxs, chunk_size, shard_size): + """Write a single array - designed for threading""" + chunk_size_used, shard_size_used = get_size(arr.shape, chunk_size, shard_size) + + group.create_array( + name=key, + data=arr, + chunks=(chunk_size_used, arr.shape[1]), + shards=(shard_size_used, arr.shape[1]), + compressors=None, + ) + + group.create_array( + name=f"{key}_index", + data=idxs, + chunks=(len(idxs),), + shards=(len(idxs),), + compressors=None, + ) + return key + +def write_cell_data_threaded(group, cell_data, chunk_size, shard_size, max_workers=8): + """Write cell data using threading for I/O parallelism""" + + write_func = partial(write_single_array, group, chunk_size=chunk_size, shard_size=shard_size) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all write tasks + future_to_key = { + executor.submit(write_single_array, group, k, cell_data[k]["cell_data"], cell_data[k]["cell_data_index"], chunk_size, shard_size): k + for k in cell_data.keys() + } + + # Process results with progress bar + for future in tqdm.tqdm( + concurrent.futures.as_completed(future_to_key), + total=len(future_to_key), + desc=f"Writing {group.name}" + ): + key = future_to_key[future] + try: + future.result() # This will raise any exceptions + except Exception as exc: + print(f'Array {key} generated an exception: {exc}') + raise + +# %% + + +src_group = zgroup.create_group("src_cell_data", overwrite=True) +tgt_group = zgroup.create_group("tgt_cell_data", overwrite=True) + + +# Use the fast threaded approach you already implemented +write_cell_data_threaded(src_group, src_cell_data, chunk_size, shard_size, max_workers=24) +print("done writing src_cell_data") +write_cell_data_threaded(tgt_group, tgt_cell_data, chunk_size, shard_size, max_workers=24) +print("done writing tgt_cell_data") + + + + + + +# %% + +print("Writing mapping data") +mapping_data = zgroup.create_group("mapping_data", overwrite=True) + + +write_sharded( + group=mapping_data, + name="mapping_data", + data=train_data_dict, + chunk_size=chunk_size, + shard_size=shard_size, + compressors=None, +) +print("done") + + diff --git a/scripts/process_tahoe.sbatch b/scripts/process_tahoe.sbatch new file mode 100644 index 00000000..fecb5f55 --- /dev/null +++ b/scripts/process_tahoe.sbatch @@ -0,0 +1,17 @@ +#!/bin/zsh + +#SBATCH -o logs/process_tahoe.out +#SBATCH -e logs/process_tahoe.err +#SBATCH -J process_tahoe +#SBATCH --nice=1 +#SBATCH --time=23:00:00 +#SBATCH --partition=gpu_p +#SBATCH --qos=gpu_normal +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=26 +#SBATCH --mem=500G + +source /home/icb/selman.ozleyen/.zshrc + +mamba activate lpert +python /home/icb/selman.ozleyen/projects/CellFlow2/scripts/process_tahoe.py diff --git a/src/cellflow/__init__.py b/src/cellflow/__init__.py deleted file mode 100644 index 526fc741..00000000 --- a/src/cellflow/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from importlib import metadata - -import cellflow.preprocessing as pp -from cellflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/data/__init__.py b/src/cellflow/data/__init__.py deleted file mode 100644 index e6f6f2de..00000000 --- a/src/cellflow/data/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from cellflow.data._data import BaseDataMixin, ConditionData, PredictionData, TrainingData, ValidationData -from cellflow.data._dataloader import PredictionSampler, TrainSampler, ValidationSampler -from cellflow.data._datamanager import DataManager - -__all__ = [ - "DataManager", - "BaseDataMixin", - "ConditionData", - "PredictionData", - "TrainingData", - "ValidationData", - "TrainSampler", - "ValidationSampler", - "PredictionSampler", -] diff --git a/src/cellflow/data/_dataloader.py b/src/cellflow/data/_dataloader.py deleted file mode 100644 index 70bd91ef..00000000 --- a/src/cellflow/data/_dataloader.py +++ /dev/null @@ -1,303 +0,0 @@ -import abc -import queue -import threading -from collections.abc import Generator -from typing import Any, Literal - -import jax -import numpy as np - -from cellflow.data._data import PredictionData, TrainingData, ValidationData - -__all__ = ["TrainSampler", "ValidationSampler", "PredictionSampler", "OOCTrainSampler"] - - -class TrainSampler: - """Data sampler for :class:`~cellflow.data.TrainingData`. - - Parameters - ---------- - data - The training data. - batch_size - The batch size. - - """ - - def __init__(self, data: TrainingData, batch_size: int = 1024): - self._data = data - self._data_idcs = np.arange(data.cell_data.shape[0]) - self.batch_size = batch_size - self.n_source_dists = data.n_controls - self.n_target_dists = data.n_perturbations - - self._control_to_perturbation_keys = sorted(data.control_to_perturbation.keys()) - self._has_condition_data = data.condition_data is not None - - def _sample_target_dist_idx(self, source_dist_idx, rng): - """Sample a target distribution index given the source distribution index.""" - return rng.choice(self._data.control_to_perturbation[source_dist_idx]) - - def _get_embeddings(self, idx, condition_data) -> dict[str, np.ndarray]: - """Get embeddings for a given index.""" - result = {} - for key, arr in condition_data.items(): - result[key] = np.expand_dims(arr[idx], 0) - return result - - def _sample_from_mask(self, rng, mask) -> np.ndarray: - """Sample indices according to a mask.""" - # Convert mask to probability distribution - valid_indices = np.where(mask)[0] - - # Handle case with no valid indices (should not happen in practice) - if len(valid_indices) == 0: - raise ValueError("No valid indices found in the mask") - - # Sample from valid indices with equal probability - batch_idcs = rng.choice(valid_indices, self.batch_size, replace=True) - return batch_idcs - - def sample(self, rng) -> dict[str, Any]: - """Sample a batch of data. - - Parameters - ---------- - seed : int, optional - Random seed - - Returns - ------- - Dictionary with source and target data - """ - # Sample source distribution index - source_dist_idx = rng.integers(0, self.n_source_dists) - - # Get source cells - source_cells_mask = self._data.split_covariates_mask == source_dist_idx - source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) - source_batch = self._data.cell_data[source_batch_idcs] - - target_dist_idx = self._sample_target_dist_idx(source_dist_idx, rng) - target_cells_mask = self._data.perturbation_covariates_mask == target_dist_idx - target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) - target_batch = self._data.cell_data[target_batch_idcs] - - if not self._has_condition_data: - return {"src_cell_data": source_batch, "tgt_cell_data": target_batch} - else: - condition_batch = self._get_embeddings(target_dist_idx, self._data.condition_data) - return { - "src_cell_data": source_batch, - "tgt_cell_data": target_batch, - "condition": condition_batch, - } - - @property - def data(self): - """The training data.""" - return self._data - - -class BaseValidSampler(abc.ABC): - @abc.abstractmethod - def sample(*args, **kwargs): - pass - - def _get_key(self, cond_idx: int) -> tuple[str, ...]: - if len(self._data.perturbation_idx_to_id): # type: ignore[attr-defined] - return self._data.perturbation_idx_to_id[cond_idx] # type: ignore[attr-defined] - cov_combination = self._data.perturbation_idx_to_covariates[cond_idx] # type: ignore[attr-defined] - return tuple(cov_combination[i] for i in range(len(cov_combination))) - - def _get_perturbation_to_control(self, data: ValidationData | PredictionData) -> dict[int, np.ndarray]: - d = {} - for k, v in data.control_to_perturbation.items(): - for el in v: - d[el] = k - return d - - def _get_condition_data(self, cond_idx: int) -> dict[str, np.ndarray]: - return {k: v[[cond_idx], ...] for k, v in self._data.condition_data.items()} # type: ignore[attr-defined] - - -class ValidationSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.ValidationData`. - - Parameters - ---------- - val_data - The validation data. - seed - Random seed. - """ - - def __init__(self, val_data: ValidationData, seed: int = 0) -> None: - self._data = val_data - self.perturbation_to_control = self._get_perturbation_to_control(val_data) - self.n_conditions_on_log_iteration = ( - val_data.n_conditions_on_log_iteration - if val_data.n_conditions_on_log_iteration is not None - else val_data.n_perturbations - ) - self.n_conditions_on_train_end = ( - val_data.n_conditions_on_train_end - if val_data.n_conditions_on_train_end is not None - else val_data.n_perturbations - ) - self.rng = np.random.default_rng(seed) - if self._data.condition_data is None: - raise NotImplementedError("Validation data must have condition data.") - - def sample(self, mode: Literal["on_log_iteration", "on_train_end"]) -> Any: - """Sample data for validation. - - Parameters - ---------- - mode - Sampling mode. Either ``"on_log_iteration"`` or ``"on_train_end"``. - - Returns - ------- - Dictionary with source, condition, and target data from the validation data. - """ - size = self.n_conditions_on_log_iteration if mode == "on_log_iteration" else self.n_conditions_on_train_end - condition_idcs = self.rng.choice(self._data.n_perturbations, size=(size,), replace=False) - - source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs] - source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs] - source_cells = [self._data.cell_data[mask] for mask in source_cells_mask] - target_cells_mask = [cond_idx == self._data.perturbation_covariates_mask for cond_idx in condition_idcs] - target_cells = [self._data.cell_data[mask] for mask in target_cells_mask] - conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs] - cell_rep_dict = {} - cond_dict = {} - true_dict = {} - for i in range(len(condition_idcs)): - k = self._get_key(condition_idcs[i]) - cell_rep_dict[k] = source_cells[i] - cond_dict[k] = conditions[i] - true_dict[k] = target_cells[i] - - return {"source": cell_rep_dict, "condition": cond_dict, "target": true_dict} - - @property - def data(self) -> ValidationData: - """The validation data.""" - return self._data - - -class PredictionSampler(BaseValidSampler): - """Data sampler for :class:`~cellflow.data.PredictionData`. - - Parameters - ---------- - pred_data - The prediction data. - - """ - - def __init__(self, pred_data: PredictionData) -> None: - self._data = pred_data - self.perturbation_to_control = self._get_perturbation_to_control(pred_data) - if self._data.condition_data is None: - raise NotImplementedError("Validation data must have condition data.") - - def sample(self) -> Any: - """Sample data for prediction. - - Returns - ------- - Dictionary with source and condition data from the prediction data. - """ - condition_idcs = range(self._data.n_perturbations) - - source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs] - source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs] - source_cells = [self._data.cell_data[mask] for mask in source_cells_mask] - conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs] - cell_rep_dict = {} - cond_dict = {} - for i in range(len(condition_idcs)): - k = self._get_key(condition_idcs[i]) - cell_rep_dict[k] = source_cells[i] - cond_dict[k] = conditions[i] - - return { - "source": cell_rep_dict, - "condition": cond_dict, - } - - @property - def data(self) -> PredictionData: - """The training data.""" - return self._data - - -def prefetch_to_device( - sampler: TrainSampler, seed: int, num_iterations: int, prefetch_factor: int = 2, num_workers: int = 4 -) -> Generator[dict[str, Any], None, None]: - seq = np.random.SeedSequence(seed) - random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] - - q: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=prefetch_factor * num_workers) - sem = threading.Semaphore(num_iterations) - stop_event = threading.Event() - - def worker(rng: np.random.Generator): - while not stop_event.is_set() and sem.acquire(blocking=False): - batch = sampler.sample(rng) - batch = jax.device_put(batch, jax.devices()[0], donate=True) - jax.block_until_ready(batch) - while not stop_event.is_set(): - try: - q.put(batch, timeout=1.0) - break # Batch successfully put into the queue; break out of retry loop - except queue.Full: - continue - - return - - # Start multiple worker threads - ts = [] - for i in range(num_workers): - t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i],)) - t.start() - ts.append(t) - - try: - for _ in range(num_iterations): - # Yield batches from the queue; will block waiting for available batch - yield q.get() - finally: - # When the generator is closed or garbage collected, clean up the worker threads - stop_event.set() # Signal all workers to exit - for t in ts: - t.join() # Wait for all worker threads to finish - - -class OOCTrainSampler: - def __init__( - self, data: TrainingData, seed: int, batch_size: int = 1024, num_workers: int = 4, prefetch_factor: int = 2 - ): - self.inner = TrainSampler(data=data, batch_size=batch_size) - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor - self.seed = seed - self._iterator = None - - def set_sampler(self, num_iterations: int) -> None: - self._iterator = prefetch_to_device( - sampler=self.inner, seed=self.seed, num_iterations=num_iterations, prefetch_factor=self.prefetch_factor - ) - - def sample(self, rng=None) -> dict[str, Any]: - if self._iterator is None: - raise ValueError( - "Sampler not set. Use `set_sampler` to set the sampler with" - "the number of iterations. Without the number of iterations," - " the sampler will not be able to sample the data." - ) - if rng is not None: - del rng - return next(self._iterator) diff --git a/src/cellflow/data/_utils.py b/src/cellflow/data/_utils.py deleted file mode 100644 index c22d50bb..00000000 --- a/src/cellflow/data/_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from collections.abc import Iterable -from typing import Any - - -def _to_list(x: list[Any] | tuple[Any] | Any) -> list[Any] | tuple[Any]: - """Converts x to a list if it is not already a list or tuple.""" - if isinstance(x, (list | tuple)): - return x - return [x] - - -def _flatten_list(x: Iterable[Iterable[Any]]) -> list[Any]: - """Flattens a list of lists.""" - return [item for sublist in x for item in sublist] diff --git a/src/cellflow/external/__init__.py b/src/cellflow/external/__init__.py deleted file mode 100644 index 7a03a1c8..00000000 --- a/src/cellflow/external/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -try: - from cellflow.external._scvi import CFJaxSCVI -except ImportError as e: - raise ImportError( - "cellflow.external requires more dependencies. Please install via pip install 'cellflow[external]'" - ) from e diff --git a/src/cellflow/model/__init__.py b/src/cellflow/model/__init__.py deleted file mode 100644 index 8731f241..00000000 --- a/src/cellflow/model/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.model._cellflow import CellFlow - -__all__ = ["CellFlow"] diff --git a/src/cellflow/plotting/__init__.py b/src/cellflow/plotting/__init__.py deleted file mode 100644 index c7fd387e..00000000 --- a/src/cellflow/plotting/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cellflow.plotting._plotting import plot_condition_embedding - -__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/preprocessing/__init__.py b/src/cellflow/preprocessing/__init__.py deleted file mode 100644 index 21eaa993..00000000 --- a/src/cellflow/preprocessing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from cellflow.preprocessing._gene_emb import ( - GeneInfo, - get_esm_embedding, - prot_sequence_from_ensembl, - protein_features_from_genes, -) -from cellflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca -from cellflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints -from cellflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/solvers/__init__.py b/src/cellflow/solvers/__init__.py deleted file mode 100644 index a02a5510..00000000 --- a/src/cellflow/solvers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from cellflow.solvers._genot import GENOT -from cellflow.solvers._otfm import OTFlowMatching - -__all__ = ["GENOT", "OTFlowMatching"] diff --git a/src/cfp/__init__.py b/src/cfp/__init__.py deleted file mode 100644 index 76f0742f..00000000 --- a/src/cfp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from cellflow import * # noqa: F403 diff --git a/src/scaleflow/__init__.py b/src/scaleflow/__init__.py new file mode 100644 index 00000000..60891e49 --- /dev/null +++ b/src/scaleflow/__init__.py @@ -0,0 +1,4 @@ +from importlib import metadata + +import scaleflow.preprocessing as pp +from scaleflow import data, datasets, metrics, model, networks, solvers, training, utils diff --git a/src/cellflow/_constants.py b/src/scaleflow/_constants.py similarity index 57% rename from src/cellflow/_constants.py rename to src/scaleflow/_constants.py index 3b782c8d..92f38201 100644 --- a/src/cellflow/_constants.py +++ b/src/scaleflow/_constants.py @@ -1,4 +1,4 @@ -CONTROL_HELPER = "_cellflow_control" +CONTROL_HELPER = "_scaleflow_control" CONDITION_EMBEDDING = "condition_embedding" -CELLFLOW_KEY = "cellflow" +CELLFLOW_KEY = "scaleflow" GENOT_CELL_KEY = "cell_embedding_condition" diff --git a/src/cellflow/_logging.py b/src/scaleflow/_logging.py similarity index 100% rename from src/cellflow/_logging.py rename to src/scaleflow/_logging.py diff --git a/src/scaleflow/_optional.py b/src/scaleflow/_optional.py new file mode 100644 index 00000000..c05ccec0 --- /dev/null +++ b/src/scaleflow/_optional.py @@ -0,0 +1,9 @@ +class OptionalDependencyNotAvailable(ImportError): + pass + + +def torch_required_msg() -> str: + return ( + "Optional dependency 'torch' is required for this feature.\n" + "Install it via: pip install torch # or pip install 'scaleflow-tools[torch]'" + ) diff --git a/src/cellflow/_types.py b/src/scaleflow/_types.py similarity index 100% rename from src/cellflow/_types.py rename to src/scaleflow/_types.py diff --git a/src/scaleflow/compat/__init__.py b/src/scaleflow/compat/__init__.py new file mode 100644 index 00000000..82ff1adb --- /dev/null +++ b/src/scaleflow/compat/__init__.py @@ -0,0 +1,3 @@ +from .torch_ import TorchIterableDataset + +__all__ = ["TorchIterableDataset"] diff --git a/src/scaleflow/compat/torch_.py b/src/scaleflow/compat/torch_.py new file mode 100644 index 00000000..b79f134e --- /dev/null +++ b/src/scaleflow/compat/torch_.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from scaleflow._optional import OptionalDependencyNotAvailable, torch_required_msg + +try: + from torch.utils.data import IterableDataset as TorchIterableDataset # type: ignore + + TORCH_AVAILABLE = True +except ImportError as _: + TORCH_AVAILABLE = False + + class TorchIterableDataset: # noqa: D101 + def __init__(self, *args, **kwargs): + raise OptionalDependencyNotAvailable(torch_required_msg()) + + +if TYPE_CHECKING: + # keeps type checkers aligned with the real type + from torch.utils.data import IterableDataset as TorchIterableDataset # noqa: F401 diff --git a/src/scaleflow/data/__init__.py b/src/scaleflow/data/__init__.py new file mode 100644 index 00000000..23cf6b76 --- /dev/null +++ b/src/scaleflow/data/__init__.py @@ -0,0 +1,33 @@ +from scaleflow.data._data import ( + BaseDataMixin, + ConditionData, + PredictionData, + TrainingData, + ValidationData, + MappedCellData, +) +from scaleflow.data._dataloader import ( + PredictionSampler, + TrainSampler, + ReservoirSampler, + ValidationSampler, +) +from scaleflow.data._datamanager import DataManager +from scaleflow.data._jax_dataloader import JaxOutOfCoreTrainSampler +from scaleflow.data._torch_dataloader import TorchCombinedTrainSampler + +__all__ = [ + "DataManager", + "BaseDataMixin", + "ConditionData", + "PredictionData", + "TrainingData", + "ValidationData", + "MappedCellData", + "TrainSampler", + "ValidationSampler", + "PredictionSampler", + "TorchCombinedTrainSampler", + "JaxOutOfCoreTrainSampler", + "ReservoirSampler", +] diff --git a/src/cellflow/data/_data.py b/src/scaleflow/data/_data.py similarity index 50% rename from src/cellflow/data/_data.py rename to src/scaleflow/data/_data.py index 0f51d304..099114db 100644 --- a/src/cellflow/data/_data.py +++ b/src/scaleflow/data/_data.py @@ -1,11 +1,15 @@ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass from typing import Any -import jax import numpy as np +import zarr +from zarr.storage import LocalStore -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike +from scaleflow.data._utils import write_sharded __all__ = [ "BaseDataMixin", @@ -13,6 +17,7 @@ "PredictionData", "TrainingData", "ValidationData", + "MappedCellData", ] @@ -121,6 +126,65 @@ class TrainingData(BaseDataMixin): null_value: Any data_manager: Any + # --- Zarr export helpers ------------------------------------------------- + def write_zarr( + self, + path: str, + *, + chunk_size: int = 4096, + shard_size: int = 65536, + compressors: Any | None = None, + ) -> None: + """Write this training data to Zarr v3 with sharded, compressed arrays. + + Parameters + ---------- + path + Path to a Zarr group to create or open for writing. + chunk_size + Chunk size along the first axis. + shard_size + Shard size along the first axis. + compressors + Optional list/tuple of Zarr codecs. If ``None``, a sensible default is used. + """ + # Convert to numpy-backed containers for serialization + cell_data = np.asarray(self.cell_data) + split_covariates_mask = np.asarray(self.split_covariates_mask) + perturbation_covariates_mask = np.asarray(self.perturbation_covariates_mask) + condition_data = {str(k): np.asarray(v) for k, v in (self.condition_data or {}).items()} + control_to_perturbation = {str(k): np.asarray(v) for k, v in (self.control_to_perturbation or {}).items()} + split_idx_to_covariates = {str(k): np.asarray(v) for k, v in (self.split_idx_to_covariates or {}).items()} + perturbation_idx_to_covariates = { + str(k): np.asarray(v) for k, v in (self.perturbation_idx_to_covariates or {}).items() + } + perturbation_idx_to_id = {str(k): v for k, v in (self.perturbation_idx_to_id or {}).items()} + + train_data_dict: dict[str, Any] = { + "cell_data": cell_data, + "split_covariates_mask": split_covariates_mask, + "perturbation_covariates_mask": perturbation_covariates_mask, + "split_idx_to_covariates": split_idx_to_covariates, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "perturbation_idx_to_id": perturbation_idx_to_id, + "condition_data": condition_data, + "control_to_perturbation": control_to_perturbation, + "max_combination_length": int(self.max_combination_length), + } + + additional_kwargs = {} + if compressors is not None: + additional_kwargs["compressors"] = compressors + + zgroup = zarr.open_group(path, mode="w") + write_sharded( + zgroup, + train_data_dict, + chunk_size=chunk_size, + shard_size=shard_size, + **additional_kwargs, + ) + @dataclass class ValidationData(BaseDataMixin): @@ -171,6 +235,9 @@ class ValidationData(BaseDataMixin): n_conditions_on_train_end: int | None = None + + + @dataclass class PredictionData(BaseDataMixin): """Data container to perform prediction. @@ -191,8 +258,8 @@ class PredictionData(BaseDataMixin): Token to use for masking ``null_value``. """ - cell_data: jax.Array # (n_cells, n_features) - split_covariates_mask: jax.Array # (n_cells,), which cell assigned to which source distribution + cell_data: ArrayLike # (n_cells, n_features) + split_covariates_mask: ArrayLike # (n_cells,), which cell assigned to which source distribution split_idx_to_covariates: dict[int, tuple[Any, ...]] # (n_sources,) dictionary explaining split_covariates_mask perturbation_idx_to_covariates: dict[ int, tuple[str, ...] @@ -203,3 +270,99 @@ class PredictionData(BaseDataMixin): max_combination_length: int null_value: Any data_manager: Any + + +@dataclass +class MappedCellData(BaseDataMixin): + """Lazy, Zarr-backed variant of :class:`TrainingData`. + + Fields mirror those in :class:`TrainingData`, but array-like members are + Zarr arrays or Zarr-backed mappings. This enables out-of-core training and + composition without loading everything into memory. + + Use :meth:`read_zarr` to construct from a Zarr v3 group written via + :meth:`TrainingData.to_zarr`. + """ + + # Note: annotations use Any to allow zarr.Array and zarr groups without + # importing zarr at module import time. + src_cell_data: dict[str, Any] + tgt_cell_data: dict[str, Any] + src_cell_idx: dict[str, Any] + tgt_cell_idx: dict[str, Any] + split_covariates_mask: Any + perturbation_covariates_mask: Any + split_idx_to_covariates: dict[int, tuple[Any, ...]] + perturbation_idx_to_covariates: dict[int, tuple[str, ...]] + perturbation_idx_to_id: dict[int, Any] + condition_data: dict[str, Any] + control_to_perturbation: dict[int, Any] + max_combination_length: int + mapping_data_full_cached: bool = False + + def __post_init__(self): + # load everything except cell_data to memory + + # load masks as numpy arrays + self.condition_data = {k: np.asarray(v) for k, v in self.condition_data.items()} + self.control_to_perturbation = {int(k): np.asarray(v) for k, v in self.control_to_perturbation.items()} + if self.mapping_data_full_cached: + # used in validation usually + self.perturbation_idx_to_id = {int(k): np.asarray(v) for k, v in self.perturbation_idx_to_id.items()} + self.perturbation_idx_to_covariates = { + int(k): np.asarray(v) for k, v in self.perturbation_idx_to_covariates.items() + } + # not used in nested structure + self.src_cell_idx = self.src_cell_idx[...] + self.tgt_cell_idx = self.tgt_cell_idx[...] + self.split_covariates_mask = self.split_covariates_mask[...] + self.perturbation_covariates_mask = self.perturbation_covariates_mask[...] + self.split_idx_to_covariates = {int(k): np.asarray(v) for k, v in self.split_idx_to_covariates.items()} + + @staticmethod + def _get_mapping_data(group: zarr.Group) -> dict[str, Any]: + return group["mapping_data"]["mapping_data"] + + @staticmethod + def _read_dict(zgroup: zarr.Group, key: str) -> dict[int, Any]: + keys = zgroup[key].keys() + return {k: zgroup[key][k] for k in keys} + + @staticmethod + def _read_cell_data(zgroup: zarr.Group, key: str) -> dict[int, Any]: + keys = sorted(zgroup[key].keys()) + data_key = [k for k in keys if not k.endswith("_index")] + return {int(k): zgroup[key][k] for k in data_key}, {int(k): zgroup[key][f"{k}_index"] for k in data_key} + + @classmethod + def read_zarr(cls, path: str) -> MappedCellData: + if isinstance(path, str): + path = LocalStore(path, read_only=True) + group = zarr.open_group(path, mode="r") + max_len_node = group.get("max_combination_length") + if max_len_node is None: + max_combination_length = 0 + else: + try: + max_combination_length = int(max_len_node[()]) + except Exception: # noqa: BLE001 + max_combination_length = int(max_len_node) + + mapping_group = cls._get_mapping_data(group) + + src_cell_data, src_cell_idx = cls._read_cell_data(group, "src_cell_data") + tgt_cell_data, tgt_cell_idx = cls._read_cell_data(group, "tgt_cell_data") + return cls( + tgt_cell_data=tgt_cell_data, + tgt_cell_idx=tgt_cell_idx, + src_cell_data=src_cell_data, + src_cell_idx=src_cell_idx, + split_covariates_mask=mapping_group["split_covariates_mask"], + perturbation_covariates_mask=mapping_group["perturbation_covariates_mask"], + split_idx_to_covariates=cls._read_dict(mapping_group, "split_idx_to_covariates"), + perturbation_idx_to_covariates=cls._read_dict(mapping_group, "perturbation_idx_to_covariates"), + perturbation_idx_to_id=cls._read_dict(mapping_group, "perturbation_idx_to_id"), + condition_data=cls._read_dict(mapping_group, "condition_data"), + control_to_perturbation=cls._read_dict(mapping_group, "control_to_perturbation"), + max_combination_length=max_combination_length, + ) diff --git a/src/scaleflow/data/_data_splitter.py b/src/scaleflow/data/_data_splitter.py new file mode 100644 index 00000000..135f6789 --- /dev/null +++ b/src/scaleflow/data/_data_splitter.py @@ -0,0 +1,1043 @@ +"""Data splitter for creating train/validation/test splits from TrainingData objects.""" + +import logging +import warnings +from pathlib import Path +from typing import Literal + +import numpy as np +from sklearn.model_selection import train_test_split + +from scaleflow.data._data import MappedCellData, TrainingData + +logger = logging.getLogger(__name__) + +SplitType = Literal["holdout_groups", "holdout_combinations", "random", "stratified"] + + +class DataSplitter: + """ + A lightweight class for creating train/validation/test splits from TrainingData objects. + + This class extracts metadata from TrainingData objects and returns split indices, + making it memory-efficient for large datasets. + + Supports various splitting strategies: + - holdout_groups: Hold out specific groups (drugs, cell lines, donors, etc.) for validation/test + - holdout_combinations: Keep single treatments in training, hold out combination treatments for validation/test + - random: Random split of observations + - stratified: Stratified split maintaining condition proportions + + Parameters + ---------- + training_datasets : list[TrainingData | MappedCellData] + List of TrainingData or MappedCellData objects to process + dataset_names : list[str] + List of names for each dataset (for saving/loading) + split_ratios : list[list[float]] + List of triples, each indicating [train, validation, test] ratios for each dataset. + Each triple must sum to 1.0. Length must match training_datasets. + split_type : SplitType + Type of split to perform + split_key : str | list[str] | None + Column name(s) in adata.obs to use for splitting (required for holdout_groups and holdout_combinations). + Can be a single column or list of columns for combination treatments. + force_training_values : list[str] | None + Values that should be forced to appear only in training (e.g., ['control', 'dmso']). + These values will never appear in validation or test sets. + control_value : str | list[str] | None + Value(s) that represent control/untreated condition (e.g., 'control' or ['control', 'dmso']). + Required for holdout_combinations split type. + hard_test_split : bool + If True, validation and test get completely different groups (no overlap). + If False, validation and test can share groups, split at cell level. + Applies to all split types for consistent val/test separation control. + random_state : int + Random seed for reproducible splits. This controls: + - Observation-level splits in soft mode (hard_test_split=False) + - Fallback for test_random_state and val_random_state if they are None + Note: In hard mode with test_random_state and val_random_state specified, + this parameter only affects downstream training randomness (not DataSplitter itself). + test_random_state : int | None + Random seed specifically for selecting which conditions go to the test set. + If None, uses random_state as fallback. Only applies to 'holdout_groups' and + 'holdout_combinations' split types. This enables running multiple experiments with + different train/val splits while keeping the test set fixed for fair comparison. + val_random_state : int | None + Random seed specifically for selecting which conditions go to the validation set + (from the remaining conditions after test set selection). If None, uses random_state + as fallback. Only applies to 'holdout_groups' and 'holdout_combinations' split types. + This enables varying the validation set across runs while keeping test set fixed + + Examples + -------- + >>> # Example 1: Basic split with forced training values + >>> splitter = DataSplitter( + ... training_datasets=[train_data1, train_data2], + ... dataset_names=["dataset1", "dataset2"], + ... split_ratios=[[0.8, 0.2, 0.0], [0.9, 0.1, 0.0]], + ... split_type="holdout_groups", + ... split_key=["drug1", "drug2"], + ... force_training_values=["control", "dmso"], + ... ) + + >>> # Example 2: Split by holding out combinations (singletons in training) + >>> splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["dataset"], + ... split_ratios=[[0.8, 0.2, 0.0]], + ... split_type="holdout_combinations", + ... split_key=["drug1", "drug2"], + ... control_value=["control", "dmso"], + ... ) + + >>> # Example 3: Fixed test set across multiple runs (for drug discovery benchmarking) + >>> # All runs will test on the same drugs, but with different train/val splits + >>> for seed in [42, 43, 44, 45]: + ... splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["experiment"], + ... split_ratios=[[0.6, 0.2, 0.2]], + ... split_type="holdout_groups", + ... split_key=["drug"], + ... test_random_state=999, # Fixed: same test drugs across all runs + ... val_random_state=seed, # Varies: different validation drugs per run + ... random_state=seed, # Varies: different training randomness + ... ) + ... results = splitter.split_all_datasets() + ... # Train model with this split... + + >>> # Example 4: Completely different splits per run (current behavior) + >>> for seed in [42, 43, 44]: + ... splitter = DataSplitter( + ... training_datasets=[train_data], + ... dataset_names=["experiment"], + ... split_ratios=[[0.8, 0.1, 0.1]], + ... split_type="holdout_groups", + ... split_key=["drug"], + ... random_state=seed, # All three seeds derived from this + ... ) + + >>> # Save and load splits + >>> results = splitter.split_all_datasets() + >>> splitter.save_splits("./splits") + >>> split_info = DataSplitter.load_split_info("./splits", "dataset1") + >>> train_indices = split_info["indices"]["train"] + """ + + def __init__( + self, + training_datasets: list[TrainingData | MappedCellData], + dataset_names: list[str], + split_ratios: list[list[float]], + split_type: SplitType = "random", + split_key: str | list[str] | None = None, + force_training_values: list[str] | None = None, + control_value: str | list[str] | None = None, + hard_test_split: bool = True, + random_state: int = 42, + test_random_state: int | None = None, + val_random_state: int | None = None, + ): + self.training_datasets = training_datasets + self.dataset_names = dataset_names + self.split_ratios = split_ratios + self.split_type = split_type + self.split_key = split_key + self.force_training_values = force_training_values or [] + self.control_value = [control_value] if isinstance(control_value, str) else control_value + self.hard_test_split = hard_test_split + self.random_state = random_state + self.test_random_state = test_random_state if test_random_state is not None else random_state + self.val_random_state = val_random_state if val_random_state is not None else random_state + + self._validate_inputs() + + self.split_results: dict[str, dict] = {} + + def _validate_inputs(self) -> None: + """Validate input parameters.""" + if len(self.training_datasets) != len(self.dataset_names): + raise ValueError( + f"training_datasets length ({len(self.training_datasets)}) must match " + f"dataset_names length ({len(self.dataset_names)})" + ) + + if not isinstance(self.split_ratios, list): + raise ValueError("split_ratios must be a list of lists") + + if len(self.split_ratios) != len(self.training_datasets): + raise ValueError( + f"split_ratios length ({len(self.split_ratios)}) must match " + f"training_datasets length ({len(self.training_datasets)})" + ) + + # Check each split ratio + for i, ratios in enumerate(self.split_ratios): + if not isinstance(ratios, list) or len(ratios) != 3: + raise ValueError(f"split_ratios[{i}] must be a list of 3 values [train, val, test]") + + if not np.isclose(sum(ratios), 1.0): + raise ValueError(f"split_ratios[{i}] must sum to 1.0, got {sum(ratios)}") + + if any(ratio < 0 for ratio in ratios): + raise ValueError(f"All values in split_ratios[{i}] must be non-negative") + + # Check split key requirement + if self.split_type in ["holdout_groups", "holdout_combinations"] and self.split_key is None: + raise ValueError(f"split_key must be provided for split_type '{self.split_type}'") + + # Check control_value requirement for holdout_combinations + if self.split_type == "holdout_combinations" and self.control_value is None: + raise ValueError("control_value must be provided for split_type 'holdout_combinations'") + + for i, td in enumerate(self.training_datasets): + if not isinstance(td, (TrainingData, MappedCellData)): + raise ValueError(f"training_datasets[{i}] must be a TrainingData or MappedCellData object") + + def extract_perturbation_info(self, training_data: TrainingData | MappedCellData) -> dict: + """ + Extract condition information from TrainingData or MappedCellData. + + Note: Internal variable names use 'perturbation' for compatibility with + TrainingData structure, but conceptually these represent any conditions + (drugs, cell lines, donors, etc.). + + Parameters + ---------- + training_data : TrainingData | MappedCellData + Training data object + + Returns + ------- + dict + Dictionary containing: + - perturbation_covariates_mask: array mapping observations to condition indices + - perturbation_idx_to_covariates: dict mapping condition indices to covariate tuples + - n_cells: total number of observations + """ + perturbation_covariates_mask = np.asarray(training_data.perturbation_covariates_mask) + perturbation_idx_to_covariates = training_data.perturbation_idx_to_covariates + + n_cells = len(perturbation_covariates_mask) + + logger.info(f"Extracted condition info for {n_cells} observations") + logger.info(f"Number of unique conditions: {len(perturbation_idx_to_covariates)}") + + return { + "perturbation_covariates_mask": perturbation_covariates_mask, + "perturbation_idx_to_covariates": perturbation_idx_to_covariates, + "n_cells": n_cells, + } + + def _get_unique_perturbation_values( + self, perturbation_idx_to_covariates: dict[int, tuple[str, ...]] + ) -> list[str]: + """Get all unique covariate values from perturbation dictionary.""" + all_unique_vals = set() + for covariates in perturbation_idx_to_covariates.values(): + all_unique_vals.update(covariates) + return list(all_unique_vals) + + def _split_random(self, n_cells: int, split_ratios: list[float]) -> dict[str, np.ndarray]: + """Perform random split of cells.""" + train_ratio, val_ratio, test_ratio = split_ratios + + # Generate random indices + indices = np.arange(n_cells) + np.random.seed(self.random_state) + np.random.shuffle(indices) + + if self.hard_test_split: + # HARD: Val and test are completely separate + train_end = int(train_ratio * n_cells) + val_end = train_end + int(val_ratio * n_cells) + + train_idx = indices[:train_end] + val_idx = indices[train_end:val_end] if val_ratio > 0 else np.array([]) + test_idx = indices[val_end:] if test_ratio > 0 else np.array([]) + + logger.info("HARD RANDOM SPLIT: Completely separate val/test") + else: + # SOFT: Val and test can overlap (split val+test at cell level) + train_end = int(train_ratio * n_cells) + train_idx = indices[:train_end] + val_test_idx = indices[train_end:] + + # Split val+test according to val/test ratios + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT RANDOM SPLIT: Val/test can overlap") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_by_values( + self, + perturbation_covariates_mask: np.ndarray, + perturbation_idx_to_covariates: dict[int, tuple[str, ...]], + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Split by holding out specific condition groups.""" + if self.split_key is None: + raise ValueError("split_key must be provided for holdout_groups splitting") + + # Get all unique covariate values + unique_values = self._get_unique_perturbation_values(perturbation_idx_to_covariates) + + # Remove forced training values from consideration for val/test splits + available_values = [v for v in unique_values if v not in self.force_training_values] + forced_train_values = [v for v in unique_values if v in self.force_training_values] + + logger.info(f"Total unique values: {len(unique_values)}") + logger.info(f"Forced training values: {forced_train_values}") + logger.info(f"Available for val/test: {len(available_values)}") + + n_values = len(available_values) + + if n_values < 3: + warnings.warn( + f"Only {n_values} unique values found across columns {self.split_key}. " + "Consider using random split instead.", + stacklevel=2, + ) + + # Split values according to ratios using three-level seed hierarchy + train_ratio, val_ratio, test_ratio = split_ratios + + # Calculate number of values for each split + n_test = int(test_ratio * n_values) + n_val = int(val_ratio * n_values) + n_train = n_values - n_test - n_val + + # Ensure we have at least one value for train if train_ratio > 0 + if train_ratio > 0 and n_train == 0: + n_train = 1 + n_test = max(0, n_test - 1) + + # Step 1: Select test values using test_random_state + np.random.seed(self.test_random_state) + shuffled_for_test = np.random.permutation(available_values) + test_values = shuffled_for_test[-n_test:] if n_test > 0 else [] + remaining_after_test = shuffled_for_test[:-n_test] if n_test > 0 else shuffled_for_test + + # Step 2: Select val values from remaining using val_random_state + np.random.seed(self.val_random_state) + shuffled_for_val = np.random.permutation(remaining_after_test) + val_values = shuffled_for_val[-n_val:] if n_val > 0 else [] + train_values_random = shuffled_for_val[:-n_val] if n_val > 0 else shuffled_for_val + + # Step 3: Combine forced training values with randomly assigned training values + train_values = list(train_values_random) + forced_train_values + + logger.info(f"Split values - Train: {len(train_values)}, Val: {len(val_values)}, Test: {len(test_values)}") + logger.info(f"Train values: {train_values}") + logger.info(f"Val values: {val_values}") + logger.info(f"Test values: {test_values}") + + # Create masks by checking which perturbation indices contain which values + def _get_cells_with_values(values_set): + """Get cell indices for perturbations containing any of the specified values.""" + if len(values_set) == 0: + return np.array([], dtype=int) + + # Find perturbation indices that contain any of these values + matching_pert_indices = [] + for pert_idx, covariates in perturbation_idx_to_covariates.items(): + if any(val in covariates for val in values_set): + matching_pert_indices.append(pert_idx) + + # Get cells with these perturbation indices + if len(matching_pert_indices) == 0: + return np.array([], dtype=int) + + cell_mask = np.isin(perturbation_covariates_mask, matching_pert_indices) + return np.where(cell_mask)[0] + + if self.hard_test_split: + # HARD: Val and test get different values (existing logic) + train_idx = _get_cells_with_values(train_values) + val_idx = _get_cells_with_values(val_values) + test_idx = _get_cells_with_values(test_values) + + logger.info("HARD HOLDOUT GROUPS: Val and test get different values") + else: + # SOFT: Val and test can share values, split at cell level + train_values_all = list(train_values_random) + forced_train_values + val_test_values = list(val_values) + list(test_values) + + train_idx = _get_cells_with_values(train_values_all) + val_test_idx = _get_cells_with_values(val_test_values) + + # Split val+test cells according to val/test ratios + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT HOLDOUT GROUPS: Val/test can share values") + + # Log overlap information (important for combination treatments) + total_assigned = len(set(train_idx) | set(val_idx) | set(test_idx)) + logger.info(f"Total observations assigned to splits: {total_assigned} out of {len(perturbation_covariates_mask)}") + + overlaps = [] + if len(set(train_idx) & set(val_idx)) > 0: + overlaps.append("train-val") + if len(set(train_idx) & set(test_idx)) > 0: + overlaps.append("train-test") + if len(set(val_idx) & set(test_idx)) > 0: + overlaps.append("val-test") + + if overlaps: + logger.warning( + f"Found overlapping cells between splits: {overlaps}. This is expected with combination treatments." + ) + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_holdout_combinations( + self, + perturbation_covariates_mask: np.ndarray, + perturbation_idx_to_covariates: dict[int, tuple[str, ...]], + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Split by keeping single conditions in training and holding out combinations for val/test.""" + if self.split_key is None: + raise ValueError("split_key must be provided for holdout_combinations splitting") + if self.control_value is None: + raise ValueError("control_value must be provided for holdout_combinations splitting") + + logger.info("Identifying combinations vs singletons from condition covariates") + logger.info(f"Control value(s): {self.control_value}") + + # Classify each perturbation index as control, singleton, or combination + control_pert_indices = [] + singleton_pert_indices = [] + combination_pert_indices = [] + + for pert_idx, covariates in perturbation_idx_to_covariates.items(): + non_control_values = [c for c in covariates if c not in self.control_value] + n_non_control = len(non_control_values) + + if n_non_control == 0: + control_pert_indices.append(pert_idx) + elif n_non_control == 1: + singleton_pert_indices.append(pert_idx) + else: + combination_pert_indices.append(pert_idx) + + # Get cell indices for each type + if len(control_pert_indices) > 0: + control_mask = np.isin(perturbation_covariates_mask, control_pert_indices) + else: + control_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + if len(singleton_pert_indices) > 0: + singleton_mask = np.isin(perturbation_covariates_mask, singleton_pert_indices) + else: + singleton_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + if len(combination_pert_indices) > 0: + combination_mask = np.isin(perturbation_covariates_mask, combination_pert_indices) + else: + combination_mask = np.zeros(len(perturbation_covariates_mask), dtype=bool) + + # Count each type + n_combinations = combination_mask.sum() + n_singletons = singleton_mask.sum() + n_controls = control_mask.sum() + + logger.info(f"Found {n_combinations} combination treatments") + logger.info(f"Found {n_singletons} singleton treatments") + logger.info(f"Found {n_controls} control treatments") + + if n_combinations == 0: + warnings.warn("No combination treatments found. Consider using 'holdout_groups' instead.", stacklevel=2) + + # Get indices for each type + combination_indices = np.where(combination_mask)[0] + singleton_indices = np.where(singleton_mask)[0] + control_indices = np.where(control_mask)[0] + + # All singletons and controls go to training + train_idx = np.concatenate([singleton_indices, control_indices]) + + # Split combinations according to the provided ratios + train_ratio, val_ratio, test_ratio = split_ratios + + if n_combinations > 0: + # Get perturbation identifiers for combination cells + # Map each cell to its perturbation tuple (non-control values only) + perturbation_ids = [] + for cell_idx in combination_indices: + pert_idx = perturbation_covariates_mask[cell_idx] + covariates = perturbation_idx_to_covariates[pert_idx] + # Extract non-control values + non_control_vals = [c for c in covariates if c not in self.control_value] + perturbation_id = tuple(sorted(non_control_vals)) + perturbation_ids.append(perturbation_id) + + # Get unique perturbation combinations + unique_perturbations = list(set(perturbation_ids)) + n_unique_perturbations = len(unique_perturbations) + + logger.info(f"Found {n_unique_perturbations} unique condition combinations") + + if self.hard_test_split: + # HARD TEST SPLIT: Val and test get completely different conditions + # Calculate number of perturbation combinations for each split + n_test_perturbations = int(test_ratio * n_unique_perturbations) + n_val_perturbations = int(val_ratio * n_unique_perturbations) + n_train_perturbations = n_unique_perturbations - n_test_perturbations - n_val_perturbations + + # Ensure we have at least one perturbation for train if train_ratio > 0 + if train_ratio > 0 and n_train_perturbations == 0: + n_train_perturbations = 1 + n_test_perturbations = max(0, n_test_perturbations - 1) + + # Step 1: Select test perturbations using test_random_state + np.random.seed(self.test_random_state) + shuffled_for_test = np.random.permutation(unique_perturbations) + test_perturbations = ( + [tuple(p) for p in shuffled_for_test[-n_test_perturbations:]] if n_test_perturbations > 0 else [] + ) + remaining_after_test = ( + shuffled_for_test[:-n_test_perturbations] if n_test_perturbations > 0 else shuffled_for_test + ) + + # Step 2: Select val perturbations from remaining using val_random_state + np.random.seed(self.val_random_state) + shuffled_for_val = np.random.permutation(remaining_after_test) + val_perturbations = ( + [tuple(p) for p in shuffled_for_val[-n_val_perturbations:]] if n_val_perturbations > 0 else [] + ) + train_perturbations = ( + [tuple(p) for p in shuffled_for_val[:-n_val_perturbations]] if n_val_perturbations > 0 else [tuple(p) for p in shuffled_for_val] + ) + + # Assign all cells with same perturbation to same split + train_combo_idx = [] + val_combo_idx = [] + test_combo_idx = [] + + for i, perturbation_id in enumerate(perturbation_ids): + cell_idx = combination_indices[i] + if perturbation_id in train_perturbations: + train_combo_idx.append(cell_idx) + elif perturbation_id in val_perturbations: + val_combo_idx.append(cell_idx) + elif perturbation_id in test_perturbations: + test_combo_idx.append(cell_idx) + + logger.info( + f"HARD TEST SPLIT - Condition split: Train={len(train_perturbations)}, Val={len(val_perturbations)}, Test={len(test_perturbations)}" + ) + if len(test_perturbations) > 0: + logger.info(f"Test perturbations: {list(test_perturbations)[:3]}") + if len(val_perturbations) > 0: + logger.info(f"Val perturbations: {list(val_perturbations)[:3]}") + + else: + # SOFT TEST SPLIT: Val and test can share conditions, split at cell level + # First assign conditions to train vs (val+test) using test_random_state + # (In soft mode, val and test share conditions, so we only need one seed for this split) + n_train_perturbations = int(train_ratio * n_unique_perturbations) + n_val_test_perturbations = n_unique_perturbations - n_train_perturbations + + # Shuffle perturbations using test_random_state + np.random.seed(self.test_random_state) + shuffled_perturbations = np.random.permutation(unique_perturbations) + + train_perturbations = ( + shuffled_perturbations[:n_train_perturbations] if n_train_perturbations > 0 else [] + ) + val_test_perturbations = ( + shuffled_perturbations[n_train_perturbations:] if n_val_test_perturbations > 0 else [] + ) + + # Get cells for train perturbations (all go to train) + train_combo_idx = [] + val_test_combo_idx = [] + + for i, perturbation_id in enumerate(perturbation_ids): + cell_idx = combination_indices[i] + if perturbation_id in train_perturbations: + train_combo_idx.append(cell_idx) + else: + val_test_combo_idx.append(cell_idx) + + # Now split val_test cells according to val/test ratios + if len(val_test_combo_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + np.random.seed(self.random_state + 1) # Different seed for cell-level split + + val_combo_idx, test_combo_idx = train_test_split( + val_test_combo_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_combo_idx = np.array([]) + test_combo_idx = np.array([]) + + logger.info( + f"SOFT TEST SPLIT - Condition split: Train={len(train_perturbations)}, Val+Test={len(val_test_perturbations)}" + ) + logger.info(f"Cell split within Val+Test: Val={len(val_combo_idx)}, Test={len(test_combo_idx)}") + + # Convert to numpy arrays + train_combo_idx = np.array(train_combo_idx) + val_combo_idx = np.array(val_combo_idx) + test_combo_idx = np.array(test_combo_idx) + + # Combine singletons/controls with assigned combinations + train_idx = np.concatenate([train_idx, train_combo_idx]) + val_idx = val_combo_idx + test_idx = test_combo_idx + + logger.info( + f"Final cell split: Train={len(train_combo_idx)}, Val={len(val_combo_idx)}, Test={len(test_combo_idx)}" + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info( + f"Final split - Train: {len(train_idx)} (singletons + controls + {len(train_combo_idx) if n_combinations > 0 else 0} combination observations)" + ) + logger.info(f"Final split - Val: {len(val_idx)} (combination observations only)") + logger.info(f"Final split - Test: {len(test_idx)} (combination observations only)") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def _split_stratified( + self, + perturbation_covariates_mask: np.ndarray, + split_ratios: list[float], + ) -> dict[str, np.ndarray]: + """Perform stratified split maintaining proportions of conditions.""" + if self.split_key is None: + raise ValueError("split_key must be provided for stratified splitting") + + train_ratio, val_ratio, test_ratio = split_ratios + # Use perturbation indices as stratification labels + labels = perturbation_covariates_mask + indices = np.arange(len(perturbation_covariates_mask)) + + if self.hard_test_split: + # HARD: Val and test get different stratification groups (existing logic) + if val_ratio + test_ratio > 0: + train_idx, temp_idx = train_test_split( + indices, train_size=train_ratio, stratify=labels, random_state=self.random_state + ) + + if val_ratio > 0 and test_ratio > 0: + temp_labels = labels[temp_idx] + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + temp_idx, train_size=val_size, stratify=temp_labels, random_state=self.random_state + ) + elif val_ratio > 0: + val_idx = temp_idx + test_idx = np.array([]) + else: + val_idx = np.array([]) + test_idx = temp_idx + else: + train_idx = indices + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("HARD STRATIFIED SPLIT: Val and test get different strata") + else: + # SOFT: Val and test can share stratification groups, split at cell level + if val_ratio + test_ratio > 0: + train_idx, val_test_idx = train_test_split( + indices, train_size=train_ratio, stratify=labels, random_state=self.random_state + ) + + # Split val+test cells (not stratified) + if len(val_test_idx) > 0 and val_ratio + test_ratio > 0: + val_size = val_ratio / (val_ratio + test_ratio) + val_idx, test_idx = train_test_split( + val_test_idx, train_size=val_size, random_state=self.random_state + 1 + ) + else: + val_idx = np.array([]) + test_idx = np.array([]) + else: + train_idx = indices + val_idx = np.array([]) + test_idx = np.array([]) + + logger.info("SOFT STRATIFIED SPLIT: Val/test can share strata") + + return {"train": train_idx, "val": val_idx, "test": test_idx} + + def split_single_dataset(self, training_data: TrainingData | MappedCellData, dataset_index: int) -> dict: + """ + Split a single TrainingData or MappedCellData object according to the specified strategy. + + Parameters + ---------- + training_data : TrainingData | MappedCellData + Training data object to split + dataset_index : int + Index of the dataset to get the correct split ratios + + Returns + ------- + dict + Dictionary containing split indices and metadata + """ + # Extract perturbation information + pert_info = self.extract_perturbation_info(training_data) + perturbation_covariates_mask = pert_info["perturbation_covariates_mask"] + perturbation_idx_to_covariates = pert_info["perturbation_idx_to_covariates"] + n_cells = pert_info["n_cells"] + + # Get split ratios for this specific dataset + current_split_ratios = self.split_ratios[dataset_index] + + # Perform split based on strategy + if self.split_type == "random": + split_indices = self._split_random(n_cells, current_split_ratios) + elif self.split_type == "holdout_groups": + split_indices = self._split_by_values( + perturbation_covariates_mask, perturbation_idx_to_covariates, current_split_ratios + ) + elif self.split_type == "holdout_combinations": + split_indices = self._split_holdout_combinations( + perturbation_covariates_mask, perturbation_idx_to_covariates, current_split_ratios + ) + elif self.split_type == "stratified": + split_indices = self._split_stratified(perturbation_covariates_mask, current_split_ratios) + else: + raise ValueError(f"Unknown split_type: {self.split_type}") + + # Create result dictionary with indices and metadata + result = { + "indices": split_indices, + "metadata": { + "total_cells": n_cells, + "split_type": self.split_type, + "split_key": self.split_key, + "split_ratios": current_split_ratios, + "random_state": self.random_state, + "test_random_state": self.test_random_state, + "val_random_state": self.val_random_state, + "hard_test_split": self.hard_test_split, + }, + } + + # Add force_training_values and control_value to metadata + if self.force_training_values: + result["metadata"]["force_training_values"] = self.force_training_values + if self.control_value: + result["metadata"]["control_value"] = self.control_value + + # Add split values information if applicable + if self.split_type in ["holdout_groups", "holdout_combinations"] and self.split_key: + unique_values = self._get_unique_perturbation_values(perturbation_idx_to_covariates) + + def _get_split_values(indices): + """Get all unique covariate values for cells in this split.""" + if len(indices) == 0: + return [] + split_vals = set() + for idx in indices: + pert_idx = perturbation_covariates_mask[idx] + covariates = perturbation_idx_to_covariates[pert_idx] + split_vals.update(covariates) + return list(split_vals) + + train_values = _get_split_values(split_indices["train"]) + val_values = _get_split_values(split_indices["val"]) + test_values = _get_split_values(split_indices["test"]) + + result["split_values"] = { + "train": train_values, + "val": val_values, + "test": test_values, + "all_unique": unique_values, + } + + # Log split statistics + logger.info(f"Split results for {self.dataset_names[dataset_index]}:") + for split_name, indices in split_indices.items(): + if len(indices) > 0: + logger.info(f" {split_name}: {len(indices)} observations") + + return result + + def split_all_datasets(self) -> dict[str, dict]: + """ + Split all TrainingData objects according to the specified strategy. + + Returns + ------- + dict[str, dict] + Nested dictionary with dataset names as keys and split information as values + """ + logger.info(f"Starting data splitting with strategy: {self.split_type}") + logger.info(f"Number of datasets: {len(self.training_datasets)}") + for i, ratios in enumerate(self.split_ratios): + logger.info(f"Dataset {i} ratios: train={ratios[0]}, val={ratios[1]}, test={ratios[2]}") + + for i, (training_data, dataset_name) in enumerate(zip(self.training_datasets, self.dataset_names, strict=True)): + logger.info(f"\nProcessing dataset {i}: {dataset_name}") + logger.info(f"Using split ratios: {self.split_ratios[i]}") + + split_result = self.split_single_dataset(training_data, i) + self.split_results[dataset_name] = split_result + + logger.info(f"\nCompleted splitting {len(self.training_datasets)} datasets") + return self.split_results + + def generate_split_summary(self) -> dict[str, dict]: + """ + Generate a human-readable summary of split conditions for each dataset. + + This method creates a comprehensive summary showing which specific conditions + (perturbations, cell lines, donors, etc.) are assigned to train/val/test splits. + Useful for tracking what was tested across different random seeds. + + Returns + ------- + dict[str, dict] + Dictionary with dataset names as keys and split summaries as values. + Each summary contains: + - conditions_per_split: Lists of condition values in each split + - observations_per_condition: Number of observations for each condition in each split + - statistics: Observation and condition counts per split + - configuration: Random states and split parameters used + + Examples + -------- + >>> splitter = DataSplitter(...) + >>> results = splitter.split_all_datasets() + >>> summary = splitter.generate_split_summary() + >>> print(summary["dataset1"]["conditions_per_split"]["test"]) + ['DrugA', 'DrugB', 'DrugC'] + >>> print(summary["dataset1"]["observations_per_condition"]["test"]["DrugA"]) + 150 + """ + if not self.split_results: + raise ValueError("No split results available. Run split_all_datasets() first.") + + summary = {} + + for i, (dataset_name, split_info) in enumerate(self.split_results.items()): + dataset_summary = { + "configuration": { + "split_type": split_info["metadata"]["split_type"], + "split_key": split_info["metadata"]["split_key"], + "split_ratios": split_info["metadata"]["split_ratios"], + "random_state": split_info["metadata"]["random_state"], + "test_random_state": split_info["metadata"]["test_random_state"], + "val_random_state": split_info["metadata"]["val_random_state"], + "hard_test_split": split_info["metadata"]["hard_test_split"], + }, + "statistics": { + "total_observations": split_info["metadata"]["total_cells"], + }, + } + + if self.force_training_values: + dataset_summary["configuration"]["force_training_values"] = self.force_training_values + if self.control_value: + dataset_summary["configuration"]["control_value"] = self.control_value + + # Add split statistics + for split_name, indices in split_info["indices"].items(): + dataset_summary["statistics"][f"{split_name}_observations"] = len(indices) + if split_info["metadata"]["total_cells"] > 0: + percentage = 100 * len(indices) / split_info["metadata"]["total_cells"] + dataset_summary["statistics"][f"{split_name}_percentage"] = round(percentage, 2) + + # Add condition information if available + if "split_values" in split_info: + dataset_summary["conditions_per_split"] = { + "train": sorted(split_info["split_values"]["train"]), + "val": sorted(split_info["split_values"]["val"]), + "test": sorted(split_info["split_values"]["test"]), + } + dataset_summary["statistics"]["total_unique_conditions"] = len( + split_info["split_values"]["all_unique"] + ) + dataset_summary["statistics"]["train_conditions"] = len(split_info["split_values"]["train"]) + dataset_summary["statistics"]["val_conditions"] = len(split_info["split_values"]["val"]) + dataset_summary["statistics"]["test_conditions"] = len(split_info["split_values"]["test"]) + + # Add observations per condition for each split + training_data = self.training_datasets[i] + pert_info = self.extract_perturbation_info(training_data) + perturbation_covariates_mask = pert_info["perturbation_covariates_mask"] + perturbation_idx_to_covariates = pert_info["perturbation_idx_to_covariates"] + + observations_per_condition = {} + for split_name, indices in split_info["indices"].items(): + if len(indices) == 0: + observations_per_condition[split_name] = {} + continue + + # Count observations per condition for this split + condition_counts = {} + for idx in indices: + pert_idx = perturbation_covariates_mask[idx] + condition_tuple = perturbation_idx_to_covariates[pert_idx] + + # Convert tuple to string representation for JSON compatibility + if len(condition_tuple) == 1: + condition_str = condition_tuple[0] + else: + condition_str = "+".join(condition_tuple) + + condition_counts[condition_str] = condition_counts.get(condition_str, 0) + 1 + + # Sort by condition name for consistent output + observations_per_condition[split_name] = dict(sorted(condition_counts.items())) + + dataset_summary["observations_per_condition"] = observations_per_condition + + summary[dataset_name] = dataset_summary + + return summary + + def save_splits(self, output_dir: str | Path) -> None: + """ + Save all split information to the specified directory. + + This saves multiple files per dataset: + - split_summary.json: Human-readable summary with conditions per split + - indices/*.npy: Cell indices for each split + - metadata.json: Configuration and parameters + - split_values.json: Condition values per split (if applicable) + - split_info.pkl: Complete split information + + Parameters + ---------- + output_dir : str | Path + Directory to save the split information + """ + import json + import pickle + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving splits to: {output_dir}") + + # Generate and save split summary + split_summary = self.generate_split_summary() + summary_file = output_dir / "split_summary.json" + with open(summary_file, "w") as f: + json.dump(split_summary, f, indent=2) + logger.info(f"Saved split summary -> {summary_file}") + + for dataset_name, split_info in self.split_results.items(): + # Save indices as numpy arrays (more efficient for large datasets) + indices_dir = output_dir / dataset_name / "indices" + indices_dir.mkdir(parents=True, exist_ok=True) + + for split_name, indices in split_info["indices"].items(): + if len(indices) > 0: + indices_file = indices_dir / f"{split_name}_indices.npy" + np.save(indices_file, indices) + logger.info(f"Saved {split_name} indices: {len(indices)} observations -> {indices_file}") + + # Save metadata as JSON + metadata_file = output_dir / dataset_name / "metadata.json" + with open(metadata_file, "w") as f: + # Convert numpy arrays to lists for JSON serialization + metadata = split_info["metadata"].copy() + json.dump(metadata, f, indent=2) + logger.info(f"Saved metadata -> {metadata_file}") + + # Save split values if available + if "split_values" in split_info: + split_values_file = output_dir / dataset_name / "split_values.json" + with open(split_values_file, "w") as f: + json.dump(split_info["split_values"], f, indent=2) + logger.info(f"Saved split values -> {split_values_file}") + + # Save complete split info as pickle for easy loading + complete_file = output_dir / dataset_name / "split_info.pkl" + with open(complete_file, "wb") as f: + pickle.dump(split_info, f) + logger.info(f"Saved complete split info -> {complete_file}") + + logger.info("All splits saved successfully") + + @staticmethod + def load_split_info(split_dir: str | Path, dataset_name: str) -> dict: + """ + Load split information from disk. + + Parameters + ---------- + split_dir : str | Path + Directory containing saved splits + dataset_name : str + Name of the dataset + + Returns + ------- + dict + Dictionary containing split indices and metadata + """ + import pickle + + split_dir = Path(split_dir) + dataset_dir = split_dir / dataset_name + + if not dataset_dir.exists(): + raise FileNotFoundError(f"Split directory not found: {dataset_dir}") + + # Load complete split info from pickle + complete_file = dataset_dir / "split_info.pkl" + if complete_file.exists(): + with open(complete_file, "rb") as f: + return pickle.load(f) + + # Fallback: reconstruct from individual files + logger.warning("Complete split info not found, reconstructing from individual files") + + # Load indices + indices_dir = dataset_dir / "indices" + indices = {} + for split_name in ["train", "val", "test"]: + indices_file = indices_dir / f"{split_name}_indices.npy" + if indices_file.exists(): + indices[split_name] = np.load(indices_file) + else: + indices[split_name] = np.array([]) + + # Load metadata + import json + + metadata_file = dataset_dir / "metadata.json" + with open(metadata_file) as f: + metadata = json.load(f) + + # Load split values if available + split_values = None + split_values_file = dataset_dir / "split_values.json" + if split_values_file.exists(): + with open(split_values_file) as f: + split_values = json.load(f) + + result = {"indices": indices, "metadata": metadata} + if split_values: + result["split_values"] = split_values + + return result diff --git a/src/scaleflow/data/_dataloader.py b/src/scaleflow/data/_dataloader.py new file mode 100644 index 00000000..7b7feb1f --- /dev/null +++ b/src/scaleflow/data/_dataloader.py @@ -0,0 +1,491 @@ +import abc +from typing import Any, Literal + +import numpy as np +import tqdm +import os +import threading +from concurrent.futures import ThreadPoolExecutor, Future + +from scaleflow.data._data import ( + PredictionData, + TrainingData, + ValidationData, + MappedCellData, +) + +__all__ = [ + "TrainSampler", + "ValidationSampler", + "PredictionSampler", + "ReservoirSampler", +] + + +class TrainSampler: + """Data sampler for :class:`~scaleflow.data.TrainingData`. + + Parameters + ---------- + data + The training data. + batch_size + The batch size. + + """ + + def __init__(self, data: TrainingData, batch_size: int = 1024): + self._data = data + self._data_idcs = np.arange(data.cell_data.shape[0]) + self.batch_size = batch_size + self.n_source_dists = data.n_controls + self.n_target_dists = data.n_perturbations + + self._control_to_perturbation_keys = sorted(data.control_to_perturbation.keys()) + self._has_condition_data = data.condition_data is not None + + def _sample_target_dist_idx(self, rng, source_dist_idx: int) -> int: + """Sample a target distribution index given the source distribution index.""" + return rng.choice(self._data.control_to_perturbation[source_dist_idx]) + + def _sample_source_dist_idx(self, rng) -> int: + """Sample a source distribution index.""" + return rng.choice(self.n_source_dists) + + def _get_embeddings(self, idx, condition_data) -> dict[str, np.ndarray]: + """Get embeddings for a given index.""" + result = {} + for key, arr in condition_data.items(): + result[key] = np.expand_dims(arr[idx], 0) + return result + + def _sample_from_mask(self, rng, mask) -> np.ndarray: + """Sample indices according to a mask.""" + # Convert mask to probability distribution + valid_indices = np.where(mask)[0] + # Handle case with no valid indices (should not happen in practice) + if len(valid_indices) == 0: + raise ValueError("No valid indices found in the mask") + + # Sample from valid indices with equal probability + batch_idcs = rng.choice(valid_indices, self.batch_size, replace=True) + return batch_idcs + + def _get_source_cells_mask(self, source_dist_idx: int) -> np.ndarray: + return self._data.split_covariates_mask == source_dist_idx + + def _get_target_cells_mask(self, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + return self._data.perturbation_covariates_mask == target_dist_idx + + def _sample_source_batch_idcs(self, rng, source_dist_idx: int) -> dict[str, Any]: + source_cells_mask = self._get_source_cells_mask(source_dist_idx) + source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) + return source_batch_idcs + + def _sample_target_batch_idcs(self, rng, source_dist_idx: int, target_dist_idx: int) -> dict[str, Any]: + target_cells_mask = self._get_target_cells_mask(source_dist_idx, target_dist_idx) + target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) + return target_batch_idcs + + def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: + source_cells_mask = self._get_source_cells_mask(source_dist_idx) + source_batch_idcs = self._sample_from_mask(rng, source_cells_mask) + return self._data.cell_data[source_batch_idcs] + + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + target_cells_mask = self._get_target_cells_mask(source_dist_idx, target_dist_idx) + target_batch_idcs = self._sample_from_mask(rng, target_cells_mask) + return self._data.cell_data[target_batch_idcs] + + def sample(self, rng) -> dict[str, Any]: + """Sample a batch of data. + + Parameters + ---------- + seed : int, optional + Random seed + + Returns + ------- + Dictionary with source and target data + """ + # Sample source and target + source_dist_idx = self._sample_source_dist_idx(rng) + target_dist_idx = self._sample_target_dist_idx(rng, source_dist_idx) + + # Sample source and target cells + source_batch = self._sample_source_cells(rng, source_dist_idx) + target_batch = self._sample_target_cells(rng, source_dist_idx, target_dist_idx) + + res = {"src_cell_data": source_batch, "tgt_cell_data": target_batch} + if self._has_condition_data: + condition_batch = self._get_embeddings(target_dist_idx, self._data.condition_data) + res["condition"] = condition_batch + return res + + @property + def data(self) -> TrainingData: + """The training data.""" + return self._data + +class ReservoirSampler(TrainSampler): + """Data sampler with gradual pool replacement using reservoir sampling. + + This approach replaces pool elements one by one rather than refreshing + the entire pool, providing better cache locality while maintaining + reasonable randomness. + + Parameters + ---------- + data + The training data. + batch_size + The batch size. + pool_size + The size of the pool of source distribution indices. + replacement_prob + Probability of replacing a pool element after each sample. + Lower values = longer cache retention, less randomness. + Higher values = faster cache turnover, more randomness. + replace_in_pool + Whether to allow replacement when sampling from the pool. + """ + + def __init__( + self, + data: MappedCellData, + batch_size: int = 1024, + pool_size: int = 100, + replacement_prob: float = 0.01, + ): + self.batch_size = batch_size + self.n_source_dists = data.n_controls + self.n_target_dists = data.n_perturbations + self._data = data + + self._control_to_perturbation_keys = sorted(data.control_to_perturbation.keys()) + self._has_condition_data = data.condition_data is not None + self._pool_size = pool_size + self._replacement_prob = replacement_prob + self._pool_usage_count = np.zeros(self.n_source_dists, dtype=int) + self._initialized = False + + # Concurrency primitives + self._lock = threading.RLock() + self._executor = ThreadPoolExecutor(max_workers=2) + # Map pool position -> {"old": int, "new": int, "future": Future} + self._pending_replacements: dict[int, dict[str, Any]] = {} + + def init_pool(self, rng): + self._init_pool(rng) + self._init_cache_pool_elements() + + @staticmethod + def _get_target_idx_pool(src_idx_pool: np.ndarray, control_to_perturbation: dict[int, np.ndarray]) -> set[int]: + tgt_idx_pool = set() + for src_idx in src_idx_pool: + tgt_idx_pool.update(control_to_perturbation[src_idx].tolist()) + return tgt_idx_pool + + def _init_cache_pool_elements(self): + if not self._initialized: + raise ValueError("Pool not initialized. Call init_pool(rng) first.") + with self._lock: + self._cached_srcs = {i: self._data.src_cell_data[i][...] for i in self._src_idx_pool} + tgt_indices = sorted( + {int(j) for i in self._src_idx_pool for j in self._data.control_to_perturbation[i]} + ) + + def _load_tgt(j: int): + return j, self._data.tgt_cell_data[j][...] + + max_workers = min(32, (os.cpu_count() or 4)) + with ThreadPoolExecutor(max_workers=max_workers) as ex: + results = list(ex.map(_load_tgt, tgt_indices)) + + with self._lock: + self._cached_tgts = {j: arr for j, arr in results} + + def _init_pool(self, rng): + """Initialize the pool with random source distribution indices.""" + self._src_idx_pool = rng.choice(self.n_source_dists, size=self._pool_size, replace=False) + self._initialized = True + + def _sample_source_dist_idx(self, rng) -> int: + """Sample a source distribution index with gradual pool replacement.""" + if not self._initialized: + raise ValueError("Pool not initialized. Call init_pool(rng) first.") + + # Opportunistically apply any ready replacements (non-blocking) + self._apply_ready_replacements() + + # Sample from current pool + with self._lock: + source_idx = rng.choice(sorted(self._cached_srcs.keys())) + + # Increment usage count for monitoring + self._pool_usage_count[source_idx] += 1 + + # Gradually replace elements based on replacement probability (schedule only) + if rng.random() < self._replacement_prob: + self._schedule_replacement(rng) + + return source_idx + + def _schedule_replacement(self, rng): + """Schedule a single pool element replacement without blocking.""" + # weights same as previous logic + most_used_weight = (self._pool_usage_count == self._pool_usage_count.max()).astype(float) + if most_used_weight.sum() == 0: + return + most_used_weight /= most_used_weight.sum() + replaced_pool_idx = rng.choice(self.n_source_dists, p=most_used_weight) + + with self._lock: + pool_set = set(self._src_idx_pool.tolist()) + if replaced_pool_idx not in pool_set: + return + in_pool_idx = int(np.where(self._src_idx_pool == replaced_pool_idx)[0][0]) + + # If there's already a pending replacement for this pool slot, skip + if in_pool_idx in self._pending_replacements: + return + + least_used_weight = (self._pool_usage_count == self._pool_usage_count.min()).astype(float) + if least_used_weight.sum() == 0: + return + least_used_weight /= least_used_weight.sum() + new_pool_idx = int(rng.choice(self.n_source_dists, p=least_used_weight)) + + # Kick off background load for new indices + fut: Future = self._executor.submit(self._load_new_cache, new_pool_idx) + self._pending_replacements[in_pool_idx] = { + "old": replaced_pool_idx, + "new": new_pool_idx, + "future": fut, + } + print(f"scheduled replacement of {replaced_pool_idx} with {new_pool_idx} (slot {in_pool_idx})") + + def _apply_ready_replacements(self): + """Apply any finished background loads; non-blocking.""" + to_apply: list[int] = [] + with self._lock: + for slot, info in self._pending_replacements.items(): + fut: Future = info["future"] + if fut.done() and not fut.cancelled(): + to_apply.append(slot) + + for slot in to_apply: + with self._lock: + info = self._pending_replacements.pop(slot, None) + if info is None: + continue + old_idx = int(info["old"]) + new_idx = int(info["new"]) + fut: Future = info["future"] + try: + prepared = fut.result(timeout=0) # already done + except Exception as e: + print(f"background load failed for {new_idx}: {e}") + continue + + # Swap pool index + self._src_idx_pool[slot] = new_idx + + # Add new entries first + self._cached_srcs[new_idx] = prepared["src"] + for k, arr in prepared["tgts"].items(): + self._cached_tgts[k] = arr + + # Remove old entries + if old_idx in self._cached_srcs: + del self._cached_srcs[old_idx] + for k in self._data.control_to_perturbation[old_idx]: + if k in self._cached_tgts: + del self._cached_tgts[k] + + print(f"applied replacement: {old_idx} -> {new_idx} (slot {slot})") + + def _load_new_cache(self, src_idx: int) -> dict[str, Any]: + """Load new src and corresponding tgt arrays in the background.""" + src_arr = self._data.src_cell_data[src_idx][...] + tgt_dict = {k: self._data.tgt_cell_data[k][...] for k in self._data.control_to_perturbation[src_idx]} + return {"src": src_arr, "tgts": tgt_dict} + + def get_pool_stats(self) -> dict: + """Get statistics about the current pool state.""" + if self._src_idx_pool is None: + return {"pool_size": 0, "avg_usage": 0, "unique_sources": 0} + return { + "pool_size": self._pool_size, + "avg_usage": float(np.mean(self._pool_usage_count)), + "unique_sources": len(set(self._src_idx_pool)), + "pool_elements": self._src_idx_pool.copy(), + "usage_counts": self._pool_usage_count.copy(), + } + + def _sample_source_cells(self, rng, source_dist_idx: int) -> np.ndarray: + with self._lock: + arr = self._cached_srcs[source_dist_idx] + return rng.choice(arr, size=self.batch_size, replace=True) + + def _sample_target_cells(self, rng, source_dist_idx: int, target_dist_idx: int) -> np.ndarray: + with self._lock: + arr = self._cached_tgts[target_dist_idx] + return rng.choice(arr, size=self.batch_size, replace=True) + + +class BaseValidSampler(abc.ABC): + @abc.abstractmethod + def sample(*args, **kwargs): + pass + + def _get_key(self, cond_idx: int) -> tuple[str, ...]: + if len(self._data.perturbation_idx_to_id): # type: ignore[attr-defined] + return self._data.perturbation_idx_to_id[cond_idx] # type: ignore[attr-defined] + cov_combination = self._data.perturbation_idx_to_covariates[cond_idx] # type: ignore[attr-defined] + return tuple(cov_combination[i] for i in range(len(cov_combination))) + + def _get_perturbation_to_control(self, data: ValidationData | PredictionData) -> dict[int, np.ndarray]: + d = {} + for k, v in data.control_to_perturbation.items(): + for el in v: + d[el] = k + return d + + def _get_condition_data(self, cond_idx: int) -> dict[str, np.ndarray]: + return {k: v[[cond_idx], ...] for k, v in self._data.condition_data.items()} # type: ignore[attr-defined] + + +class ValidationSampler(BaseValidSampler): + """Data sampler for :class:`~scaleflow.data.ValidationData`. + + Parameters + ---------- + val_data + The validation data. + seed + Random seed. + validation_batch_size + Maximum number of cells to sample per condition during validation. + If None, uses all available cells. + """ + + def __init__(self, val_data: ValidationData, seed: int = 0, validation_batch_size: int | None = None) -> None: + self._data = val_data + self.perturbation_to_control = self._get_perturbation_to_control(val_data) + self.n_conditions_on_log_iteration = ( + val_data.n_conditions_on_log_iteration + if val_data.n_conditions_on_log_iteration is not None + else val_data.n_perturbations + ) + self.n_conditions_on_train_end = ( + val_data.n_conditions_on_train_end + if val_data.n_conditions_on_train_end is not None + else val_data.n_perturbations + ) + self.validation_batch_size = validation_batch_size + self.rng = np.random.default_rng(seed) + if self._data.condition_data is None: + raise NotImplementedError("Validation data must have condition data.") + + def sample(self, mode: Literal["on_log_iteration", "on_train_end"]) -> Any: + """Sample data for validation. + + Parameters + ---------- + mode + Sampling mode. Either ``"on_log_iteration"`` or ``"on_train_end"``. + + Returns + ------- + Dictionary with source, condition, and target data from the validation data. + """ + size = self.n_conditions_on_log_iteration if mode == "on_log_iteration" else self.n_conditions_on_train_end + condition_idcs = self.rng.choice(self._data.n_perturbations, size=(size,), replace=False) + + source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs] + source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs] + source_cells = [self._data.cell_data[mask] for mask in source_cells_mask] + target_cells_mask = [cond_idx == self._data.perturbation_covariates_mask for cond_idx in condition_idcs] + target_cells = [self._data.cell_data[mask] for mask in target_cells_mask] + + # Apply validation batch size if specified + if self.validation_batch_size is not None: + source_cells = self._subsample_cells(source_cells) + target_cells = self._subsample_cells(target_cells) + + conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs] + cell_rep_dict = {} + cond_dict = {} + true_dict = {} + for i in range(len(condition_idcs)): + k = self._get_key(condition_idcs[i]) + cell_rep_dict[k] = source_cells[i] + cond_dict[k] = conditions[i] + true_dict[k] = target_cells[i] + + return {"source": cell_rep_dict, "condition": cond_dict, "target": true_dict} + + def _subsample_cells(self, cells_list: list[np.ndarray]) -> list[np.ndarray]: + """Subsample cells from each condition to validation_batch_size.""" + subsampled_cells = [] + for cells in cells_list: + if len(cells) > self.validation_batch_size: + indices = self.rng.choice(len(cells), size=self.validation_batch_size, replace=False) + subsampled_cells.append(cells[indices]) + else: + subsampled_cells.append(cells) + return subsampled_cells + + @property + def data(self) -> ValidationData: + """The validation data.""" + return self._data + + +class PredictionSampler(BaseValidSampler): + """Data sampler for :class:`~scaleflow.data.PredictionData`. + + Parameters + ---------- + pred_data + The prediction data. + + """ + + def __init__(self, pred_data: PredictionData) -> None: + self._data = pred_data + self.perturbation_to_control = self._get_perturbation_to_control(pred_data) + if self._data.condition_data is None: + raise NotImplementedError("Validation data must have condition data.") + + def sample(self) -> Any: + """Sample data for prediction. + + Returns + ------- + Dictionary with source and condition data from the prediction data. + """ + condition_idcs = range(self._data.n_perturbations) + + source_idcs = [self.perturbation_to_control[cond_idx] for cond_idx in condition_idcs] + source_cells_mask = [self._data.split_covariates_mask == source_idx for source_idx in source_idcs] + source_cells = [self._data.cell_data[mask] for mask in source_cells_mask] + conditions = [self._get_condition_data(cond_idx) for cond_idx in condition_idcs] + cell_rep_dict = {} + cond_dict = {} + for i in range(len(condition_idcs)): + k = self._get_key(condition_idcs[i]) + cell_rep_dict[k] = source_cells[i] + cond_dict[k] = conditions[i] + + return { + "source": cell_rep_dict, + "condition": cond_dict, + } + + @property + def data(self) -> PredictionData: + """The training data.""" + return self._data diff --git a/src/cellflow/data/_datamanager.py b/src/scaleflow/data/_datamanager.py similarity index 98% rename from src/cellflow/data/_datamanager.py rename to src/scaleflow/data/_datamanager.py index 065cddd2..d13c2c4c 100644 --- a/src/cellflow/data/_datamanager.py +++ b/src/scaleflow/data/_datamanager.py @@ -2,20 +2,34 @@ from collections.abc import Sequence from typing import Any +import scipy.sparse as sp +import sklearn.preprocessing as preprocessing + +import numpy as np +import pandas as pd +from pandas.api.types import is_numeric_dtype +import tqdm +import threading +from concurrent.futures import ThreadPoolExecutor, Future +import os import anndata import dask import dask.dataframe as dd -import dask.delayed -import numpy as np -import pandas as pd -import scipy.sparse as sp -import sklearn.preprocessing as preprocessing from dask.diagnostics import ProgressBar -from pandas.api.types import is_numeric_dtype -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData +<<<<<<< HEAD +from scaleflow.data._data import ( + PredictionData, + TrainingData, + ValidationData, + MappedCellData, +) + +======= +>>>>>>> main +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._data import ConditionData, PredictionData, ReturnData, TrainingData, ValidationData from ._utils import _flatten_list, _to_list @@ -223,8 +237,8 @@ def get_prediction_data( is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. covariate_data A :class:`~pandas.DataFrame` with columns defining the covariates as - in :meth:`cellflow.model.CellFlow.prepare_data` and stored in - :attr:`cellflow.model.CellFlow.data_manager`. + in :meth:`scaleflow.model.CellFlow.prepare_data` and stored in + :attr:`scaleflow.model.CellFlow.data_manager`. rep_dict Dictionary with representations of the covariates. If not provided, :attr:`~anndata.AnnData.uns` is used. @@ -759,15 +773,15 @@ def _get_cell_data( if sample_rep == "X": sample_rep = adata.X if isinstance(sample_rep, sp.csr_matrix): - return np.asarray(sample_rep.toarray()) + return sample_rep.toarray() else: - return np.asarray(sample_rep) + return sample_rep if isinstance(self._sample_rep, str): if self._sample_rep not in adata.obsm: raise KeyError(f"Sample representation '{self._sample_rep}' not found in `adata.obsm`.") - return np.asarray(adata.obsm[self._sample_rep]) + return adata.obsm[self._sample_rep] attr, key = next(iter(sample_rep.items())) # type: ignore[union-attr] - return np.asarray(getattr(adata, attr)[key]) + return getattr(adata, attr)[key] def _verify_control_data(self, adata: anndata.AnnData | None) -> None: if adata is None: diff --git a/src/scaleflow/data/_jax_dataloader.py b/src/scaleflow/data/_jax_dataloader.py new file mode 100644 index 00000000..84966a21 --- /dev/null +++ b/src/scaleflow/data/_jax_dataloader.py @@ -0,0 +1,110 @@ +import queue +import threading +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from scaleflow.data._data import ( + TrainingData, +) +from scaleflow.data._dataloader import TrainSampler + + +def _prefetch_to_device( + sampler: TrainSampler, + seed: int, + num_iterations: int, + prefetch_factor: int = 2, + num_workers: int = 4, +) -> Generator[dict[str, Any], None, None]: + import jax + + seq = np.random.SeedSequence(seed) + random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] + + q: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=prefetch_factor * num_workers) + sem = threading.Semaphore(num_iterations) + stop_event = threading.Event() + + def worker(rng: np.random.Generator): + while not stop_event.is_set() and sem.acquire(blocking=False): + batch = sampler.sample(rng) + batch = jax.device_put(batch, jax.devices()[0], donate=True) + jax.block_until_ready(batch) + while not stop_event.is_set(): + try: + q.put(batch, timeout=1.0) + break # Batch successfully put into the queue; break out of retry loop + except queue.Full: + continue + + return + + # Start multiple worker threads + ts = [] + for i in range(num_workers): + t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i],)) + t.start() + ts.append(t) + + try: + for _ in range(num_iterations): + # Yield batches from the queue; blocks waiting for available batch + yield q.get() + finally: + # When the generator is closed or garbage collected, clean up the worker threads + stop_event.set() # Signal all workers to exit + for t in ts: + t.join() # Wait for all worker threads to finish + + +@dataclass +class JaxOutOfCoreTrainSampler: + """ + A sampler that prefetches batches to the GPU for out-of-core training. + + Here out-of-core means that data can be more than the GPU memory. + + Parameters + ---------- + data + The training data. + seed + The seed for the random number generator. + batch_size + The batch size. + num_workers + The number of workers to use for prefetching. + prefetch_factor + The prefetch factor similar to PyTorch's DataLoader. + + """ + + data: TrainingData + seed: int + batch_size: int = 1024 + num_workers: int = 4 + prefetch_factor: int = 2 + + def __post_init__(self): + self.inner = TrainSampler(data=self.data, batch_size=self.batch_size) + self._iterator = None + + def set_sampler(self, num_iterations: int) -> None: + self._iterator = _prefetch_to_device( + sampler=self.inner, seed=self.seed, num_iterations=num_iterations, + prefetch_factor=self.prefetch_factor, num_workers=self.num_workers + ) + + def sample(self, rng=None) -> dict[str, Any]: + if self._iterator is None: + raise ValueError( + "Sampler not set. Use `set_sampler` to set the sampler with" + "the number of iterations. Without the number of iterations," + " the sampler will not be able to sample the data." + ) + if rng is not None: + del rng + return next(self._iterator) diff --git a/src/scaleflow/data/_torch_dataloader.py b/src/scaleflow/data/_torch_dataloader.py new file mode 100644 index 00000000..832746eb --- /dev/null +++ b/src/scaleflow/data/_torch_dataloader.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from functools import partial + +import numpy as np + +from scaleflow.compat import TorchIterableDataset +from scaleflow.data._data import MappedCellData +from scaleflow.data._dataloader import TrainSampler, ReservoirSampler + + +def _worker_init_fn_helper(worker_id, random_generators): + import torch + + del worker_id + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id # type: ignore[union-attr] + rng = random_generators[worker_id] + worker_info.dataset.set_rng(rng) # type: ignore[union-attr] + return rng + + +@dataclass +class TorchCombinedTrainSampler(TorchIterableDataset): + """ + Combined training sampler that iterates over multiple samplers. + + Need to call set_rng(rng) before using the sampler. + + Args: + samplers: List of training samplers. + rng: Random number generator. + """ + + samplers: list[TrainSampler] + weights: np.ndarray | None = None + rng: np.random.Generator | None = None + dataset_names: list[str] | None = None + + def __post_init__(self): + if self.weights is None: + self.weights = np.ones(len(self.samplers)) + self.weights = np.asarray(self.weights) + assert len(self.weights) == len(self.samplers) + self.weights = self.weights / self.weights.sum() + + def set_rng(self, rng: np.random.Generator): + self.rng = rng + + def __iter__(self): + return self + + def __next__(self): + if self.rng is None: + raise ValueError("Please call set_rng() before using the sampler.") + dataset_idx = self.rng.choice(len(self.samplers), p=self.weights) + res = self.samplers[dataset_idx].sample(self.rng) + if self.dataset_names is not None: + res["dataset_name"] = self.dataset_names[dataset_idx] + return res + + @classmethod + def combine_zarr_training_samplers( + cls, + data_paths: list[str], + batch_size: int = 1024, + seed: int = 42, + num_workers: int = 4, + prefetch_factor: int = 2, + weights: np.ndarray | None = None, + dataset_names: list[str] | None = None, + ): + import torch + + seq = np.random.SeedSequence(seed) + random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)] + worker_init_fn = partial(_worker_init_fn_helper, random_generators=random_generators) + data = [MappedCellData.read_zarr(path) for path in data_paths] + samplers = [ReservoirSampler(data[i], batch_size) for i in range(len(data))] + combined_sampler = cls(samplers, weights=weights, dataset_names=dataset_names) + return torch.utils.data.DataLoader( + combined_sampler, + batch_size=None, + worker_init_fn=worker_init_fn, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) diff --git a/src/scaleflow/data/_utils.py b/src/scaleflow/data/_utils.py new file mode 100644 index 00000000..d741300b --- /dev/null +++ b/src/scaleflow/data/_utils.py @@ -0,0 +1,93 @@ +from collections.abc import Iterable, Mapping +from typing import Any + +import anndata as ad +import zarr +from zarr.abc.codec import BytesBytesCodec +from zarr.codecs import BloscCodec + + +def write_sharded( + group: zarr.Group, + data: dict[str, Any], + name: str, + chunk_size: int = 4096, + shard_size: int = 65536, + compressors: Iterable[BytesBytesCodec] = ( + BloscCodec( + cname="lz4", + clevel=3, + ), + ), +): + """Function to write data to a zarr group in a sharded format. + + Parameters + ---------- + group + The zarr group to write to. + data + The data to write. + chunk_size + The chunk size. + shard_size + The shard size. + """ + # TODO: this is a copy of the function in arrayloaders + # when it is no longer public we should use the function from arrayloaders + # https://github.com/laminlabs/arrayloaders/blob/main/arrayloaders/io/store_creation.py + ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr + + def get_size(shape: tuple[int, ...], chunk_size: int, shard_size: int) -> tuple[int, int]: + shard_size_used = shard_size + chunk_size_used = chunk_size + if chunk_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + elif chunk_size < shape[0] or shard_size > shape[0]: + chunk_size_used = shard_size_used = shape[0] + return chunk_size_used, shard_size_used + + def callback( + func: ad.experimental.Write, + g: zarr.Group, + k: str, + elem: ad.typing.RWAble, + dataset_kwargs: Mapping[str, Any], + iospec: ad.experimental.IOSpec, + ): + if iospec.encoding_type in {"array"}: + # Calculate greatest common divisor for first dimension + # or use smallest dimension as chunk size + + chunk_size_used, shard_size_used = get_size(elem.shape, chunk_size, shard_size) + + dataset_kwargs = { + "shards": (shard_size_used,) + (elem.shape[1:]), # only shard over 1st dim + "chunks": (chunk_size_used,) + (elem.shape[1:]), # only chunk over 1st dim + "compressors": compressors, + **dataset_kwargs, + } + elif iospec.encoding_type in {"csr_matrix", "csc_matrix"}: + dataset_kwargs = { + "shards": (shard_size,), + "chunks": (chunk_size,), + "compressors": compressors, + **dataset_kwargs, + } + + func(g, k, elem, dataset_kwargs=dataset_kwargs) + + ad.experimental.write_dispatched(group, name, data, callback=callback) + zarr.consolidate_metadata(group.store) + + +def _to_list(x: list[Any] | tuple[Any] | Any) -> list[Any] | tuple[Any]: + """Converts x to a list if it is not already a list or tuple.""" + if isinstance(x, (list | tuple)): + return x + return [x] + + +def _flatten_list(x: Iterable[Iterable[Any]]) -> list[Any]: + """Flattens a list of lists.""" + return [item for sublist in x for item in sublist] diff --git a/src/cellflow/datasets.py b/src/scaleflow/datasets.py similarity index 94% rename from src/cellflow/datasets.py rename to src/scaleflow/datasets.py index d07ce340..f8bace3a 100644 --- a/src/cellflow/datasets.py +++ b/src/scaleflow/datasets.py @@ -4,7 +4,7 @@ import anndata as ad from scanpy.readwrite import _check_datafile_present_and_download -from cellflow._types import PathLike +from scaleflow._types import PathLike __all__ = [ "ineurons", @@ -13,7 +13,7 @@ def ineurons( - path: PathLike = "~/.cache/cellflow/ineurons.h5ad", + path: PathLike = "~/.cache/scaleflow/ineurons.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: @@ -45,7 +45,7 @@ def ineurons( def pbmc_cytokines( - path: PathLike = "~/.cache/cellflow/pbmc_parse.h5ad", + path: PathLike = "~/.cache/scaleflow/pbmc_parse.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: @@ -78,7 +78,7 @@ def pbmc_cytokines( def zesta( - path: PathLike = "~/.cache/cellflow/zesta.h5ad", + path: PathLike = "~/.cache/scaleflow/zesta.h5ad", force_download: bool = False, **kwargs: Any, ) -> ad.AnnData: diff --git a/src/scaleflow/external/__init__.py b/src/scaleflow/external/__init__.py new file mode 100644 index 00000000..f6d2f0e3 --- /dev/null +++ b/src/scaleflow/external/__init__.py @@ -0,0 +1,6 @@ +try: + from scaleflow.external._scvi import CFJaxSCVI +except ImportError as e: + raise ImportError( + "scaleflow.external requires more dependencies. Please install via pip install 'scaleflow[external]'" + ) from e diff --git a/src/cellflow/external/_scvi.py b/src/scaleflow/external/_scvi.py similarity index 98% rename from src/cellflow/external/_scvi.py rename to src/scaleflow/external/_scvi.py index f979d93c..e84a912f 100644 --- a/src/cellflow/external/_scvi.py +++ b/src/scaleflow/external/_scvi.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import numpy as np -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike if TYPE_CHECKING: from typing import Literal @@ -25,7 +25,7 @@ class CFJaxSCVI(JaxSCVI): - from cellflow.external._scvi_utils import CFJaxVAE + from scaleflow.external._scvi_utils import CFJaxVAE _module_cls = CFJaxVAE diff --git a/src/cellflow/external/_scvi_utils.py b/src/scaleflow/external/_scvi_utils.py similarity index 100% rename from src/cellflow/external/_scvi_utils.py rename to src/scaleflow/external/_scvi_utils.py diff --git a/src/cellflow/metrics/__init__.py b/src/scaleflow/metrics/__init__.py similarity index 92% rename from src/cellflow/metrics/__init__.py rename to src/scaleflow/metrics/__init__.py index 79cb1738..63a2aa52 100644 --- a/src/cellflow/metrics/__init__.py +++ b/src/scaleflow/metrics/__init__.py @@ -1,4 +1,4 @@ -from cellflow.metrics._metrics import ( +from scaleflow.metrics._metrics import ( compute_e_distance, compute_e_distance_fast, compute_mean_metrics, diff --git a/src/cellflow/metrics/_metrics.py b/src/scaleflow/metrics/_metrics.py similarity index 100% rename from src/cellflow/metrics/_metrics.py rename to src/scaleflow/metrics/_metrics.py diff --git a/src/scaleflow/model/__init__.py b/src/scaleflow/model/__init__.py new file mode 100644 index 00000000..7f4ac88b --- /dev/null +++ b/src/scaleflow/model/__init__.py @@ -0,0 +1,7 @@ +<<<<<<< HEAD +from scaleflow.model._scaleflow import CellFlow +======= +from scaleflow.model._cellflow import CellFlow +>>>>>>> main + +__all__ = ["CellFlow"] diff --git a/src/cellflow/model/_cellflow.py b/src/scaleflow/model/_cellflow.py similarity index 75% rename from src/cellflow/model/_cellflow.py rename to src/scaleflow/model/_cellflow.py index a535ddf7..487526c2 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/scaleflow/model/_cellflow.py @@ -15,18 +15,18 @@ import pandas as pd from ott.neural.methods.flows import dynamics -from cellflow import _constants -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.data._data import ConditionData, TrainingData, ValidationData -from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler, ValidationSampler -from cellflow.data._datamanager import DataManager -from cellflow.model._utils import _write_predictions -from cellflow.networks import _velocity_field -from cellflow.plotting import _utils -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback -from cellflow.training._trainer import CellFlowTrainer -from cellflow.utils import match_linear +from scaleflow import _constants +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from scaleflow.data._data import ConditionData, TrainingData, ValidationData +from scaleflow.data._datamanager import DataManager +from scaleflow.model._utils import _write_predictions +from scaleflow.networks import _velocity_field +from scaleflow.plotting import _utils +from scaleflow.solvers import _genot, _otfm, _eqm +from scaleflow.training._callbacks import BaseCallback +from scaleflow.training._trainer import CellFlowTrainer +from scaleflow.utils import match_linear __all__ = ["CellFlow"] @@ -43,23 +43,28 @@ class CellFlow: adata An :class:`~anndata.AnnData` object to extract the training data from. solver - Solver to use for training. Either ``'otfm'`` or ``'genot'``. + Solver to use for training. Either ``'otfm'``, ``'genot'`` or ``'eqm'``. """ - def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm"): + def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot", "eqm"] = "otfm"): self._adata = adata - self._solver_class = _otfm.OTFlowMatching if solver == "otfm" else _genot.GENOT - self._vf_class = ( - _velocity_field.ConditionalVelocityField - if solver == "otfm" - else _velocity_field.GENOTConditionalVelocityField - ) - self._dataloader: TrainSampler | OOCTrainSampler | None = None + if solver == "otfm": + self._solver_class = _otfm.OTFlowMatching + self._vf_class = _velocity_field.ConditionalVelocityField + elif solver == "genot": + self._solver_class = _genot.GENOT + self._vf_class = _velocity_field.GENOTConditionalVelocityField + elif solver == "eqm": + self._solver_class = _eqm.EquilibriumMatching + self._vf_class = _velocity_field.EquilibriumVelocityField + else: + raise ValueError(f"Unknown solver: {solver}. Must be 'otfm', 'genot', or 'eqm'.") + self._dataloader: TrainSampler | JaxOutOfCoreTrainSampler | None = None self._trainer: CellFlowTrainer | None = None self._validation_data: dict[str, ValidationData] = {"predict_kwargs": {}} - self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None + self._solver: _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None = None self._condition_dim: int | None = None - self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None + self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None = None def prepare_data( self, @@ -73,19 +78,19 @@ def prepare_data( max_combination_length: int | None = None, null_value: float = 0.0, ) -> None: - """Prepare the dataloader for training from :attr:`~cellflow.model.CellFlow.adata`. + """Prepare the dataloader for training from :attr:`~scaleflow.model.CellFlow.adata`. Parameters ---------- sample_rep - Key in :attr:`~anndata.AnnData.obsm` of :attr:`cellflow.model.CellFlow.adata` where + Key in :attr:`~anndata.AnnData.obsm` of :attr:`scaleflow.model.CellFlow.adata` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. control_key Key of a boolean column in :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata` that defines the control samples. + :attr:`scaleflow.model.CellFlow.adata` that defines the control samples. perturbation_covariates A dictionary where the keys indicate the name of the covariate group and the values are - keys in :attr:`~anndata.AnnData.obs` of :attr:`cellflow.model.CellFlow.adata`. The + keys in :attr:`~anndata.AnnData.obs` of :attr:`scaleflow.model.CellFlow.adata`. The corresponding columns can be of the following types: - categorial: The column contains categories whose representation is stored in @@ -126,8 +131,8 @@ def prepare_data( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.data_manager` - the :class:`cellflow.data.DataManager` object. - - :attr:`cellflow.model.CellFlow.train_data` - the training data. + - :attr:`scaleflow.model.CellFlow.data_manager` - the :class:`scaleflow.data.DataManager` object. + - :attr:`scaleflow.model.CellFlow.train_data` - the training data. Example ------- @@ -203,7 +208,7 @@ def prepare_validation_data( An :class:`~anndata.AnnData` object. name Name of the validation data defining the key in - :attr:`cellflow.model.CellFlow.validation_data`. + :attr:`scaleflow.model.CellFlow.validation_data`. n_conditions_on_log_iteration Number of conditions to use for computation callbacks at each logged iteration. If :obj:`None`, use all conditions. @@ -212,14 +217,14 @@ def prepare_validation_data( If :obj:`None`, use all conditions. predict_kwargs Keyword arguments for the prediction function - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. Returns ------- :obj:`None`, and updates the following fields: - - :attr:`cellflow.model.CellFlow.validation_data` - a dictionary with the validation data. + - :attr:`scaleflow.model.CellFlow.validation_data` - a dictionary with the validation data. """ if self.train_data is None: @@ -238,7 +243,8 @@ def prepare_validation_data( predict_kwargs = predict_kwargs or {} # Check if predict_kwargs is alreday provided from an earlier call if "predict_kwargs" in self._validation_data and len(predict_kwargs): - predict_kwargs = self._validation_data["predict_kwargs"].update(predict_kwargs) + self._validation_data["predict_kwargs"].update(predict_kwargs) + predict_kwargs = self._validation_data["predict_kwargs"] # Set batched prediction to False if split_val is True if split_val: predict_kwargs["batched"] = False @@ -279,7 +285,7 @@ def prepare_model( """Prepare the model for training. This function sets up the neural network architecture and specificities of the - :attr:`solver`. When :attr:`solver` is an instance of :class:`cellflow.solvers._genot.GENOT`, + :attr:`solver`. When :attr:`solver` is an instance of :class:`scaleflow.solvers._genot.GENOT`, the following arguments have to be passed to ``'condition_encoder_kwargs'``: @@ -309,9 +315,9 @@ def prepare_model( pooling_kwargs Keyword arguments for the pooling method corresponding to: - - :class:`cellflow.networks.TokenAttentionPooling` if ``'pooling'`` is + - :class:`scaleflow.networks.TokenAttentionPooling` if ``'pooling'`` is ``'attention_token'``. - - :class:`cellflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. + - :class:`scaleflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. layers_before_pool Layers applied to the condition embeddings before pooling. Can be of type @@ -320,8 +326,8 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keyword arguments for the layer type :class:`cellflow.networks.MLPBlock` or - :class:`cellflow.networks.SelfAttentionBlock`. + - Further keyword arguments for the layer type :class:`scaleflow.networks.MLPBlock` or + :class:`scaleflow.networks.SelfAttentionBlock`. - :class:`dict` with keys corresponding to perturbation covariate keys, and values correspondinng to the above mentioned tuples. @@ -333,16 +339,16 @@ def prepare_model( - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be ``'mlp'`` or ``'self_attention'``. - - Further keys depend on the layer type, either for :class:`cellflow.networks.MLPBlock` or - for :class:`cellflow.networks.SelfAttentionBlock`. + - Further keys depend on the layer type, either for :class:`scaleflow.networks.MLPBlock` or + for :class:`scaleflow.networks.SelfAttentionBlock`. condition_embedding_dim Dimensions of the condition embedding, i.e. the last layer of the - :class:`cellflow.networks.ConditionEncoder`. + :class:`scaleflow.networks.ConditionEncoder`. cond_output_dropout - Dropout rate for the last layer of the :class:`cellflow.networks.ConditionEncoder`. + Dropout rate for the last layer of the :class:`scaleflow.networks.ConditionEncoder`. condition_encoder_kwargs - Keyword arguments for the :class:`cellflow.networks.ConditionEncoder`. + Keyword arguments for the :class:`scaleflow.networks.ConditionEncoder`. pool_sample_covariates Whether to include sample covariates in the pooling. time_freqs @@ -350,17 +356,17 @@ def prepare_model( (:func:`ott.neural.networks.layers.sinusoidal_time_encoder`). time_max_period Controls the frequency of the time embeddings, see - :func:`cellflow.networks.utils.sinusoidal_time_encoder`. + :func:`scaleflow.networks.utils.sinusoidal_time_encoder`. time_encoder_dims Dimensions of the layers processing the time embedding in - :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. time_encoder_dropout - Dropout rate for the :attr:`cellflow.networks.ConditionalVelocityField.time_encoder`. + Dropout rate for the :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. hidden_dims Dimensions of the layers processing the input to the velocity field - via :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + via :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. hidden_dropout - Dropout rate for :attr:`cellflow.networks.ConditionalVelocityField.x_encoder`. + Dropout rate for :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. conditioning Conditioning method, should be one of: @@ -373,18 +379,18 @@ def prepare_model( Keyword arguments for the conditioning method. decoder_dims Dimensions of the output layers in - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. decoder_dropout Dropout rate for the output layer - :attr:`cellflow.networks.ConditionalVelocityField.decoder`. + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. vf_act_fn - Activation function of the :class:`cellflow.networks.ConditionalVelocityField`. + Activation function of the :class:`scaleflow.networks.ConditionalVelocityField`. vf_kwargs Additional keyword arguments for the solver-specific vector field. For instance, when ``'solver==genot'``, the following keyword argument can be passed: - ``'genot_source_dims'`` of type :class:`tuple` with the dimensions - of the :class:`cellflow.networks.MLPBlock` processing the source cell. + of the :class:`scaleflow.networks.MLPBlock` processing the source cell. - ``'genot_source_dropout'`` of type :class:`float` indicating the dropout rate for the source cell processing. probability_path @@ -397,12 +403,12 @@ def prepare_model( match_fn Matching function between unperturbed and perturbed cells. Should take as input source and target data and return the optimal transport matrix, see e.g. - :func:`cellflow.utils.match_linear`. + :func:`scaleflow.utils.match_linear`. optimizer Optimizer used for training. solver_kwargs - Keyword arguments for the solver :class:`cellflow.solvers.OTFlowMatching` or - :class:`cellflow.solvers.GENOT`. + Keyword arguments for the solver :class:`scaleflow.solvers.OTFlowMatching` or + :class:`scaleflow.solvers.GENOT`. layer_norm_before_concatenation If :obj:`True`, applies layer normalization before concatenating the embedded time, embedded data, and condition embeddings. @@ -416,23 +422,26 @@ def prepare_model( ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.velocity_field` - an instance of the - :class:`cellflow.networks.ConditionalVelocityField`. - - :attr:`cellflow.model.CellFlow.solver` - an instance of :class:`cellflow.solvers.OTFlowMatching` - or :class:`cellflow.solvers.GENOT`. - - :attr:`cellflow.model.CellFlow.trainer` - an instance of the - :class:`cellflow.training.CellFlowTrainer`. + - :attr:`scaleflow.model.CellFlow.velocity_field` - an instance of the + :class:`scaleflow.networks.ConditionalVelocityField`. + - :attr:`scaleflow.model.CellFlow.solver` - an instance of :class:`scaleflow.solvers.OTFlowMatching` + or :class:`scaleflow.solvers.GENOT`. + - :attr:`scaleflow.model.CellFlow.trainer` - an instance of the + :class:`scaleflow.training.CellFlowTrainer`. """ if self.train_data is None: raise ValueError("Dataloader not initialized. Please call `prepare_data` first.") + # Store the seed for use in train method + self._seed = seed + if condition_mode == "stochastic": if regularization == 0.0: raise ValueError("Stochastic condition embeddings require `regularization`>0.") condition_encoder_kwargs = condition_encoder_kwargs or {} - if self._solver_class == _otfm.OTFlowMatching and vf_kwargs is not None: - raise ValueError("For `solver='otfm'`, `vf_kwargs` must be `None`.") + if (self._solver_class == _otfm.OTFlowMatching or self._solver_class == _eqm.EquilibriumMatching) and vf_kwargs is not None: + raise ValueError("For `solver='otfm'` or `solver='eqm'`, `vf_kwargs` must be `None`.") if self._solver_class == _genot.GENOT: if vf_kwargs is None: vf_kwargs = {"genot_source_dims": [1024, 1024, 1024], "genot_source_dropout": 0.0} @@ -446,34 +455,59 @@ def prepare_model( solver_kwargs = solver_kwargs or {} probability_path = probability_path or {"constant_noise": 0.0} - self.vf = self._vf_class( - output_dim=self._data_dim, - max_combination_length=self.train_data.max_combination_length, - condition_mode=condition_mode, - regularization=regularization, - condition_embedding_dim=condition_embedding_dim, - covariates_not_pooled=covariates_not_pooled, - pooling=pooling, - pooling_kwargs=pooling_kwargs, - layers_before_pool=layers_before_pool, - layers_after_pool=layers_after_pool, - cond_output_dropout=cond_output_dropout, - condition_encoder_kwargs=condition_encoder_kwargs, - act_fn=vf_act_fn, - time_freqs=time_freqs, - time_max_period=time_max_period, - time_encoder_dims=time_encoder_dims, - time_encoder_dropout=time_encoder_dropout, - hidden_dims=hidden_dims, - hidden_dropout=hidden_dropout, - conditioning=conditioning, - conditioning_kwargs=conditioning_kwargs, - decoder_dims=decoder_dims, - decoder_dropout=decoder_dropout, - layer_norm_before_concatenation=layer_norm_before_concatenation, - linear_projection_before_concatenation=linear_projection_before_concatenation, - **vf_kwargs, - ) + if self._solver_class == _eqm.EquilibriumMatching: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + ) + else: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + time_freqs=time_freqs, + time_max_period=time_max_period, + time_encoder_dims=time_encoder_dims, + time_encoder_dropout=time_encoder_dropout, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + **vf_kwargs, + ) probability_path, noise = next(iter(probability_path.items())) if probability_path == "constant_noise": @@ -495,6 +529,16 @@ def prepare_model( rng=jax.random.PRNGKey(seed), **solver_kwargs, ) + elif self._solver_class == _eqm.EquilibriumMatching: + # EqM doesn't use probability_path, only match_fn + self._solver = self._solver_class( + vf=self.vf, + match_fn=match_fn, + optimizer=optimizer, + conditions=self.train_data.condition_data, + rng=jax.random.PRNGKey(seed), + **solver_kwargs, + ) elif self._solver_class == _genot.GENOT: self._solver = self._solver_class( vf=self.vf, @@ -508,7 +552,7 @@ def prepare_model( **solver_kwargs, ) else: - raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") + raise NotImplementedError(f"Solver must be an instance of OTFlowMatching, EquilibriumMatching, or GENOT, got {type(self.solver)}") self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type] @@ -517,9 +561,12 @@ def train( num_iterations: int, batch_size: int = 1024, valid_freq: int = 1000, + validation_batch_size: int | None = None, callbacks: Sequence[BaseCallback] = [], monitor_metrics: Sequence[str] = [], out_of_core_dataloading: bool = False, + num_workers: int = 8, # Increased from default 4 + prefetch_factor: int = 4, # Increased from default 2 ) -> None: """Train the model. @@ -539,21 +586,21 @@ def train( callbacks Callbacks to perform at each validation step. There are two types of callbacks: - Callbacks for computations should inherit from - :class:`~cellflow.training.ComputationCallback` see e.g. :class:`cellflow.training.Metrics`. - - Callbacks for logging should inherit from :class:`~cellflow.training.LoggingCallback` see - e.g. :class:`~cellflow.training.WandbLogger`. + :class:`~scaleflow.training.ComputationCallback` see e.g. :class:`scaleflow.training.Metrics`. + - Callbacks for logging should inherit from :class:`~scaleflow.training.LoggingCallback` see + e.g. :class:`~scaleflow.training.WandbLogger`. monitor_metrics Metrics to monitor. out_of_core_dataloading - If :obj:`True`, use out-of-core dataloading. Uses the :class:`cellflow.data._dataloader.OOCTrainSampler` + If :obj:`True`, use out-of-core dataloading. Uses the :class:`scaleflow.data.JaxOutOfCoreTrainSampler` to load data that does not fit into GPU memory. Returns ------- Updates the following fields: - - :attr:`cellflow.model.CellFlow.dataloader` - the training dataloader. - - :attr:`cellflow.model.CellFlow.solver` - the trained solver. + - :attr:`scaleflow.model.CellFlow.dataloader` - the training dataloader. + - :attr:`scaleflow.model.CellFlow.solver` - the trained solver. """ if self.train_data is None: raise ValueError("Data not initialized. Please call `prepare_data` first.") @@ -562,10 +609,22 @@ def train( raise ValueError("Model not initialized. Please call `prepare_model` first.") if out_of_core_dataloading: - self._dataloader = OOCTrainSampler(data=self.train_data, batch_size=batch_size) + self._dataloader = JaxOutOfCoreTrainSampler( + data=self.train_data, + batch_size=batch_size, + seed=self._seed, + num_workers=num_workers, + prefetch_factor=prefetch_factor + ) else: self._dataloader = TrainSampler(data=self.train_data, batch_size=batch_size) - validation_loaders = {k: ValidationSampler(v) for k, v in self.validation_data.items() if k != "predict_kwargs"} + + # Pass validation_batch_size to ValidationSampler + validation_loaders = { + k: ValidationSampler(v, validation_batch_size=validation_batch_size) + for k, v in self.validation_data.items() + if k != "predict_kwargs" + } self._solver = self.trainer.train( dataloader=self._dataloader, @@ -595,8 +654,8 @@ def predict( covariate_data Covariate data defining the condition to predict. This :class:`~pandas.DataFrame` should have the same columns as :attr:`~anndata.AnnData.obs` of - :attr:`cellflow.model.CellFlow.adata`, and as registered in - :attr:`cellflow.model.CellFlow.data_manager`. + :attr:`scaleflow.model.CellFlow.adata`, and as registered in + :attr:`scaleflow.model.CellFlow.data_manager`. sample_rep Key in :attr:`~anndata.AnnData.obsm` where the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. If :obj:`None`, the key is assumed to be @@ -608,12 +667,12 @@ def predict( If :obj:`None`, the predictions are not stored, and the predictions are returned as a :class:`dict`. rng - Random number generator. If :obj:`None` and :attr:`cellflow.model.CellFlow.conditino_mode` + Random number generator. If :obj:`None` and :attr:`scaleflow.model.CellFlow.conditino_mode` is ``'stochastic'``, the condition vector will be the mean of the learnt distributions, otherwise samples from the distribution. kwargs Keyword arguments for the predict function, i.e. - :meth:`cellflow.solvers.OTFlowMatching.predict` or :meth:`cellflow.solvers.GENOT.predict`. + :meth:`scaleflow.solvers.OTFlowMatching.predict` or :meth:`scaleflow.solvers.GENOT.predict`. Returns ------- @@ -679,7 +738,7 @@ def get_condition_embedding( """Get the embedding of the conditions. Outputs the mean and variance of the learnt embeddings - generated by the :class:`~cellflow.networks.ConditionEncoder`. + generated by the :class:`~scaleflow.networks.ConditionEncoder`. Parameters ---------- @@ -687,8 +746,8 @@ def get_condition_embedding( Can be one of - a :class:`~pandas.DataFrame` defining the conditions with the same columns as the - :class:`~anndata.AnnData` used for the initialisation of :class:`~cellflow.model.CellFlow`. - - an instance of :class:`~cellflow.data.ConditionData`. + :class:`~anndata.AnnData` used for the initialisation of :class:`~scaleflow.model.CellFlow`. + - an instance of :class:`~scaleflow.data.ConditionData`. rep_dict Dictionary containing the representations of the perturbation covariates. Will be considered an @@ -756,7 +815,7 @@ def save( """ Save the model. - Pickles the :class:`~cellflow.model.CellFlow` object. + Pickles the :class:`~scaleflow.model.CellFlow` object. Parameters ---------- @@ -789,7 +848,7 @@ def load( filename: str, ) -> "CellFlow": """ - Load a :class:`~cellflow.model.CellFlow` model from a saved instance. + Load a :class:`~scaleflow.model.CellFlow` model from a saved instance. Parameters ---------- @@ -816,12 +875,12 @@ def adata(self) -> ad.AnnData: return self._adata @property - def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | None: + def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None: """The solver.""" return self._solver @property - def dataloader(self) -> TrainSampler | OOCTrainSampler | None: + def dataloader(self) -> TrainSampler | JaxOutOfCoreTrainSampler | None: """The dataloader used for training.""" return self._dataloader @@ -837,13 +896,13 @@ def validation_data(self) -> dict[str, ValidationData]: @property def data_manager(self) -> DataManager: - """The data manager, initialised with :attr:`cellflow.model.CellFlow.adata`.""" + """The data manager, initialised with :attr:`scaleflow.model.CellFlow.adata`.""" return self._dm @property def velocity_field( self, - ) -> _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None: + ) -> _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None: """The conditional velocity field.""" return self._vf diff --git a/src/scaleflow/model/_scaleflow.py b/src/scaleflow/model/_scaleflow.py new file mode 100644 index 00000000..487526c2 --- /dev/null +++ b/src/scaleflow/model/_scaleflow.py @@ -0,0 +1,931 @@ +import functools +import os +import types +from collections.abc import Callable, Sequence +from dataclasses import field as dc_field +from typing import Any, Literal + +import anndata as ad +import cloudpickle +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pandas as pd +from ott.neural.methods.flows import dynamics + +from scaleflow import _constants +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler, ValidationSampler +from scaleflow.data._data import ConditionData, TrainingData, ValidationData +from scaleflow.data._datamanager import DataManager +from scaleflow.model._utils import _write_predictions +from scaleflow.networks import _velocity_field +from scaleflow.plotting import _utils +from scaleflow.solvers import _genot, _otfm, _eqm +from scaleflow.training._callbacks import BaseCallback +from scaleflow.training._trainer import CellFlowTrainer +from scaleflow.utils import match_linear + +__all__ = ["CellFlow"] + + +class CellFlow: + """CellFlow model for perturbation prediction using Flow Matching and Optimal Transport. + + CellFlow builds upon neural optimal transport estimators extending :cite:`tong:23`, + :cite:`pooladian:23`, :cite:`eyring:24`, :cite:`klein:23` which are all based on + Flow Matching :cite:`lipman:22`. + + Parameters + ---------- + adata + An :class:`~anndata.AnnData` object to extract the training data from. + solver + Solver to use for training. Either ``'otfm'``, ``'genot'`` or ``'eqm'``. + """ + + def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot", "eqm"] = "otfm"): + self._adata = adata + if solver == "otfm": + self._solver_class = _otfm.OTFlowMatching + self._vf_class = _velocity_field.ConditionalVelocityField + elif solver == "genot": + self._solver_class = _genot.GENOT + self._vf_class = _velocity_field.GENOTConditionalVelocityField + elif solver == "eqm": + self._solver_class = _eqm.EquilibriumMatching + self._vf_class = _velocity_field.EquilibriumVelocityField + else: + raise ValueError(f"Unknown solver: {solver}. Must be 'otfm', 'genot', or 'eqm'.") + self._dataloader: TrainSampler | JaxOutOfCoreTrainSampler | None = None + self._trainer: CellFlowTrainer | None = None + self._validation_data: dict[str, ValidationData] = {"predict_kwargs": {}} + self._solver: _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None = None + self._condition_dim: int | None = None + self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None = None + + def prepare_data( + self, + sample_rep: str, + control_key: str, + perturbation_covariates: dict[str, Sequence[str]], + perturbation_covariate_reps: dict[str, str] | None = None, + sample_covariates: Sequence[str] | None = None, + sample_covariate_reps: dict[str, str] | None = None, + split_covariates: Sequence[str] | None = None, + max_combination_length: int | None = None, + null_value: float = 0.0, + ) -> None: + """Prepare the dataloader for training from :attr:`~scaleflow.model.CellFlow.adata`. + + Parameters + ---------- + sample_rep + Key in :attr:`~anndata.AnnData.obsm` of :attr:`scaleflow.model.CellFlow.adata` where + the sample representation is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. + control_key + Key of a boolean column in :attr:`~anndata.AnnData.obs` of + :attr:`scaleflow.model.CellFlow.adata` that defines the control samples. + perturbation_covariates + A dictionary where the keys indicate the name of the covariate group and the values are + keys in :attr:`~anndata.AnnData.obs` of :attr:`scaleflow.model.CellFlow.adata`. The + corresponding columns can be of the following types: + + - categorial: The column contains categories whose representation is stored in + :attr:`~anndata.AnnData.uns`, see ``'perturbation_covariate_reps'``. + - boolean: The perturbation is present or absent. + - numeric: The perturbation is given as a numeric value, possibly linked to + a categorical perturbation, e.g. dosages for a drug. + + If multiple groups are provided, the first is interpreted as the primary + perturbation and the others as covariates corresponding to these perturbations. + perturbation_covariate_reps + A :class:`dict` where the keys indicate the name of the covariate group and the values + are keys in :attr:`~anndata.AnnData.uns` storing a dictionary with the representation + of the covariates. + sample_covariates + Keys in :attr:`~anndata.AnnData.obs` indicating sample covariates. Sample covariates + are defined such that each cell has only one value for each sample covariate (in + constrast to ``'perturbation_covariates'`` which can have multiple values for each + cell). If :obj:`None`, no sample + sample_covariate_reps + A dictionary where the keys indicate the name of the covariate group and the values + are keys in :attr:`~anndata.AnnData.uns` storing a dictionary with the representation + of the covariates. + split_covariates + Covariates in :attr:`~anndata.AnnData.obs` to split all control cells into different + control populations. The perturbed cells are also split according to these columns, + but if any of the ``'split_covariates'`` has a representation which should be + incorporated by the model, the corresponding column should also be used in + ``'perturbation_covariates'``. + max_combination_length + Maximum number of combinations of primary ``'perturbation_covariates'``. If + :obj:`None`, the value is inferred from the provided ``'perturbation_covariates'`` + as the maximal number of perturbations a cell has been treated with. + null_value + Value to use for padding to ``'max_combination_length'``. + + Returns + ------- + Updates the following fields: + + - :attr:`scaleflow.model.CellFlow.data_manager` - the :class:`scaleflow.data.DataManager` object. + - :attr:`scaleflow.model.CellFlow.train_data` - the training data. + + Example + ------- + Consider the case where we have combinations of drugs along with dosages, saved in + :attr:`~anndata.AnnData.obs` as columns ``drug_1`` and ``drug_2`` with three different + drugs ``DrugA``, ``DrugB``, and ``DrugC``, and ``dose_1`` and ``dose_2`` for their + dosages, respectively. We store the embeddings of the drugs in + :attr:`~anndata.AnnData.uns` under the key ``drug_embeddings``, while the dosage + columns are numeric. Moreover, we have a covariate ``cell_type`` with values + ``cell_typeA`` and ``cell_typeB``, with embeddings stored in + :attr:`~anndata.AnnData.uns` under the key ``cell_type_embeddings``. Note that we then + also have to set ``'split_covariates'`` as we assume we have an unperturbed population + for each cell type. + + .. code-block:: python + + perturbation_covariates = {{"drug": ("drug_1", "drug_2"), "dose": ("dose_1", "dose_2")}} + perturbation_covariate_reps = {"drug": "drug_embeddings"} + adata.uns["drug_embeddings"] = { + "drugA": np.array([0.1, 0.2, 0.3]), + "drugB": np.array([0.4, 0.5, 0.6]), + "drugC": np.array([-0.2, 0.3, 0.0]), + } + + sample_covariates = {"cell_type": "cell_type_embeddings"} + adata.uns["cell_type_embeddings"] = { + "cell_typeA": np.array([0.0, 1.0]), + "cell_typeB": np.array([0.0, 2.0]), + } + + split_covariates = ["cell_type"] + + cf = CellFlow(adata) + cf = cf.prepare_data( + sample_rep="X", + control_key="control", + perturbation_covariates=perturbation_covariates, + perturbation_covariate_reps=perturbation_covariate_reps, + sample_covariates=sample_covariates, + sample_covariate_reps=sample_covariate_reps, + split_covariates=split_covariates, + ) + """ + self._dm = DataManager( + self.adata, + sample_rep=sample_rep, + control_key=control_key, + perturbation_covariates=perturbation_covariates, + perturbation_covariate_reps=perturbation_covariate_reps, + sample_covariates=sample_covariates, + sample_covariate_reps=sample_covariate_reps, + split_covariates=split_covariates, + max_combination_length=max_combination_length, + null_value=null_value, + ) + + self.train_data = self._dm.get_train_data(self.adata) + self._data_dim = self.train_data.cell_data.shape[-1] # type: ignore[union-attr] + + def prepare_validation_data( + self, + adata: ad.AnnData, + name: str, + n_conditions_on_log_iteration: int | None = None, + n_conditions_on_train_end: int | None = None, + predict_kwargs: dict[str, Any] | None = None, + ) -> None: + """Prepare the validation data. + + Parameters + ---------- + adata + An :class:`~anndata.AnnData` object. + name + Name of the validation data defining the key in + :attr:`scaleflow.model.CellFlow.validation_data`. + n_conditions_on_log_iteration + Number of conditions to use for computation callbacks at each logged iteration. + If :obj:`None`, use all conditions. + n_conditions_on_train_end + Number of conditions to use for computation callbacks at the end of training. + If :obj:`None`, use all conditions. + predict_kwargs + Keyword arguments for the prediction function + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict` or + :func:`scaleflow.solvers._genot.GENOT.predict` used during validation. + + Returns + ------- + :obj:`None`, and updates the following fields: + + - :attr:`scaleflow.model.CellFlow.validation_data` - a dictionary with the validation data. + + """ + if self.train_data is None: + raise ValueError( + "Dataloader not initialized. Training data needs to be set up before preparing validation data. Please call prepare_data first." + ) + val_data = self._dm.get_validation_data( + adata, + n_conditions_on_log_iteration=n_conditions_on_log_iteration, + n_conditions_on_train_end=n_conditions_on_train_end, + ) + self._validation_data[name] = val_data + # Batched prediction is not compatible with split covariates + # as all conditions need to be the same size + split_val = len(val_data.control_to_perturbation) > 1 + predict_kwargs = predict_kwargs or {} + # Check if predict_kwargs is alreday provided from an earlier call + if "predict_kwargs" in self._validation_data and len(predict_kwargs): + self._validation_data["predict_kwargs"].update(predict_kwargs) + predict_kwargs = self._validation_data["predict_kwargs"] + # Set batched prediction to False if split_val is True + if split_val: + predict_kwargs["batched"] = False + self._validation_data["predict_kwargs"] = predict_kwargs + + def prepare_model( + self, + condition_mode: Literal["deterministic", "stochastic"] = "deterministic", + regularization: float = 0.0, + pooling: Literal["mean", "attention_token", "attention_seed"] = "attention_token", + pooling_kwargs: dict[str, Any] = types.MappingProxyType({}), + layers_before_pool: Layers_separate_input_t | Layers_t = dc_field(default_factory=lambda: []), + layers_after_pool: Layers_t = dc_field(default_factory=lambda: []), + condition_embedding_dim: int = 256, + cond_output_dropout: float = 0.9, + condition_encoder_kwargs: dict[str, Any] | None = None, + pool_sample_covariates: bool = True, + time_freqs: int = 1024, + time_max_period: int | None = 10000, + time_encoder_dims: Sequence[int] = (2048, 2048, 2048), + time_encoder_dropout: float = 0.0, + hidden_dims: Sequence[int] = (2048, 2048, 2048), + hidden_dropout: float = 0.0, + conditioning: Literal["concatenation", "film", "resnet"] = "concatenation", + conditioning_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}), + decoder_dims: Sequence[int] = (4096, 4096, 4096), + decoder_dropout: float = 0.0, + vf_act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu, + vf_kwargs: dict[str, Any] | None = None, + probability_path: dict[Literal["constant_noise", "bridge"], float] | None = None, + match_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = match_linear, + optimizer: optax.GradientTransformation = optax.MultiSteps(optax.adam(5e-5), 20), + solver_kwargs: dict[str, Any] | None = None, + layer_norm_before_concatenation: bool = False, + linear_projection_before_concatenation: bool = False, + seed=0, + ) -> None: + """Prepare the model for training. + + This function sets up the neural network architecture and specificities of the + :attr:`solver`. When :attr:`solver` is an instance of :class:`scaleflow.solvers._genot.GENOT`, + the following arguments have to be passed to ``'condition_encoder_kwargs'``: + + + Parameters + ---------- + condition_mode + Mode of the encoder, should be one of: + + - ``'deterministic'``: Learns condition encoding point-wise. + - ``'stochastic'``: Learns a Gaussian distribution for representing conditions. + + regularization + Regularization strength in the latent space: + + - For deterministic mode, it is the strength of the L2 regularization. + - For stochastic mode, it is the strength of the VAE regularization. + + pooling + Pooling method, should be one of: + + - ``'mean'``: Aggregates combinations of covariates by the mean of their + learned embeddings. + - ``'attention_token'``: Aggregates combinations of covariates by an attention + mechanism with a class token. + - ``'attention_seed'``: Aggregates combinations of covariates by seed attention. + + pooling_kwargs + Keyword arguments for the pooling method corresponding to: + + - :class:`scaleflow.networks.TokenAttentionPooling` if ``'pooling'`` is + ``'attention_token'``. + - :class:`scaleflow.networks.SeedAttentionPooling` if ``'pooling'`` is ``'attention_seed'``. + + layers_before_pool + Layers applied to the condition embeddings before pooling. Can be of type + + - :class:`tuple` with elements corresponding to dictionaries with keys: + + - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be + ``'mlp'`` or ``'self_attention'``. + - Further keyword arguments for the layer type :class:`scaleflow.networks.MLPBlock` or + :class:`scaleflow.networks.SelfAttentionBlock`. + + - :class:`dict` with keys corresponding to perturbation covariate keys, and values + correspondinng to the above mentioned tuples. + + layers_after_pool + Layers applied to the condition embeddings after pooling, and before applying the last + layer of size ``'condition_embedding_dim'``. Should be of type :class:`tuple` with + elements corresponding to dictionaries with keys: + + - ``'layer_type'`` of type :class:`str` indicating the type of the layer, can be + ``'mlp'`` or ``'self_attention'``. + - Further keys depend on the layer type, either for :class:`scaleflow.networks.MLPBlock` or + for :class:`scaleflow.networks.SelfAttentionBlock`. + + condition_embedding_dim + Dimensions of the condition embedding, i.e. the last layer of the + :class:`scaleflow.networks.ConditionEncoder`. + cond_output_dropout + Dropout rate for the last layer of the :class:`scaleflow.networks.ConditionEncoder`. + condition_encoder_kwargs + Keyword arguments for the :class:`scaleflow.networks.ConditionEncoder`. + pool_sample_covariates + Whether to include sample covariates in the pooling. + time_freqs + Frequency of the sinusoidal time encoding + (:func:`ott.neural.networks.layers.sinusoidal_time_encoder`). + time_max_period + Controls the frequency of the time embeddings, see + :func:`scaleflow.networks.utils.sinusoidal_time_encoder`. + time_encoder_dims + Dimensions of the layers processing the time embedding in + :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. + time_encoder_dropout + Dropout rate for the :attr:`scaleflow.networks.ConditionalVelocityField.time_encoder`. + hidden_dims + Dimensions of the layers processing the input to the velocity field + via :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. + hidden_dropout + Dropout rate for :attr:`scaleflow.networks.ConditionalVelocityField.x_encoder`. + conditioning + Conditioning method, should be one of: + + - ``'concatenation'``: Concatenate the time, data, and condition embeddings. + - ``'film'``: Use FiLM conditioning, i.e. learn FiLM weights from time and condition embedding + to scale the data embeddings. + - ``'resnet'``: Use residual conditioning. + + conditioning_kwargs + Keyword arguments for the conditioning method. + decoder_dims + Dimensions of the output layers in + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. + decoder_dropout + Dropout rate for the output layer + :attr:`scaleflow.networks.ConditionalVelocityField.decoder`. + vf_act_fn + Activation function of the :class:`scaleflow.networks.ConditionalVelocityField`. + vf_kwargs + Additional keyword arguments for the solver-specific vector field. + For instance, when ``'solver==genot'``, the following keyword argument can be passed: + + - ``'genot_source_dims'`` of type :class:`tuple` with the dimensions + of the :class:`scaleflow.networks.MLPBlock` processing the source cell. + - ``'genot_source_dropout'`` of type :class:`float` indicating the dropout rate + for the source cell processing. + probability_path + Probability path to use for training. Should be a :class:`dict` of the form + + - ``'{"constant_noise": noise_val'`` + - ``'{"bridge": noise_val}'`` + + If :obj:`None`, defaults to ``'{"constant_noise": 0.0}'``. + match_fn + Matching function between unperturbed and perturbed cells. Should take as input source + and target data and return the optimal transport matrix, see e.g. + :func:`scaleflow.utils.match_linear`. + optimizer + Optimizer used for training. + solver_kwargs + Keyword arguments for the solver :class:`scaleflow.solvers.OTFlowMatching` or + :class:`scaleflow.solvers.GENOT`. + layer_norm_before_concatenation + If :obj:`True`, applies layer normalization before concatenating + the embedded time, embedded data, and condition embeddings. + linear_projection_before_concatenation + If :obj:`True`, applies a linear projection before concatenating + the embedded time, embedded data, and embedded condition. + seed + Random seed. + + Returns + ------- + Updates the following fields: + + - :attr:`scaleflow.model.CellFlow.velocity_field` - an instance of the + :class:`scaleflow.networks.ConditionalVelocityField`. + - :attr:`scaleflow.model.CellFlow.solver` - an instance of :class:`scaleflow.solvers.OTFlowMatching` + or :class:`scaleflow.solvers.GENOT`. + - :attr:`scaleflow.model.CellFlow.trainer` - an instance of the + :class:`scaleflow.training.CellFlowTrainer`. + """ + if self.train_data is None: + raise ValueError("Dataloader not initialized. Please call `prepare_data` first.") + + # Store the seed for use in train method + self._seed = seed + + if condition_mode == "stochastic": + if regularization == 0.0: + raise ValueError("Stochastic condition embeddings require `regularization`>0.") + + condition_encoder_kwargs = condition_encoder_kwargs or {} + if (self._solver_class == _otfm.OTFlowMatching or self._solver_class == _eqm.EquilibriumMatching) and vf_kwargs is not None: + raise ValueError("For `solver='otfm'` or `solver='eqm'`, `vf_kwargs` must be `None`.") + if self._solver_class == _genot.GENOT: + if vf_kwargs is None: + vf_kwargs = {"genot_source_dims": [1024, 1024, 1024], "genot_source_dropout": 0.0} + else: + assert isinstance(vf_kwargs, dict) + assert "genot_source_dims" in vf_kwargs + assert "genot_source_dropout" in vf_kwargs + else: + vf_kwargs = {} + covariates_not_pooled = [] if pool_sample_covariates else self._dm.sample_covariates + solver_kwargs = solver_kwargs or {} + probability_path = probability_path or {"constant_noise": 0.0} + + if self._solver_class == _eqm.EquilibriumMatching: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + ) + else: + self.vf = self._vf_class( + output_dim=self._data_dim, + max_combination_length=self.train_data.max_combination_length, + condition_mode=condition_mode, + regularization=regularization, + condition_embedding_dim=condition_embedding_dim, + covariates_not_pooled=covariates_not_pooled, + pooling=pooling, + pooling_kwargs=pooling_kwargs, + layers_before_pool=layers_before_pool, + layers_after_pool=layers_after_pool, + cond_output_dropout=cond_output_dropout, + condition_encoder_kwargs=condition_encoder_kwargs, + act_fn=vf_act_fn, + time_freqs=time_freqs, + time_max_period=time_max_period, + time_encoder_dims=time_encoder_dims, + time_encoder_dropout=time_encoder_dropout, + hidden_dims=hidden_dims, + hidden_dropout=hidden_dropout, + conditioning=conditioning, + conditioning_kwargs=conditioning_kwargs, + decoder_dims=decoder_dims, + decoder_dropout=decoder_dropout, + layer_norm_before_concatenation=layer_norm_before_concatenation, + linear_projection_before_concatenation=linear_projection_before_concatenation, + **vf_kwargs, + ) + + probability_path, noise = next(iter(probability_path.items())) + if probability_path == "constant_noise": + probability_path = dynamics.ConstantNoiseFlow(noise) + elif probability_path == "bridge": + probability_path = dynamics.BrownianBridge(noise) + else: + raise NotImplementedError( + f"The key of `probability_path` must be `'constant_noise'` or `'bridge'` but found {probability_path}." + ) + + if self._solver_class == _otfm.OTFlowMatching: + self._solver = self._solver_class( + vf=self.vf, + match_fn=match_fn, + probability_path=probability_path, + optimizer=optimizer, + conditions=self.train_data.condition_data, + rng=jax.random.PRNGKey(seed), + **solver_kwargs, + ) + elif self._solver_class == _eqm.EquilibriumMatching: + # EqM doesn't use probability_path, only match_fn + self._solver = self._solver_class( + vf=self.vf, + match_fn=match_fn, + optimizer=optimizer, + conditions=self.train_data.condition_data, + rng=jax.random.PRNGKey(seed), + **solver_kwargs, + ) + elif self._solver_class == _genot.GENOT: + self._solver = self._solver_class( + vf=self.vf, + data_match_fn=match_fn, + probability_path=probability_path, + source_dim=self._data_dim, + target_dim=self._data_dim, + optimizer=optimizer, + conditions=self.train_data.condition_data, + rng=jax.random.PRNGKey(seed), + **solver_kwargs, + ) + else: + raise NotImplementedError(f"Solver must be an instance of OTFlowMatching, EquilibriumMatching, or GENOT, got {type(self.solver)}") + + self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type] + + def train( + self, + num_iterations: int, + batch_size: int = 1024, + valid_freq: int = 1000, + validation_batch_size: int | None = None, + callbacks: Sequence[BaseCallback] = [], + monitor_metrics: Sequence[str] = [], + out_of_core_dataloading: bool = False, + num_workers: int = 8, # Increased from default 4 + prefetch_factor: int = 4, # Increased from default 2 + ) -> None: + """Train the model. + + Note + ---- + A low value of ``'valid_freq'`` results in long training + because predictions are time-consuming compared to training steps. + + Parameters + ---------- + num_iterations + Number of iterations to train the model. + batch_size + Batch size. + valid_freq + Frequency of validation. + callbacks + Callbacks to perform at each validation step. There are two types of callbacks: + - Callbacks for computations should inherit from + :class:`~scaleflow.training.ComputationCallback` see e.g. :class:`scaleflow.training.Metrics`. + - Callbacks for logging should inherit from :class:`~scaleflow.training.LoggingCallback` see + e.g. :class:`~scaleflow.training.WandbLogger`. + monitor_metrics + Metrics to monitor. + out_of_core_dataloading + If :obj:`True`, use out-of-core dataloading. Uses the :class:`scaleflow.data.JaxOutOfCoreTrainSampler` + to load data that does not fit into GPU memory. + + Returns + ------- + Updates the following fields: + + - :attr:`scaleflow.model.CellFlow.dataloader` - the training dataloader. + - :attr:`scaleflow.model.CellFlow.solver` - the trained solver. + """ + if self.train_data is None: + raise ValueError("Data not initialized. Please call `prepare_data` first.") + + if self.trainer is None: + raise ValueError("Model not initialized. Please call `prepare_model` first.") + + if out_of_core_dataloading: + self._dataloader = JaxOutOfCoreTrainSampler( + data=self.train_data, + batch_size=batch_size, + seed=self._seed, + num_workers=num_workers, + prefetch_factor=prefetch_factor + ) + else: + self._dataloader = TrainSampler(data=self.train_data, batch_size=batch_size) + + # Pass validation_batch_size to ValidationSampler + validation_loaders = { + k: ValidationSampler(v, validation_batch_size=validation_batch_size) + for k, v in self.validation_data.items() + if k != "predict_kwargs" + } + + self._solver = self.trainer.train( + dataloader=self._dataloader, + num_iterations=num_iterations, + valid_freq=valid_freq, + valid_loaders=validation_loaders, + callbacks=callbacks, + monitor_metrics=monitor_metrics, + ) + + def predict( + self, + adata: ad.AnnData, + covariate_data: pd.DataFrame, + sample_rep: str | None = None, + condition_id_key: str | None = None, + key_added_prefix: str | None = None, + rng: ArrayLike | None = None, + **kwargs: Any, + ) -> dict[str, ArrayLike] | None: + """Predict perturbation responses. + + Parameters + ---------- + adata + An :class:`~anndata.AnnData` object with the source representation. + covariate_data + Covariate data defining the condition to predict. This :class:`~pandas.DataFrame` + should have the same columns as :attr:`~anndata.AnnData.obs` of + :attr:`scaleflow.model.CellFlow.adata`, and as registered in + :attr:`scaleflow.model.CellFlow.data_manager`. + sample_rep + Key in :attr:`~anndata.AnnData.obsm` where the sample representation is stored or + ``'X'`` to use :attr:`~anndata.AnnData.X`. If :obj:`None`, the key is assumed to be + the same as for the training data. + condition_id_key + Key in ``'covariate_data'`` defining the condition name. + key_added_prefix + If not :obj:`None`, prefix to store the prediction in :attr:`~anndata.AnnData.obsm`. + If :obj:`None`, the predictions are not stored, and the predictions are returned as a + :class:`dict`. + rng + Random number generator. If :obj:`None` and :attr:`scaleflow.model.CellFlow.conditino_mode` + is ``'stochastic'``, the condition vector will be the mean of the learnt distributions, + otherwise samples from the distribution. + kwargs + Keyword arguments for the predict function, i.e. + :meth:`scaleflow.solvers.OTFlowMatching.predict` or :meth:`scaleflow.solvers.GENOT.predict`. + + Returns + ------- + If ``'key_added_prefix'`` is :obj:`None`, a :class:`dict` with the predicted sample + representation for each perturbation, otherwise stores the predictions in + :attr:`~anndata.AnnData.obsm` and returns :obj:`None`. + """ + if self.solver is None or not self.solver.is_trained: + raise ValueError("Model not trained. Please call `train` first.") + + if sample_rep is None: + sample_rep = self._dm.sample_rep + + if adata is not None and covariate_data is not None: + if covariate_data.empty: + raise ValueError("`covariate_data` is empty.") + if self._dm.control_key not in adata.obs.columns: + raise ValueError( + f"If both `adata` and `covariate_data` are given, the control key `{self._dm.control_key}` must be in `adata.obs`." + ) + if not adata.obs[self._dm.control_key].all(): + raise ValueError( + f"If both `adata` and `covariate_data` are given, all samples in `adata` must be control samples, and thus `adata.obs[`{self._dm.control_key}`] must be set to `True` everywhere." + ) + pred_data = self._dm.get_prediction_data( + adata, + sample_rep=sample_rep, # type: ignore[arg-type] + covariate_data=covariate_data, + condition_id_key=condition_id_key, + ) + pred_loader = PredictionSampler(pred_data) + batch = pred_loader.sample() + src = batch["source"] + condition = batch.get("condition", None) + # using jax.tree.map to batch the prediction + # because PredictionSampler can return a different number of cells for each condition + out = jax.tree.map( + functools.partial(self.solver.predict, rng=rng, **kwargs), + src, + condition, # type: ignore[attr-defined] + ) + if key_added_prefix is None: + return out + if len(pred_data.control_to_perturbation) > 1: + raise ValueError( + f"When saving predictions to `adata`, all control cells must be from the same control \ + population, but found {len(pred_data.control_to_perturbation)} control populations." + ) + out_np = {k: np.array(v) for k, v in out.items()} + _write_predictions( + adata=adata, + predictions=out_np, + key_added_prefix=key_added_prefix, + ) + + def get_condition_embedding( + self, + covariate_data: pd.DataFrame | ConditionData, + rep_dict: dict[str, str] | None = None, + condition_id_key: str | None = None, + key_added: str | None = _constants.CONDITION_EMBEDDING, + ) -> tuple[pd.DataFrame, pd.DataFrame]: + """Get the embedding of the conditions. + + Outputs the mean and variance of the learnt embeddings + generated by the :class:`~scaleflow.networks.ConditionEncoder`. + + Parameters + ---------- + covariate_data + Can be one of + + - a :class:`~pandas.DataFrame` defining the conditions with the same columns as the + :class:`~anndata.AnnData` used for the initialisation of :class:`~scaleflow.model.CellFlow`. + - an instance of :class:`~scaleflow.data.ConditionData`. + + rep_dict + Dictionary containing the representations of the perturbation covariates. Will be considered an + empty dictionary if :obj:`None`. + condition_id_key + Key defining the name of the condition. Only available + if ``'covariate_data'`` is a :class:`~pandas.DataFrame`. + key_added + Key to store the condition embedding in :attr:`~anndata.AnnData.uns`. + + Returns + ------- + A :class:`tuple` of :class:`~pandas.DataFrame` with the mean and variance of the condition embeddings. + """ + if self.solver is None or not self.solver.is_trained: + raise ValueError("Model not trained. Please call `train` first.") + + if hasattr(covariate_data, "condition_data"): + cond_data = covariate_data + elif isinstance(covariate_data, pd.DataFrame): + cond_data = self._dm.get_condition_data( + covariate_data=covariate_data, + rep_dict=rep_dict, + condition_id_key=condition_id_key, + ) + else: + raise ValueError("Covariate data must be a `pandas.DataFrame` or an instance of `BaseData`.") + + condition_embeddings_mean: dict[str, ArrayLike] = {} + condition_embeddings_var: dict[str, ArrayLike] = {} + n_conditions = len(next(iter(cond_data.condition_data.values()))) + for i in range(n_conditions): + condition = {k: v[[i], :] for k, v in cond_data.condition_data.items()} + if condition_id_key: + c_key = cond_data.perturbation_idx_to_id[i] + else: + cov_combination = cond_data.perturbation_idx_to_covariates[i] + c_key = tuple(cov_combination[i] for i in range(len(cov_combination))) + condition_embeddings_mean[c_key], condition_embeddings_var[c_key] = self.solver.get_condition_embedding( + condition + ) + + df_mean = pd.DataFrame.from_dict({k: v[0] for k, v in condition_embeddings_mean.items()}).T + df_var = pd.DataFrame.from_dict({k: v[0] for k, v in condition_embeddings_var.items()}).T + + if condition_id_key: + df_mean.index.set_names([condition_id_key], inplace=True) + df_var.index.set_names([condition_id_key], inplace=True) + else: + df_mean.index.set_names(list(self._dm.perturb_covar_keys), inplace=True) + df_var.index.set_names(list(self._dm.perturb_covar_keys), inplace=True) + + if key_added is not None: + _utils.set_plotting_vars(self.adata, key=key_added, value=df_mean) + _utils.set_plotting_vars(self.adata, key=key_added, value=df_var) + + return df_mean, df_var + + def save( + self, + dir_path: str, + file_prefix: str | None = None, + overwrite: bool = False, + ) -> None: + """ + Save the model. + + Pickles the :class:`~scaleflow.model.CellFlow` object. + + Parameters + ---------- + dir_path + Path to a directory, defaults to current directory + file_prefix + Prefix to prepend to the file name. + overwrite + Overwrite existing data or not. + + Returns + ------- + :obj:`None` + """ + file_name = ( + f"{file_prefix}_{self.__class__.__name__}.pkl" + if file_prefix is not None + else f"{self.__class__.__name__}.pkl" + ) + file_dir = os.path.join(dir_path, file_name) if dir_path is not None else file_name + + if not overwrite and os.path.exists(file_dir): + raise RuntimeError(f"Unable to save to an existing file `{file_dir}` use `overwrite=True` to overwrite it.") + with open(file_dir, "wb") as f: + cloudpickle.dump(self, f) + + @classmethod + def load( + cls, + filename: str, + ) -> "CellFlow": + """ + Load a :class:`~scaleflow.model.CellFlow` model from a saved instance. + + Parameters + ---------- + filename + Path to the saved file. + + Returns + ------- + Loaded instance of the model. + """ + # Check if filename is a directory + file_name = os.path.join(filename, f"{cls.__name__}.pkl") if os.path.isdir(filename) else filename + + with open(file_name, "rb") as f: + model = cloudpickle.load(f) + + if type(model) is not cls: + raise TypeError(f"Expected the model to be type of `{cls}`, found `{type(model)}`.") + return model + + @property + def adata(self) -> ad.AnnData: + """The :class:`~anndata.AnnData` object used for training.""" + return self._adata + + @property + def solver(self) -> _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching | None: + """The solver.""" + return self._solver + + @property + def dataloader(self) -> TrainSampler | JaxOutOfCoreTrainSampler | None: + """The dataloader used for training.""" + return self._dataloader + + @property + def trainer(self) -> CellFlowTrainer | None: + """The trainer used for training.""" + return self._trainer + + @property + def validation_data(self) -> dict[str, ValidationData]: + """The validation data.""" + return self._validation_data + + @property + def data_manager(self) -> DataManager: + """The data manager, initialised with :attr:`scaleflow.model.CellFlow.adata`.""" + return self._dm + + @property + def velocity_field( + self, + ) -> _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | _velocity_field.EquilibriumVelocityField | None: + """The conditional velocity field.""" + return self._vf + + @property + def train_data(self) -> TrainingData | None: + """The training data.""" + return self._train_data + + @train_data.setter + def train_data(self, data: TrainingData) -> None: + """Set the training data.""" + if not isinstance(data, TrainingData): + raise ValueError(f"Expected `data` to be an instance of `TrainingData`, found `{type(data)}`.") + self._train_data = data + + @velocity_field.setter # type: ignore[attr-defined,no-redef] + def velocity_field(self, vf: _velocity_field.ConditionalVelocityField) -> None: + """Set the velocity field.""" + if not isinstance(vf, _velocity_field.ConditionalVelocityField): + raise ValueError(f"Expected `vf` to be an instance of `ConditionalVelocityField`, found `{type(vf)}`.") + self._vf = vf + + @property + def condition_mode(self) -> Literal["deterministic", "stochastic"]: + """The mode of the encoder.""" + return self.velocity_field.condition_mode diff --git a/src/cellflow/model/_utils.py b/src/scaleflow/model/_utils.py similarity index 96% rename from src/cellflow/model/_utils.py rename to src/scaleflow/model/_utils.py index 76384b38..920bbf77 100644 --- a/src/cellflow/model/_utils.py +++ b/src/scaleflow/model/_utils.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike def _multivariate_normal( diff --git a/src/cellflow/networks/__init__.py b/src/scaleflow/networks/__init__.py similarity index 64% rename from src/cellflow/networks/__init__.py rename to src/scaleflow/networks/__init__.py index e8051b1c..48285109 100644 --- a/src/cellflow/networks/__init__.py +++ b/src/scaleflow/networks/__init__.py @@ -1,7 +1,7 @@ -from cellflow.networks._set_encoders import ( +from scaleflow.networks._set_encoders import ( ConditionEncoder, ) -from cellflow.networks._utils import ( +from scaleflow.networks._utils import ( FilmBlock, MLPBlock, ResNetBlock, @@ -10,11 +10,12 @@ SelfAttentionBlock, TokenAttentionPooling, ) -from cellflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField +from scaleflow.networks._velocity_field import ConditionalVelocityField, GENOTConditionalVelocityField, EquilibriumVelocityField __all__ = [ "ConditionalVelocityField", "GENOTConditionalVelocityField", + "EquilibriumVelocityField", "ConditionEncoder", "MLPBlock", "SelfAttention", diff --git a/src/cellflow/networks/_set_encoders.py b/src/scaleflow/networks/_set_encoders.py similarity index 98% rename from src/cellflow/networks/_set_encoders.py rename to src/scaleflow/networks/_set_encoders.py index 74279872..8c334233 100644 --- a/src/cellflow/networks/_set_encoders.py +++ b/src/scaleflow/networks/_set_encoders.py @@ -9,8 +9,8 @@ from flax.training import train_state from flax.typing import FrozenDict -from cellflow._types import ArrayLike, Layers_separate_input_t, Layers_t -from cellflow.networks import _utils as nn_utils +from scaleflow._types import ArrayLike, Layers_separate_input_t, Layers_t +from scaleflow.networks import _utils as nn_utils __all__ = [ "ConditionEncoder", diff --git a/src/cellflow/networks/_utils.py b/src/scaleflow/networks/_utils.py similarity index 99% rename from src/cellflow/networks/_utils.py rename to src/scaleflow/networks/_utils.py index 3441330c..a6a72da5 100644 --- a/src/cellflow/networks/_utils.py +++ b/src/scaleflow/networks/_utils.py @@ -7,7 +7,7 @@ from flax import linen as nn from flax.linen import initializers -from cellflow._types import Layers_t +from scaleflow._types import Layers_t __all__ = [ "SelfAttention", diff --git a/src/cellflow/networks/_velocity_field.py b/src/scaleflow/networks/_velocity_field.py similarity index 76% rename from src/cellflow/networks/_velocity_field.py rename to src/scaleflow/networks/_velocity_field.py index 157ad4d8..113b3b40 100644 --- a/src/cellflow/networks/_velocity_field.py +++ b/src/scaleflow/networks/_velocity_field.py @@ -9,11 +9,11 @@ from flax import linen as nn from flax.training import train_state -from cellflow._types import Layers_separate_input_t, Layers_t -from cellflow.networks._set_encoders import ConditionEncoder -from cellflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder +from scaleflow._types import Layers_separate_input_t, Layers_t +from scaleflow.networks._set_encoders import ConditionEncoder +from scaleflow.networks._utils import FilmBlock, MLPBlock, ResNetBlock, sinusoidal_time_encoder -__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField"] +__all__ = ["ConditionalVelocityField", "GENOTConditionalVelocityField", "EquilibriumVelocityField"] class ConditionalVelocityField(nn.Module): @@ -238,7 +238,7 @@ def get_condition_embedding(self, condition: dict[str, jnp.ndarray]) -> tuple[jn Returns ------- Learnt mean and log-variance of the condition embedding. - If :attr:`cellflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance + If :attr:`scaleflow.model.CellFlow.condition_mode` is ``'deterministic'``, the log-variance is set to zero. """ condition_mean, condition_logvar = self.condition_encoder(condition, training=False) @@ -495,7 +495,7 @@ def setup(self): elif self.conditioning == "resnet": self.resnet_block = ResNetBlock( input_dim=self.hidden_dims[-1], - **self.conditioning_kwargs, + **conditioning_kwargs, ) elif self.conditioning == "concatenation": if len(conditioning_kwargs) > 0: @@ -587,3 +587,160 @@ def create_train_state( train=False, )["params"] return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer) + + +class EquilibriumVelocityField(nn.Module): + """Parameterized neural gradient field for Equilibrium Matching (no time conditioning). + + Same as ConditionalVelocityField but without time encoder. + """ + + output_dim: int + max_combination_length: int + condition_mode: Literal["deterministic", "stochastic"] = "deterministic" + regularization: float = 1.0 + condition_embedding_dim: int = 32 + covariates_not_pooled: Sequence[str] = dc_field(default_factory=lambda: []) + pooling: Literal["mean", "attention_token", "attention_seed"] = "attention_token" + pooling_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + layers_before_pool: Layers_separate_input_t | Layers_t = dc_field(default_factory=lambda: []) + layers_after_pool: Layers_t = dc_field(default_factory=lambda: []) + cond_output_dropout: float = 0.0 + mask_value: float = 0.0 + condition_encoder_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.silu + hidden_dims: Sequence[int] = (1024, 1024, 1024) + hidden_dropout: float = 0.0 + conditioning: Literal["concatenation", "film", "resnet"] = "concatenation" + conditioning_kwargs: dict[str, Any] = dc_field(default_factory=lambda: {}) + decoder_dims: Sequence[int] = (1024, 1024, 1024) + decoder_dropout: float = 0.0 + layer_norm_before_concatenation: bool = False + linear_projection_before_concatenation: bool = False + + def setup(self): + """Initialize the network.""" + if isinstance(self.conditioning_kwargs, dataclasses.Field): + conditioning_kwargs = dict(self.conditioning_kwargs.default_factory()) + else: + conditioning_kwargs = dict(self.conditioning_kwargs) + self.condition_encoder = ConditionEncoder( + condition_mode=self.condition_mode, + regularization=self.regularization, + output_dim=self.condition_embedding_dim, + pooling=self.pooling, + pooling_kwargs=self.pooling_kwargs, + layers_before_pool=self.layers_before_pool, + layers_after_pool=self.layers_after_pool, + covariates_not_pooled=self.covariates_not_pooled, + mask_value=self.mask_value, + **self.condition_encoder_kwargs, + ) + + self.layer_cond_output_dropout = nn.Dropout(rate=self.cond_output_dropout) + self.layer_norm_condition = nn.LayerNorm() if self.layer_norm_before_concatenation else lambda x: x + + self.x_encoder = MLPBlock( + dims=self.hidden_dims, + act_fn=self.act_fn, + dropout_rate=self.hidden_dropout, + act_last_layer=(False if self.linear_projection_before_concatenation else True), + ) + self.layer_norm_x = nn.LayerNorm() if self.layer_norm_before_concatenation else lambda x: x + + self.decoder = MLPBlock( + dims=self.decoder_dims, + act_fn=self.act_fn, + dropout_rate=self.decoder_dropout, + act_last_layer=(False if self.linear_projection_before_concatenation else True), + ) + + self.output_layer = nn.Dense(self.output_dim) + + if self.conditioning == "film": + self.film_block = FilmBlock( + input_dim=self.hidden_dims[-1], + cond_dim=self.condition_embedding_dim, # No time encoder! + **conditioning_kwargs, + ) + elif self.conditioning == "resnet": + self.resnet_block = ResNetBlock( + input_dim=self.hidden_dims[-1], + **conditioning_kwargs, + ) + elif self.conditioning == "concatenation": + if len(conditioning_kwargs) > 0: + raise ValueError("If `conditioning=='concatenation' mode, no conditioning kwargs can be passed.") + else: + raise ValueError(f"Unknown conditioning mode: {self.conditioning}") + + def __call__( + self, + x: jnp.ndarray, + cond: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + train: bool = True, + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + squeeze = x.ndim == 1 + cond_mean, cond_logvar = self.condition_encoder(cond, training=train) + if self.condition_mode == "deterministic": + cond_embedding = cond_mean + else: + cond_embedding = cond_mean + encoder_noise * jnp.exp(cond_logvar / 2.0) + + cond_embedding = self.layer_cond_output_dropout(cond_embedding, deterministic=not train) + x_encoded = self.x_encoder(x, training=train) + + x_encoded = self.layer_norm_x(x_encoded) + cond_embedding = self.layer_norm_condition(cond_embedding) + + if squeeze: + cond_embedding = jnp.squeeze(cond_embedding) + elif cond_embedding.shape[0] != x.shape[0]: + cond_embedding = jnp.tile(cond_embedding, (x.shape[0], 1)) + + if self.conditioning == "concatenation": + out = jnp.concatenate((x_encoded, cond_embedding), axis=-1) # No time! + elif self.conditioning == "film": + out = self.film_block(x_encoded, cond_embedding) # No time! + elif self.conditioning == "resnet": + out = self.resnet_block(x_encoded, cond_embedding) # No time! + else: + raise ValueError(f"Unknown conditioning mode: {self.conditioning}.") + + out = self.decoder(out, training=train) + return self.output_layer(out), cond_mean, cond_logvar + + def get_condition_embedding(self, condition: dict[str, jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]: + """Get the embedding of the condition.""" + condition_mean, condition_logvar = self.condition_encoder(condition, training=False) + return condition_mean, condition_logvar + + def create_train_state( + self, + rng: jax.Array, + optimizer: optax.OptState, + input_dim: int, + conditions: dict[str, jnp.ndarray], + ) -> train_state.TrainState: + """Create the training state.""" + x = jnp.ones((1, input_dim)) # No time variable! + encoder_noise = jnp.ones((1, self.condition_embedding_dim)) + cond = { + pert_cov: jnp.ones((1, self.max_combination_length, condition.shape[-1])) + for pert_cov, condition in conditions.items() + } + params_rng, condition_encoder_rng = jax.random.split(rng, 2) + params = self.init( + {"params": params_rng, "condition_encoder": condition_encoder_rng}, + x=x, + cond=cond, + encoder_noise=encoder_noise, + train=False, + )["params"] + return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=optimizer) + + @property + def output_dims(self): + """Dimensions of the output layers.""" + return tuple(self.decoder_dims) + (self.output_dim,) diff --git a/src/scaleflow/plotting/__init__.py b/src/scaleflow/plotting/__init__.py new file mode 100644 index 00000000..45364b19 --- /dev/null +++ b/src/scaleflow/plotting/__init__.py @@ -0,0 +1,3 @@ +from scaleflow.plotting._plotting import plot_condition_embedding + +__all__ = ["plot_condition_embedding"] diff --git a/src/cellflow/plotting/_plotting.py b/src/scaleflow/plotting/_plotting.py similarity index 96% rename from src/cellflow/plotting/_plotting.py rename to src/scaleflow/plotting/_plotting.py index 389f37d0..6cfbeef7 100644 --- a/src/cellflow/plotting/_plotting.py +++ b/src/scaleflow/plotting/_plotting.py @@ -7,8 +7,8 @@ import seaborn as sns from adjustText import adjust_text -from cellflow import _constants -from cellflow.plotting._utils import ( +from scaleflow import _constants +from scaleflow.plotting._utils import ( _compute_kernel_pca_from_df, _compute_pca_from_df, _compute_umap_from_df, @@ -38,7 +38,7 @@ def plot_condition_embedding( df A :class:`pandas.DataFrame` with embedding and metadata. Column names of embedding dimensions should be consecutive integers starting from 0, - e.g. as output from :meth:`~cellflow.model.CellFlow.get_condition_embedding`, and + e.g. as output from :meth:`~scaleflow.model.CellFlow.get_condition_embedding`, and metadata should be in columns with strings. embedding Embedding to plot. Options are "raw_embedding", "UMAP", "PCA", "Kernel_PCA". diff --git a/src/cellflow/plotting/_utils.py b/src/scaleflow/plotting/_utils.py similarity index 98% rename from src/cellflow/plotting/_utils.py rename to src/scaleflow/plotting/_utils.py index 6378122d..e89b805d 100644 --- a/src/cellflow/plotting/_utils.py +++ b/src/scaleflow/plotting/_utils.py @@ -9,7 +9,7 @@ from sklearn.decomposition import KernelPCA from sklearn.metrics.pairwise import cosine_similarity -from cellflow import _constants, _logging +from scaleflow import _constants, _logging def set_plotting_vars( diff --git a/src/scaleflow/preprocessing/__init__.py b/src/scaleflow/preprocessing/__init__.py new file mode 100644 index 00000000..36e1bff1 --- /dev/null +++ b/src/scaleflow/preprocessing/__init__.py @@ -0,0 +1,9 @@ +from scaleflow.preprocessing._gene_emb import ( + GeneInfo, + get_esm_embedding, + prot_sequence_from_ensembl, + protein_features_from_genes, +) +from scaleflow.preprocessing._pca import centered_pca, project_pca, reconstruct_pca +from scaleflow.preprocessing._preprocessing import annotate_compounds, encode_onehot, get_molecular_fingerprints +from scaleflow.preprocessing._wknn import compute_wknn, transfer_labels diff --git a/src/cellflow/preprocessing/_gene_emb.py b/src/scaleflow/preprocessing/_gene_emb.py similarity index 99% rename from src/cellflow/preprocessing/_gene_emb.py rename to src/scaleflow/preprocessing/_gene_emb.py index cbddb59f..376f25e7 100644 --- a/src/cellflow/preprocessing/_gene_emb.py +++ b/src/scaleflow/preprocessing/_gene_emb.py @@ -7,7 +7,7 @@ import anndata as ad import pandas as pd -from cellflow._logging import logger +from scaleflow._logging import logger try: import requests # type: ignore[import-untyped] @@ -21,7 +21,7 @@ EsmModel = None raise ImportError( "To use gene embedding, please install `transformers` and `torch` \ - e.g. via `pip install cellflow['embedding']`." + e.g. via `pip install scaleflow['embedding']`." ) from e __all__ = [ diff --git a/src/cellflow/preprocessing/_pca.py b/src/scaleflow/preprocessing/_pca.py similarity index 99% rename from src/cellflow/preprocessing/_pca.py rename to src/scaleflow/preprocessing/_pca.py index b6b72238..6a0dc886 100644 --- a/src/cellflow/preprocessing/_pca.py +++ b/src/scaleflow/preprocessing/_pca.py @@ -3,7 +3,7 @@ import scanpy as sc from scipy.sparse import csr_matrix -from cellflow._types import ArrayLike +from scaleflow._types import ArrayLike __all__ = ["centered_pca", "reconstruct_pca", "project_pca"] diff --git a/src/cellflow/preprocessing/_preprocessing.py b/src/scaleflow/preprocessing/_preprocessing.py similarity index 98% rename from src/cellflow/preprocessing/_preprocessing.py rename to src/scaleflow/preprocessing/_preprocessing.py index a12bd627..96149d01 100644 --- a/src/cellflow/preprocessing/_preprocessing.py +++ b/src/scaleflow/preprocessing/_preprocessing.py @@ -5,9 +5,9 @@ import numpy as np import sklearn.preprocessing as preprocessing -from cellflow._logging import logger -from cellflow._types import ArrayLike -from cellflow.data._utils import _to_list +from scaleflow._logging import logger +from scaleflow._types import ArrayLike +from scaleflow.data._utils import _to_list __all__ = ["encode_onehot", "annotate_compounds", "get_molecular_fingerprints"] diff --git a/src/cellflow/preprocessing/_wknn.py b/src/scaleflow/preprocessing/_wknn.py similarity index 99% rename from src/cellflow/preprocessing/_wknn.py rename to src/scaleflow/preprocessing/_wknn.py index 222a9dcf..5430f926 100644 --- a/src/cellflow/preprocessing/_wknn.py +++ b/src/scaleflow/preprocessing/_wknn.py @@ -6,8 +6,8 @@ import pandas as pd from scipy import sparse -from cellflow._logging import logger -from cellflow._types import ArrayLike +from scaleflow._logging import logger +from scaleflow._types import ArrayLike __all__ = ["compute_wknn", "transfer_labels"] diff --git a/src/scaleflow/solvers/__init__.py b/src/scaleflow/solvers/__init__.py new file mode 100644 index 00000000..6c7aa964 --- /dev/null +++ b/src/scaleflow/solvers/__init__.py @@ -0,0 +1,5 @@ +from scaleflow.solvers._genot import GENOT +from scaleflow.solvers._otfm import OTFlowMatching +from scaleflow.solvers._eqm import EquilibriumMatching + +__all__ = ["GENOT", "OTFlowMatching", "EquilibriumMatching"] diff --git a/src/scaleflow/solvers/_eqm.py b/src/scaleflow/solvers/_eqm.py new file mode 100644 index 00000000..af436bf8 --- /dev/null +++ b/src/scaleflow/solvers/_eqm.py @@ -0,0 +1,356 @@ +# /home/icb/alejandro.tejada/CellFlow2/src/scaleflow/solvers/_eqm.py + +from collections.abc import Callable +from functools import partial +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core import frozen_dict +from flax.training import train_state +from ott.solvers import utils as solver_utils + +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import ConditionalVelocityField +from scaleflow.solvers.utils import ema_update + +__all__ = ["EquilibriumMatching"] + + +class EquilibriumMatching: + """Equilibrium Matching for generative modeling. + + Based on "Equilibrium Matching" (Wang & Du, 2024). + Learns a time-invariant equilibrium gradient field instead of + time-conditional velocities. + + Parameters + ---------- + vf + Vector field parameterized by a neural network (without time conditioning). + match_fn + Function to match samples from the source and the target + distributions. It has a ``(src, tgt) -> matching`` signature, + see e.g. :func:`scaleflow.utils.match_linear`. If :obj:`None`, no + matching is performed. + gamma_sampler + Noise level sampler with a ``(rng, n_samples) -> gamma`` signature. + Defaults to uniform sampling on [0, 1]. + c_fn + Weighting function c(gamma). Defaults to c(gamma) = 1 - gamma. + kwargs + Keyword arguments for :meth:`scaleflow.networks.ConditionalVelocityField.create_train_state`. + """ + + def __init__( + self, + vf: ConditionalVelocityField, + match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, + gamma_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, + c_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None, + **kwargs: Any, + ): + self._is_trained: bool = False + self.vf = vf + self.condition_encoder_mode = self.vf.condition_mode + self.condition_encoder_regularization = self.vf.regularization + self.gamma_sampler = gamma_sampler + self.c_fn = c_fn if c_fn is not None else lambda gamma: 1.0 - gamma + self.match_fn = jax.jit(match_fn) if match_fn is not None else None + self.ema = kwargs.pop("ema", 1.0) + + self.vf_state = self.vf.create_train_state(input_dim=self.vf.output_dims[-1], **kwargs) + self.vf_state_inference = self.vf.create_train_state(input_dim=self.vf.output_dims[-1], **kwargs) + self.vf_step_fn = self._get_vf_step_fn() + + def _get_vf_step_fn(self) -> Callable: + @jax.jit + def vf_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + gamma: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + ): + def loss_fn( + params: jnp.ndarray, + gamma: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_encoder, rng_dropout = jax.random.split(rng, 2) + + # Interpolate between source (noise) and target (data) + gamma_expanded = gamma[:, jnp.newaxis] + x_gamma = gamma_expanded * target + (1.0 - gamma_expanded) * source + + # Predict gradient field (no time input) + f_pred, mean_cond, logvar_cond = vf_state.apply_fn( + {"params": params}, + x_gamma, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + + # Target gradient: (source - target) * c(gamma) + c_gamma = self.c_fn(gamma)[:, jnp.newaxis] + target_gradient = (source - target) * c_gamma + + # EqM loss + eqm_loss = jnp.mean((f_pred - target_gradient) ** 2) + + # Condition encoder regularization (same as flow matching) + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + + return eqm_loss + encoder_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(vf_state.params, gamma, source, target, conditions, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return vf_step_fn + + def step_fn( + self, + rng: jnp.ndarray, + batch: dict[str, ArrayLike], + ) -> float: + """Single step function of the solver. + + Parameters + ---------- + rng + Random number generator. + batch + Data batch with keys ``src_cell_data``, ``tgt_cell_data``, and + optionally ``condition``. + + Returns + ------- + Loss value. + """ + src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] + condition = batch.get("condition") + rng_resample, rng_gamma, rng_step_fn, rng_encoder_noise = jax.random.split(rng, 4) + n = src.shape[0] + gamma = self.gamma_sampler(rng_gamma, n).squeeze() + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + if self.match_fn is not None: + tmat = self.match_fn(src, tgt) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) + src, tgt = src[src_ixs], tgt[tgt_ixs] + + self.vf_state, loss = self.vf_step_fn( + rng_step_fn, + self.vf_state, + gamma, + src, + tgt, + condition, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_numpy=True) -> ArrayLike: + """Get learnt embeddings of the conditions. + + Parameters + ---------- + condition + Conditions to encode + return_as_numpy + Whether to return the embeddings as numpy arrays. + + Returns + ------- + Mean and log-variance of encoded conditions. + """ + cond_mean, cond_logvar = self.vf.apply( + {"params": self.vf_state_inference.params}, + condition, + method="get_condition_embedding", + ) + if return_as_numpy: + return np.asarray(cond_mean), np.asarray(cond_logvar) + return cond_mean, cond_logvar + + def _predict_jit( + self, + x: ArrayLike, + condition: dict[str, ArrayLike], + rng: jax.Array | None = None, + eta: float = 0.003, + max_steps: int = 250, + use_nesterov: bool = True, + mu: float = 0.35, + **kwargs: Any, + ) -> ArrayLike: + """Predict using gradient descent sampling. + + Parameters + ---------- + x + Initial samples (typically noise). + condition + Conditioning information. + rng + Random number generator for stochastic conditioning. + eta + Step size for gradient descent. + max_steps + Maximum number of gradient descent steps. + use_nesterov + Whether to use Nesterov accelerated gradient. + mu + Momentum parameter for Nesterov. + + Returns + ------- + Generated samples. + """ + noise_dim = (1, self.vf.condition_embedding_dim) + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim) + + def gradient_field(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + params = self.vf_state_inference.params + return self.vf_state_inference.apply_fn({"params": params}, x, condition, encoder_noise, train=False)[0] + + def sample_gd(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + """Basic gradient descent sampler.""" + for _ in range(max_steps): + f = gradient_field(x, condition, encoder_noise) + x = x - eta * f + return x + + def sample_nag(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + """Nesterov accelerated gradient descent sampler.""" + velocity = jnp.zeros_like(x) + for _ in range(max_steps): + x_lookahead = x - mu * velocity + f = gradient_field(x_lookahead, condition, encoder_noise) + velocity = mu * velocity + eta * f + x = x - velocity + return x + + sampler = sample_nag if use_nesterov else sample_gd + x_pred = jax.jit(jax.vmap(sampler, in_axes=[0, None, None]))(x, condition, encoder_noise) + return x_pred + + def predict( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + eta: float = 0.003, + max_steps: int = 250, + use_nesterov: bool = True, + mu: float = 0.35, + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict the translated source ``x`` under condition ``condition``. + + This function performs gradient descent on the learned equilibrium landscape. + + Parameters + ---------- + x + A dictionary with keys indicating the name of the condition and values containing + the input data as arrays. If ``batched=False`` provide an array of shape [batch_size, ...]. + condition + A dictionary with keys indicating the name of the condition and values containing + the condition of input data as arrays. If ``batched=False`` provide an array of shape + [batch_size, ...]. + rng + Random number generator to sample from the latent distribution, + only used if ``condition_mode='stochastic'``. If :obj:`None`, the + mean embedding is used. + batched + Whether to use batched prediction. + eta + Step size for gradient descent (default: 0.003 as in paper). + max_steps + Number of gradient descent steps (default: 250 as in paper). + use_nesterov + Whether to use Nesterov accelerated gradient (recommended). + mu + Momentum parameter for Nesterov (default: 0.35 as in paper). + kwargs + Additional keyword arguments (for compatibility). + + Returns + ------- + The push-forward distribution of ``x`` under condition ``condition``. + """ + if batched and not x: + return {} + + predict_fn = partial( + self._predict_jit, + rng=rng, + eta=eta, + max_steps=max_steps, + use_nesterov=use_nesterov, + mu=mu, + **kwargs, + ) + + if batched: + keys = sorted(x.keys()) + condition_keys = sorted(set().union(*(condition[k].keys() for k in keys))) + _predict_jit = jax.jit(lambda x, condition: predict_fn(x, condition)) + batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0))) + n_cells = x[keys[0]].shape[0] + for k in keys: + assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition" + src_inputs = jnp.stack([x[k] for k in keys], axis=0) + batched_conditions = {} + for cond_key in condition_keys: + batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys]) + + pred_targets = batched_predict(src_inputs, batched_conditions) + return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + return jax.tree.map( + predict_fn, + x, + condition, + ) + else: + x_pred = predict_fn(x, condition) + return np.array(x_pred) + + @property + def is_trained(self) -> bool: + """Whether the model is trained.""" + return self._is_trained + + @is_trained.setter + def is_trained(self, value: bool) -> None: + self._is_trained = value diff --git a/src/cellflow/solvers/_genot.py b/src/scaleflow/solvers/_genot.py similarity index 96% rename from src/cellflow/solvers/_genot.py rename to src/scaleflow/solvers/_genot.py index 7270ad7f..588a6036 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/scaleflow/solvers/_genot.py @@ -11,9 +11,9 @@ from ott.neural.networks import velocity_field from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.model._utils import _multivariate_normal +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.model._utils import _multivariate_normal __all__ = ["GENOT"] @@ -240,7 +240,7 @@ def predict( """Generate the push-forward of ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -257,7 +257,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. @@ -284,6 +284,13 @@ def predict( pred_targets = batched_predict(src_inputs, batched_conditions) return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + predict_fn = functools.partial(self._predict_jit, rng=rng, rng_genot=rng_genot, **kwargs) + return jax.tree.map( + predict_fn, + x, + condition, + ) else: x_pred = self._predict_jit(x, condition, rng, rng_genot, **kwargs) return np.array(x_pred) diff --git a/src/scaleflow/solvers/_multitask_otfm.py b/src/scaleflow/solvers/_multitask_otfm.py new file mode 100644 index 00000000..c8903df6 --- /dev/null +++ b/src/scaleflow/solvers/_multitask_otfm.py @@ -0,0 +1,399 @@ +from collections.abc import Callable +from functools import partial +from typing import Any + +import diffrax +import jax +import jax.numpy as jnp +import numpy as np +from flax.core import frozen_dict +from flax.training import train_state +from ott.neural.methods.flows import dynamics +from ott.solvers import utils as solver_utils + +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import MultiTaskConditionalVelocityField +from scaleflow.solvers.utils import ema_update + +__all__ = ["MultiTaskOTFlowMatching"] + + +class MultiTaskOTFlowMatching: + """Multi-task OT Flow Matching for both single-cell and phenotype prediction. + + This solver extends the standard OT Flow Matching to handle both flow matching + for single-cell data and phenotype prediction tasks, enabling transfer learning + between the two modalities through shared condition encodings. + + Parameters + ---------- + vf + Multi-task velocity field parameterized by a neural network. + probability_path + Probability path between the source and the target distributions. + match_fn + Function to match samples from the source and the target distributions. + time_sampler + Time sampler with a ``(rng, n_samples) -> time`` signature. + phenotype_loss_weight + Weight for the phenotype prediction loss relative to flow matching loss. + ema + Exponential moving average parameter for inference state. + kwargs + Keyword arguments for velocity field initialization. + """ + + def __init__( + self, + vf: MultiTaskConditionalVelocityField, + probability_path: dynamics.BaseFlow, + match_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] | None = None, + time_sampler: Callable[[jax.Array, int], jnp.ndarray] = solver_utils.uniform_sampler, + phenotype_loss_weight: float = 1.0, + ema: float = 0.999, + **kwargs: Any, + ): + self._is_trained: bool = False + self.vf = vf + self.condition_encoder_mode = self.vf.condition_mode + self.condition_encoder_regularization = self.vf.regularization + self.probability_path = probability_path + self.match_fn = match_fn + self.time_sampler = time_sampler + self.phenotype_loss_weight = phenotype_loss_weight + self.ema = ema + + self.vf_state = self.vf.create_train_state(**kwargs) + self.vf_state_inference = self.vf_state + self.vf_step_fn = self._get_vf_step_fn() + + def _get_vf_step_fn(self) -> Callable: + @jax.jit + def vf_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + time: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + ): + def loss_fn( + params: jnp.ndarray, + t: jnp.ndarray, + source: jnp.ndarray, + target: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_flow, rng_encoder, rng_dropout = jax.random.split(rng, 3) + x_t = self.probability_path.compute_xt(rng_flow, t, source, target) + v_t, mean_cond, logvar_cond, _ = vf_state.apply_fn( + {"params": params}, + t, + x_t, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + u_t = self.probability_path.compute_ut(t, x_t, source, target) + flow_matching_loss = jnp.mean((v_t - u_t) ** 2) + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + return flow_matching_loss + encoder_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(vf_state.params, time, source, target, conditions, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return vf_step_fn + + def _get_phenotype_step_fn(self) -> Callable: + @jax.jit + def phenotype_step_fn( + rng: jax.Array, + vf_state: train_state.TrainState, + conditions: dict[str, jnp.ndarray], + phenotype_targets: jnp.ndarray, + encoder_noise: jnp.ndarray, + ): + def phenotype_loss_fn( + params: jnp.ndarray, + conditions: dict[str, jnp.ndarray], + phenotype_targets: jnp.ndarray, + encoder_noise: jnp.ndarray, + rng: jax.Array, + ) -> jnp.ndarray: + rng_encoder, rng_dropout = jax.random.split(rng, 2) + + # Create dummy inputs for flow matching components + n = phenotype_targets.shape[0] + dummy_t = jnp.zeros(n) + dummy_x = jnp.zeros((n, self.vf.output_dim)) + + # Forward pass through multi-task velocity field + _, mean_cond, logvar_cond, phenotype_pred = vf_state.apply_fn( + {"params": params}, + dummy_t, + dummy_x, + conditions, + encoder_noise=encoder_noise, + rngs={"dropout": rng_dropout, "condition_encoder": rng_encoder}, + ) + + # Phenotype prediction loss (MSE for regression) + phenotype_loss = jnp.mean((phenotype_pred.squeeze() - phenotype_targets) ** 2) + + # Same condition regularization as flow matching + condition_mean_regularization = 0.5 * jnp.mean(mean_cond**2) + condition_var_regularization = -0.5 * jnp.mean(1 + logvar_cond - jnp.exp(logvar_cond)) + if self.condition_encoder_mode == "stochastic": + encoder_loss = condition_mean_regularization + condition_var_regularization + elif (self.condition_encoder_mode == "deterministic") and (self.condition_encoder_regularization > 0): + encoder_loss = condition_mean_regularization + else: + encoder_loss = 0.0 + + return self.phenotype_loss_weight * phenotype_loss + encoder_loss + + grad_fn = jax.value_and_grad(phenotype_loss_fn) + loss, grads = grad_fn(vf_state.params, conditions, phenotype_targets, encoder_noise, rng) + return vf_state.apply_gradients(grads=grads), loss + + return phenotype_step_fn + + def step_fn( + self, + rng: jnp.ndarray, + batch: dict[str, ArrayLike], + ) -> float: + """Single step function handling both flow matching and phenotype tasks. + + Parameters + ---------- + rng + Random number generator. + batch + Data batch. For flow matching: ``src_cell_data``, ``tgt_cell_data``, ``condition``. + For phenotype: ``condition``, ``phenotype_target``, ``task``. + + Returns + ------- + Loss value. + """ + task = batch.get("task", "flow_matching") + + if task == "phenotype": + return self._phenotype_step(rng, batch) + else: + return self._flow_matching_step(rng, batch) + + def _flow_matching_step(self, rng: jnp.ndarray, batch: dict[str, ArrayLike]) -> float: + """Handle flow matching step.""" + src, tgt = batch["src_cell_data"], batch["tgt_cell_data"] + condition = batch.get("condition") + rng_resample, rng_time, rng_step_fn, rng_encoder_noise = jax.random.split(rng, 4) + n = src.shape[0] + time = self.time_sampler(rng_time, n) + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + if self.match_fn is not None: + tmat = self.match_fn(src, tgt) + src_ixs, tgt_ixs = solver_utils.sample_joint(rng_resample, tmat) + src, tgt = src[src_ixs], tgt[tgt_ixs] + + self.vf_state, loss = self.vf_step_fn( + rng_step_fn, + self.vf_state, + time, + src, + tgt, + condition, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def _phenotype_step(self, rng: jnp.ndarray, batch: dict[str, ArrayLike]) -> float: + """Handle phenotype prediction step.""" + condition = batch["condition"] + phenotype_target = batch["phenotype_target"] + rng_step_fn, rng_encoder_noise = jax.random.split(rng, 2) + n = phenotype_target.shape[0] + encoder_noise = jax.random.normal(rng_encoder_noise, (n, self.vf.condition_embedding_dim)) + + phenotype_step_fn = self._get_phenotype_step_fn() + self.vf_state, loss = phenotype_step_fn( + rng_step_fn, + self.vf_state, + condition, + phenotype_target, + encoder_noise, + ) + + if self.ema == 1.0: + self.vf_state_inference = self.vf_state + else: + self.vf_state_inference = self.vf_state_inference.replace( + params=ema_update(self.vf_state_inference.params, self.vf_state.params, self.ema) + ) + return loss + + def get_condition_embedding(self, condition: dict[str, ArrayLike], return_as_numpy=True) -> ArrayLike: + """Get learnt embeddings of the conditions.""" + cond_mean, cond_logvar = self.vf.apply( + {"params": self.vf_state_inference.params}, + condition, + method="get_condition_embedding", + ) + if return_as_numpy: + return np.asarray(cond_mean), np.asarray(cond_logvar) + return cond_mean, cond_logvar + + def predict( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + task: str = "flow_matching", + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict either flow matching or phenotype outcomes. + + Parameters + ---------- + x + Input data (ignored for phenotype prediction). + condition + Condition dictionary. + rng + Random number generator. + batched + Whether to use batched prediction. + task + Either "flow_matching" or "phenotype". + kwargs + Additional arguments for ODE solver. + + Returns + ------- + Predictions based on the specified task. + """ + if task == "phenotype": + return self._predict_phenotype(condition, rng) + else: + return self._predict_flow_matching(x, condition, rng, batched, **kwargs) + + def _predict_phenotype( + self, + condition: dict[str, ArrayLike], + rng: jax.Array | None = None + ) -> ArrayLike: + """Predict phenotype values.""" + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + + # Get condition shape + first_cond = next(iter(condition.values())) + n_samples = first_cond.shape[0] + + encoder_noise = jnp.zeros((n_samples, self.vf.condition_embedding_dim)) if use_mean else \ + jax.random.normal(rng, (n_samples, self.vf.condition_embedding_dim)) + + phenotype_pred = self.vf_state_inference.apply_fn( + {"params": self.vf_state_inference.params}, + method="predict_phenotype", + cond=condition, + encoder_noise=encoder_noise, + train=False + ) + return np.array(phenotype_pred) + + def _predict_flow_matching( + self, + x: ArrayLike | dict[str, ArrayLike], + condition: dict[str, ArrayLike] | dict[str, dict[str, ArrayLike]], + rng: jax.Array | None = None, + batched: bool = False, + **kwargs: Any, + ) -> ArrayLike | dict[str, ArrayLike]: + """Predict flow matching outcomes (same as original OTFM).""" + if batched and not x: + return {} + + if batched: + keys = sorted(x.keys()) + condition_keys = sorted(set().union(*(condition[k].keys() for k in keys))) + _predict_jit = jax.jit(lambda x, condition: self._predict_jit(x, condition, rng, **kwargs)) + batched_predict = jax.vmap(_predict_jit, in_axes=(0, dict.fromkeys(condition_keys, 0))) + n_cells = x[keys[0]].shape[0] + for k in keys: + assert x[k].shape[0] == n_cells, "The number of cells must be the same for each condition" + src_inputs = jnp.stack([x[k] for k in keys], axis=0) + batched_conditions = {} + for cond_key in condition_keys: + batched_conditions[cond_key] = jnp.stack([condition[k][cond_key] for k in keys]) + pred_targets = batched_predict(src_inputs, batched_conditions) + return {k: pred_targets[i] for i, k in enumerate(keys)} + else: + x_pred = self._predict_jit(x, condition, rng, **kwargs) + return np.array(x_pred) + + def _predict_jit( + self, x: ArrayLike, condition: dict[str, ArrayLike], rng: jax.Array | None = None, **kwargs: Any + ) -> ArrayLike: + """JIT-compiled prediction for flow matching.""" + kwargs.setdefault("dt0", None) + kwargs.setdefault("solver", diffrax.Tsit5()) + kwargs.setdefault("stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5)) + + noise_dim = (1, self.vf.condition_embedding_dim) + use_mean = rng is None or self.condition_encoder_mode == "deterministic" + rng = utils.default_prng_key(rng) + encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim) + + def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray: + params = self.vf_state_inference.params + condition, encoder_noise = args + # Only use flow matching output (first element) + return self.vf_state_inference.apply_fn({"params": params}, t, x, condition, encoder_noise, train=False)[0] + + def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray: + ode_term = diffrax.ODETerm(vf) + result = diffrax.diffeqsolve( + ode_term, + t0=0.0, + t1=1.0, + y0=x, + args=(condition, encoder_noise), + **kwargs, + ) + return result.ys[0] + + x_pred = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))(x, condition, encoder_noise) + return x_pred + + @property + def is_trained(self) -> bool: + """Whether the model has been trained.""" + return self._is_trained + + @is_trained.setter + def is_trained(self, value: bool) -> None: + """Set the trained status.""" + self._is_trained = value diff --git a/src/cellflow/solvers/_otfm.py b/src/scaleflow/solvers/_otfm.py similarity index 95% rename from src/cellflow/solvers/_otfm.py rename to src/scaleflow/solvers/_otfm.py index 31114a6b..e987b8e1 100644 --- a/src/cellflow/solvers/_otfm.py +++ b/src/scaleflow/solvers/_otfm.py @@ -11,10 +11,10 @@ from ott.neural.methods.flows import dynamics from ott.solvers import utils as solver_utils -from cellflow import utils -from cellflow._types import ArrayLike -from cellflow.networks._velocity_field import ConditionalVelocityField -from cellflow.solvers.utils import ema_update +from scaleflow import utils +from scaleflow._types import ArrayLike +from scaleflow.networks._velocity_field import ConditionalVelocityField +from scaleflow.solvers.utils import ema_update __all__ = ["OTFlowMatching"] @@ -34,14 +34,14 @@ class OTFlowMatching: match_fn Function to match samples from the source and the target distributions. It has a ``(src, tgt) -> matching`` signature, - see e.g. :func:`cellflow.utils.match_linear`. If :obj:`None`, no + see e.g. :func:`scaleflow.utils.match_linear`. If :obj:`None`, no matching is performed, and pure probability_path matching :cite:`lipman:22` is applied. time_sampler Time sampler with a ``(rng, n_samples) -> time`` signature, see e.g. :func:`ott.solvers.utils.uniform_sampler`. kwargs - Keyword arguments for :meth:`cellflow.networks.ConditionalVelocityField.create_train_state`. + Keyword arguments for :meth:`scaleflow.networks.ConditionalVelocityField.create_train_state`. """ def __init__( @@ -231,7 +231,7 @@ def predict( """Predict the translated source ``x`` under condition ``condition``. This function solves the ODE learnt with - the :class:`~cellflow.networks.ConditionalVelocityField`. + the :class:`~scaleflow.networks.ConditionalVelocityField`. Parameters ---------- @@ -249,7 +249,7 @@ def predict( batched Whether to use batched prediction. This is only supported if the input has the same number of cells for each condition. For example, this works when using - :class:`~cellflow.data.ValidationSampler` to sample the validation data. + :class:`~scaleflow.data.ValidationSampler` to sample the validation data. kwargs Keyword arguments for :func:`diffrax.diffeqsolve`. diff --git a/src/cellflow/solvers/utils.py b/src/scaleflow/solvers/utils.py similarity index 100% rename from src/cellflow/solvers/utils.py rename to src/scaleflow/solvers/utils.py diff --git a/src/cellflow/training/__init__.py b/src/scaleflow/training/__init__.py similarity index 79% rename from src/cellflow/training/__init__.py rename to src/scaleflow/training/__init__.py index 387411d2..c19a50dd 100644 --- a/src/cellflow/training/__init__.py +++ b/src/scaleflow/training/__init__.py @@ -1,4 +1,4 @@ -from cellflow.training._callbacks import ( +from scaleflow.training._callbacks import ( BaseCallback, CallbackRunner, ComputationCallback, @@ -8,7 +8,7 @@ VAEDecodedMetrics, WandbLogger, ) -from cellflow.training._trainer import CellFlowTrainer +from scaleflow.training._trainer import CellFlowTrainer __all__ = [ "CellFlowTrainer", diff --git a/src/cellflow/training/_callbacks.py b/src/scaleflow/training/_callbacks.py similarity index 91% rename from src/cellflow/training/_callbacks.py rename to src/scaleflow/training/_callbacks.py index 5b65f33f..82fef53c 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/scaleflow/training/_callbacks.py @@ -7,14 +7,14 @@ import jax.tree_util as jtu import numpy as np -from cellflow._types import ArrayLike -from cellflow.metrics._metrics import ( +from scaleflow._types import ArrayLike +from scaleflow.metrics._metrics import ( compute_e_distance_fast, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div, ) -from cellflow.solvers import _genot, _otfm +from scaleflow.solvers import _genot, _otfm __all__ = [ "BaseCallback", @@ -42,7 +42,7 @@ class BaseCallback(abc.ABC): - """Base class for callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self, *args: Any, **kwargs: Any) -> None: @@ -61,7 +61,7 @@ def on_train_end(self, *args: Any, **kwargs: Any) -> Any: class LoggingCallback(BaseCallback, abc.ABC): - """Base class for logging callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -92,7 +92,7 @@ def on_train_end(self, dict_to_log: dict[str, Any]) -> Any: class ComputationCallback(BaseCallback, abc.ABC): - """Base class for computation callbacks in the :class:`~cellflow.training.CellFlowTrainer`""" + """Base class for computation callbacks in the :class:`~scaleflow.training.CellFlowTrainer`""" @abc.abstractmethod def on_train_begin(self) -> Any: @@ -118,7 +118,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -146,7 +146,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -205,7 +205,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -240,7 +240,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -299,7 +299,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -328,7 +328,7 @@ class VAEDecodedMetrics(Metrics): ---------- vae A VAE model object with a ``'get_reconstruction'`` method, can be an instance - of :class:`cellflow.external.CFJaxSCVI`. + of :class:`scaleflow.external.CFJaxSCVI`. adata An :class:`~anndata.AnnData` object in the same format as the ``vae``. metrics @@ -374,7 +374,7 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns @@ -477,14 +477,14 @@ def on_train_end(self, dict_to_log: dict[str, float]) -> Any: class CallbackRunner: - """Runs a set of computational and logging callbacks in the :class:`~cellflow.training.CellFlowTrainer` + """Runs a set of computational and logging callbacks in the :class:`~scaleflow.training.CellFlowTrainer` Parameters ---------- callbacks List of callbacks to run. Callbacks should be of type - :class:`~cellflow.training.ComputationCallback` or - :class:`~cellflow.training.LoggingCallback` + :class:`~scaleflow.training.ComputationCallback` or + :class:`~scaleflow.training.LoggingCallback` Returns ------- @@ -517,6 +517,7 @@ def on_log_iteration( valid_data: dict[str, dict[str, ArrayLike]], pred_data: dict[str, dict[str, ArrayLike]], solver: _otfm.OTFlowMatching | _genot.GENOT, + additional_metrics: dict[str, Any] | None = None, ) -> dict[str, Any]: """Called at each validation/log iteration to run callbacks. First computes metrics with computation callbacks and then logs data with logging callbacks. @@ -529,8 +530,10 @@ def on_log_iteration( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. + additional_metrics + Optional dictionary of metrics to include before computing validation metrics (e.g., train_loss) Returns ------- @@ -538,6 +541,10 @@ def on_log_iteration( """ dict_to_log: dict[str, Any] = {} + # Add additional metrics first + if additional_metrics is not None: + dict_to_log.update(additional_metrics) + for callback in self.computation_callbacks: results = callback.on_log_iteration(valid_source_data, valid_data, pred_data, solver) dict_to_log.update(results) @@ -565,7 +572,7 @@ def on_train_end( valid_pred_data Predicted data in nested dictionary format with same keys as ``valid_true_data`` solver - :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + :class:`~scaleflow.solvers.OTFlowMatching` solver or :class:`~scaleflow.solvers.GENOT` solver with a conditional velocity field. Returns diff --git a/src/cellflow/training/_trainer.py b/src/scaleflow/training/_trainer.py similarity index 67% rename from src/cellflow/training/_trainer.py rename to src/scaleflow/training/_trainer.py index f0130690..ad41b9d5 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/scaleflow/training/_trainer.py @@ -6,25 +6,27 @@ from numpy.typing import ArrayLike from tqdm import tqdm -from cellflow.data._dataloader import OOCTrainSampler, TrainSampler, ValidationSampler -from cellflow.solvers import _genot, _otfm -from cellflow.training._callbacks import BaseCallback, CallbackRunner +from scaleflow.data import JaxOutOfCoreTrainSampler, TrainSampler, ValidationSampler +from scaleflow.solvers import _eqm, _genot, _otfm +from scaleflow.training._callbacks import BaseCallback, CallbackRunner class CellFlowTrainer: - """Trainer for the OTFM/GENOT solver with a conditional velocity field. + """Trainer for the OTFM/GENOT/EqM solver with a conditional velocity field. Parameters ---------- dataloader Data sampler. solver - :class:`~cellflow.solvers._otfm.OTFlowMatching` or - :class:`~cellflow.solvers._genot.GENOT` solver with a conditional velocity field. + :class:`~scaleflow.solvers._otfm.OTFlowMatching`, + :class:`~scaleflow.solvers._genot.GENOT`, or + :class:`~scaleflow.solvers._eqm.EquilibriumMatching` solver with a conditional velocity field. predict_kwargs Keyword arguments for the prediction functions - :func:`cellflow.solvers._otfm.OTFlowMatching.predict` or - :func:`cellflow.solvers._genot.GENOT.predict` used during validation. + :func:`scaleflow.solvers._otfm.OTFlowMatching.predict`, + :func:`scaleflow.solvers._genot.GENOT.predict`, or + :func:`scaleflow.solvers._eqm.EquilibriumMatching.predict` used during validation. seed Random seed for subsampling validation data. @@ -35,12 +37,12 @@ class CellFlowTrainer: def __init__( self, - solver: _otfm.OTFlowMatching | _genot.GENOT, + solver: _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching, predict_kwargs: dict[str, Any] | None = None, seed: int = 0, ): - if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)): - raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}") + if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching)): + raise NotImplementedError(f"Solver must be an instance of OTFlowMatching, GENOT, or EquilibriumMatching, got {type(solver)}") self.solver = solver self.predict_kwargs = predict_kwargs or {} @@ -61,15 +63,34 @@ def _validation_step( valid_source_data: dict[str, dict[str, ArrayLike]] = {} valid_pred_data: dict[str, dict[str, ArrayLike]] = {} valid_true_data: dict[str, dict[str, ArrayLike]] = {} - for val_key, vdl in val_data.items(): + + # Add progress bar for validation + val_pbar = tqdm(val_data.items(), desc="Validation", leave=False) + for val_key, vdl in val_pbar: batch = vdl.sample(mode=mode) src = batch["source"] + print(len(src)) + key0 = list(src.keys())[0] + key1 = list(src.keys())[1] + key2 = list(src.keys())[2] + print(key0) + print(key1) + print(key2) + print(src[key0].shape) + print(src[key1].shape) + print(src[key2].shape) + print(batch["condition"][key0]) + print(batch["condition"][key1]) condition = batch.get("condition", None) true_tgt = batch["target"] valid_source_data[val_key] = src valid_pred_data[val_key] = self.solver.predict(src, condition=condition, **self.predict_kwargs) valid_true_data[val_key] = true_tgt + print("Predictions done") + # Update progress bar description with current validation set + val_pbar.set_description(f"Validation ({val_key})") + return valid_source_data, valid_true_data, valid_pred_data def _update_logs(self, logs: dict[str, Any]) -> None: @@ -81,13 +102,13 @@ def _update_logs(self, logs: dict[str, Any]) -> None: def train( self, - dataloader: TrainSampler | OOCTrainSampler, + dataloader: TrainSampler | JaxOutOfCoreTrainSampler, num_iterations: int, valid_freq: int, valid_loaders: dict[str, ValidationSampler] | None = None, monitor_metrics: Sequence[str] = [], callbacks: Sequence[BaseCallback] = [], - ) -> _otfm.OTFlowMatching | _genot.GENOT: + ) -> _otfm.OTFlowMatching | _genot.GENOT | _eqm.EquilibriumMatching: """Trains the model. Parameters @@ -122,7 +143,7 @@ def train( pbar = tqdm(range(num_iterations)) sampler = dataloader - if isinstance(dataloader, OOCTrainSampler): + if isinstance(dataloader, JaxOutOfCoreTrainSampler): dataloader.set_sampler(num_iterations=num_iterations) for it in pbar: rng_jax, rng_step_fn = jax.random.split(rng_jax, 2) @@ -136,14 +157,19 @@ def train( valid_loaders, mode="on_log_iteration" ) - # Run callbacks - metrics = crun.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, self.solver) # type: ignore[arg-type] + # Calculate mean loss + mean_loss = np.mean(self.training_logs["loss"][-valid_freq:]) + + # Run callbacks with loss as additional metric + metrics = crun.on_log_iteration( + valid_source_data, valid_true_data, valid_pred_data, self.solver, + additional_metrics={"train_loss": mean_loss} + ) self._update_logs(metrics) # Update progress bar - mean_loss = np.mean(self.training_logs["loss"][-valid_freq:]) postfix_dict = {metric: round(self.training_logs[metric][-1], 3) for metric in monitor_metrics} - postfix_dict["loss"] = round(mean_loss, 3) + postfix_dict["train_loss"] = round(mean_loss, 3) # or keep as "loss" pbar.set_postfix(postfix_dict) if num_iterations > 0: diff --git a/src/cellflow/training/_utils.py b/src/scaleflow/training/_utils.py similarity index 100% rename from src/cellflow/training/_utils.py rename to src/scaleflow/training/_utils.py diff --git a/src/cellflow/utils.py b/src/scaleflow/utils.py similarity index 100% rename from src/cellflow/utils.py rename to src/scaleflow/utils.py diff --git a/tests/conftest.py b/tests/conftest.py index d6d95fdd..e793734d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from cellflow.data._dataloader import ValidationSampler +from scaleflow.data._dataloader import ValidationSampler @pytest.fixture diff --git a/tests/data/test_cfsampler.py b/tests/data/test_cfsampler.py index 0ae5405a..7c803e46 100644 --- a/tests/data/test_cfsampler.py +++ b/tests/data/test_cfsampler.py @@ -1,13 +1,16 @@ +from pathlib import Path + import numpy as np import pytest -from cellflow.data._dataloader import OOCTrainSampler, PredictionSampler, TrainSampler -from cellflow.data._datamanager import DataManager +from scaleflow.data import JaxOutOfCoreTrainSampler, PredictionSampler, TrainSampler +from scaleflow.data._data import ZarrTrainingData +from scaleflow.data._datamanager import DataManager class TestTrainSampler: @pytest.mark.parametrize("batch_size", [1, 31]) - def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): + def test_sampling_no_combinations(self, adata_perturbation, batch_size: int, tmp_path): sample_rep = "X" split_covariates = ["cell_type"] control_key = "control" @@ -27,25 +30,36 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): ) train_data = dm.get_train_data(adata_perturbation) + train_data.write_zarr(Path(tmp_path) / "test_train_data.zarr") sampler = TrainSampler(data=train_data, batch_size=batch_size) + zarr_sampler = TrainSampler( + ZarrTrainingData.read_zarr(Path(tmp_path) / "test_train_data.zarr"), batch_size=batch_size + ) rng_1 = np.random.default_rng(0) rng_2 = np.random.default_rng(1) + rng_3 = np.random.default_rng(2) sample_1 = sampler.sample(rng_1) sample_2 = sampler.sample(rng_2) + sample_3 = zarr_sampler.sample(rng_3) assert "src_cell_data" in sample_1 assert "tgt_cell_data" in sample_1 assert "condition" in sample_1 + assert "src_cell_data" in sample_3 + assert "tgt_cell_data" in sample_3 + assert "condition" in sample_3 assert sample_1["src_cell_data"].shape[0] == batch_size assert sample_2["src_cell_data"].shape[0] == batch_size + assert sample_3["src_cell_data"].shape[0] == batch_size assert sample_1["tgt_cell_data"].shape[0] == batch_size assert sample_2["tgt_cell_data"].shape[0] == batch_size + assert sample_3["tgt_cell_data"].shape[0] == batch_size assert sample_1["condition"]["dosage"].shape[0] == 1 assert sample_2["condition"]["dosage"].shape[0] == 1 -class TestOOCTrainSampler: +class TestJaxOutOfCoreTrainSampler: @pytest.mark.parametrize("batch_size", [1, 31]) def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): sample_rep = "X" @@ -67,7 +81,7 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): ) train_data = dm.get_train_data(adata_perturbation) - sampler = OOCTrainSampler(data=train_data, batch_size=batch_size, seed=0) + sampler = JaxOutOfCoreTrainSampler(data=train_data, batch_size=batch_size, seed=0) sampler.set_sampler(num_iterations=2) sample_1 = sampler.sample() sample_2 = sampler.sample() @@ -86,8 +100,8 @@ def test_sampling_no_combinations(self, adata_perturbation, batch_size: int): class TestValidationSampler: @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 1, 3]) def test_valid_sampler(self, adata_perturbation, n_conditions_on_log_iteration): - from cellflow.data._dataloader import ValidationSampler - from cellflow.data._datamanager import DataManager + from scaleflow.data._dataloader import ValidationSampler + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] @@ -136,7 +150,7 @@ def test_pred_sampler( split_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager perturbation_covariates = {"drug": ["drug1", "drug2"]} diff --git a/tests/data/test_datamanager.py b/tests/data/test_datamanager.py index 237af9c7..91a64859 100644 --- a/tests/data/test_datamanager.py +++ b/tests/data/test_datamanager.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from cellflow.data._datamanager import DataManager +from scaleflow.data._datamanager import DataManager perturbation_covariates_args = [ {"drug": ["drug1"]}, @@ -38,7 +38,7 @@ def test_init_DataManager( perturbation_covariate_reps, sample_covariates, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -58,7 +58,7 @@ def test_init_DataManager( @pytest.mark.parametrize("el_to_delete", ["drug", "cell_type"]) def test_raise_false_uns_dict(self, adata_perturbation: ad.AnnData, el_to_delete): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -87,7 +87,7 @@ def test_raise_false_uns_dict(self, adata_perturbation: ad.AnnData, el_to_delete @pytest.mark.parametrize("el_to_delete", ["drug_b", "dosage_a"]) def test_raise_covar_mismatch(self, adata_perturbation: ad.AnnData, el_to_delete): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -113,7 +113,7 @@ def test_raise_covar_mismatch(self, adata_perturbation: ad.AnnData, el_to_delete ) def test_raise_target_without_source(self, adata_perturbation: ad.AnnData): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager sample_rep = "X" split_covariates = ["cell_type"] @@ -156,8 +156,8 @@ def test_get_train_data( perturbation_covariate_reps, sample_covariates, ): - from cellflow.data._data import TrainingData - from cellflow.data._datamanager import DataManager + from scaleflow.data._data import TrainingData + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -211,7 +211,7 @@ def test_get_train_data_with_combinations( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager dm = DataManager( adata_perturbation, @@ -300,7 +300,7 @@ def test_get_validation_data( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] @@ -338,7 +338,7 @@ def test_get_validation_data( @pytest.mark.skip(reason="To discuss: why should it raise an error?") def test_raises_wrong_max_combination_length(self, adata_perturbation): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager max_combination_length = 3 adata = adata_perturbation @@ -378,7 +378,7 @@ def test_get_prediction_data( perturbation_covariates, perturbation_covariate_reps, ): - from cellflow.data._datamanager import DataManager + from scaleflow.data._datamanager import DataManager control_key = "control" sample_covariates = ["cell_type"] diff --git a/tests/data/test_datasplitter.py b/tests/data/test_datasplitter.py new file mode 100644 index 00000000..3957e508 --- /dev/null +++ b/tests/data/test_datasplitter.py @@ -0,0 +1,729 @@ +from pathlib import Path + +import numpy as np +import pytest + +from scaleflow.data import DataManager +from scaleflow.data._data_splitter import DataSplitter + + +class TestDataSplitterValidation: + def test_mismatched_datasets_and_names(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="training_datasets length.*must match.*dataset_names length"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1", "dataset2"], + split_ratios=[[0.8, 0.1, 0.1]], + ) + + def test_mismatched_datasets_and_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="split_ratios length.*must match.*training_datasets length"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1], [0.7, 0.2, 0.1]], + ) + + def test_invalid_split_ratios_format(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must be a list of 3 values"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.2]], + ) + + def test_split_ratios_dont_sum_to_one(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must sum to 1.0"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.2]], + ) + + def test_negative_split_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="must be non-negative"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.9, 0.2, -0.1]], + ) + + def test_holdout_groups_requires_split_key(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="split_key must be provided"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + ) + + def test_holdout_combinations_requires_control_value(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + with pytest.raises(ValueError, match="control_value must be provided"): + DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_combinations", + split_key="drug", + ) + + +class TestRandomSplit: + @pytest.mark.parametrize("hard_test_split", [True, False]) + @pytest.mark.parametrize("split_ratios", [[0.8, 0.1, 0.1], [0.7, 0.2, 0.1], [1.0, 0.0, 0.0]]) + def test_random_split_ratios(self, adata_perturbation, hard_test_split, split_ratios): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[split_ratios], + split_type="random", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + n_cells = train_data.perturbation_covariates_mask.shape[0] + total_assigned = len(indices["train"]) + len(indices["val"]) + len(indices["test"]) + assert total_assigned == n_cells + + train_ratio, val_ratio, test_ratio = split_ratios + assert len(indices["train"]) == pytest.approx(train_ratio * n_cells, abs=1) + if val_ratio > 0: + assert len(indices["val"]) > 0 + if test_ratio > 0: + assert len(indices["test"]) > 0 + + def test_random_split_reproducibility(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter1 = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + results1 = splitter1.split_all_datasets() + + splitter2 = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + results2 = splitter2.split_all_datasets() + + assert np.array_equal(results1["dataset1"]["indices"]["train"], results2["dataset1"]["indices"]["train"]) + assert np.array_equal(results1["dataset1"]["indices"]["val"], results2["dataset1"]["indices"]["val"]) + assert np.array_equal(results1["dataset1"]["indices"]["test"], results2["dataset1"]["indices"]["test"]) + + def test_random_split_no_overlap_hard(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.7, 0.2, 0.1]], + split_type="random", + hard_test_split=True, + random_state=42, + ) + results = splitter.split_all_datasets() + indices = results["dataset1"]["indices"] + + train_set = set(indices["train"]) + val_set = set(indices["val"]) + test_set = set(indices["test"]) + + assert len(train_set & val_set) == 0 + assert len(train_set & test_set) == 0 + assert len(val_set & test_set) == 0 + + +class TestHoldoutGroups: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_holdout_groups_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + assert "split_values" in results["dataset1"] + + split_values = results["dataset1"]["split_values"] + assert "train" in split_values + assert "val" in split_values + assert "test" in split_values + + def test_holdout_groups_force_training_values(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=[], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + # Get available perturbation values (not control) + unique_values = set() + for covariates in train_data.perturbation_idx_to_covariates.values(): + unique_values.update(covariates) + + # Use "drug_a" instead of "control" since control cells are filtered out + force_value = "drug_a" + if force_value not in unique_values: + pytest.skip("drug_a not in perturbation values") + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + force_training_values=[force_value], + random_state=42, + ) + + results = splitter.split_all_datasets() + split_values = results["dataset1"]["split_values"] + + assert force_value in split_values["train"] + assert force_value not in split_values["val"] + assert force_value not in split_values["test"] + + def test_holdout_groups_fixed_test_seed(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + results_list = [] + for seed in [42, 43, 44]: + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_groups", + split_key="drug", + test_random_state=999, + val_random_state=seed, + random_state=seed, + ) + results = splitter.split_all_datasets() + results_list.append(results) + + test_values_1 = set(results_list[0]["dataset1"]["split_values"]["test"]) + test_values_2 = set(results_list[1]["dataset1"]["split_values"]["test"]) + test_values_3 = set(results_list[2]["dataset1"]["split_values"]["test"]) + + assert test_values_1 == test_values_2 == test_values_3 + + val_values_1 = set(results_list[0]["dataset1"]["split_values"]["val"]) + val_values_2 = set(results_list[1]["dataset1"]["split_values"]["val"]) + + if len(val_values_1) > 0 and len(val_values_2) > 0: + assert val_values_1 != val_values_2 + + +class TestHoldoutCombinations: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_holdout_combinations_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_combinations", + split_key=["drug1", "drug2"], + control_value="control", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + assert len(indices["train"]) > 0 + assert len(indices["val"]) >= 0 + assert len(indices["test"]) >= 0 + + def test_holdout_combinations_singletons_in_train(self): + # Create test data with a good number of combinations + import anndata as ad + + n_obs = 1000 # Increased to accommodate more combinations + n_vars = 50 + n_pca = 10 + + X_data = np.random.rand(n_obs, n_vars) + my_counts = np.random.rand(n_obs, n_vars) + X_pca = np.random.rand(n_obs, n_pca) + + # Use 5 drugs to get 20 unique combinations (5 * 4) + drugs = ["drug_a", "drug_b", "drug_c", "drug_d", "drug_e"] + cell_lines = ["cell_line_a", "cell_line_b", "cell_line_c"] + + # Create structured data with known combinations + drug1_list = [] + drug2_list = [] + + # Control cells (100) + drug1_list.extend(["control"] * 100) + drug2_list.extend(["control"] * 100) + + # Singleton on drug1 (250 cells: 50 per drug) + for drug in drugs: + drug1_list.extend([drug] * 50) + drug2_list.extend(["control"] * 50) + + # Singleton on drug2 (250 cells: 50 per drug) + for drug in drugs: + drug1_list.extend(["control"] * 50) + drug2_list.extend([drug] * 50) + + # Combinations (400 cells distributed across 20 combinations = 20 cells each) + # Create all possible non-control combinations + combinations = [] + for d1 in drugs: + for d2 in drugs: + if d1 != d2: # Different drugs (true combinations) + combinations.append((d1, d2)) + + # Distribute 400 cells evenly across combinations (20 cells per combination) + cells_per_combo = 400 // len(combinations) + + for d1, d2 in combinations: + drug1_list.extend([d1] * cells_per_combo) + drug2_list.extend([d2] * cells_per_combo) + + # Create cell line assignments + import pandas as pd + cell_type_list = np.random.choice(cell_lines, n_obs) + dosages = np.random.choice([10.0, 100.0, 1000.0], n_obs) + + obs_data = pd.DataFrame({ + "cell_type": cell_type_list, + "dosage": dosages, + "drug1": drug1_list, + "drug2": drug2_list, + "drug3": ["control"] * n_obs, + "dosage_a": np.random.choice([10.0, 100.0, 1000.0], n_obs), + "dosage_b": np.random.choice([10.0, 100.0, 1000.0], n_obs), + "dosage_c": np.random.choice([10.0, 100.0, 1000.0], n_obs), + }) + + # Create an AnnData object + adata_combinations = ad.AnnData(X=X_data, obs=obs_data) + adata_combinations.layers["my_counts"] = my_counts + adata_combinations.obsm["X_pca"] = X_pca + + # Add boolean columns for each drug + for drug in drugs: + adata_combinations.obs[drug] = ( + (adata_combinations.obs["drug1"] == drug) | + (adata_combinations.obs["drug2"] == drug) | + (adata_combinations.obs["drug3"] == drug) + ) + + adata_combinations.obs["control"] = ( + (adata_combinations.obs["drug1"] == "control") & + (adata_combinations.obs["drug2"] == "control") + ) + + # Convert to categorical EXCEPT for control and boolean drug columns + for col in adata_combinations.obs.columns: + if col not in ["control"] + drugs: + adata_combinations.obs[col] = adata_combinations.obs[col].astype("category") + + # Add embeddings + drug_emb = {} + for drug in adata_combinations.obs["drug1"].cat.categories: + drug_emb[drug] = np.random.randn(5, 1) + adata_combinations.uns["drug"] = drug_emb + + cell_type_emb = {} + for cell_type in adata_combinations.obs["cell_type"].cat.categories: + cell_type_emb[cell_type] = np.random.randn(3, 1) + adata_combinations.uns["cell_type"] = cell_type_emb + + # Now run the actual test + dm = DataManager( + adata_combinations, + sample_rep="X", + split_covariates=[], + control_key="control", + perturbation_covariates={"drug": ["drug1", "drug2"]}, + ) + train_data = dm.get_train_data(adata_combinations) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.6, 0.2, 0.2]], + split_type="holdout_combinations", + split_key=["drug1", "drug2"], + control_value="control", + random_state=42, + ) + + results = splitter.split_all_datasets() + + perturbation_covariates_mask = train_data.perturbation_covariates_mask + perturbation_idx_to_covariates = train_data.perturbation_idx_to_covariates + + train_indices = results["dataset1"]["indices"]["train"] + val_indices = results["dataset1"]["indices"]["val"] + test_indices = results["dataset1"]["indices"]["test"] + + # Verify that ALL singletons and controls are in training + all_singletons = [] + all_combinations = [] + + for idx in range(len(perturbation_covariates_mask)): + pert_idx = perturbation_covariates_mask[idx] + if pert_idx >= 0: + covariates = perturbation_idx_to_covariates[pert_idx] + non_control_count = sum(1 for c in covariates if c != "control") + if non_control_count == 1: + all_singletons.append(idx) + elif non_control_count > 1: + all_combinations.append(idx) + + train_set = set(train_indices) + + # All singletons should be in training + for singleton_idx in all_singletons: + assert singleton_idx in train_set, "All singleton perturbations should be in training" + + # Some (but not all) combinations should be in training according to split_ratios + combinations_in_train = [idx for idx in all_combinations if idx in train_set] + combinations_in_val = [idx for idx in all_combinations if idx in set(val_indices)] + combinations_in_test = [idx for idx in all_combinations if idx in set(test_indices)] + + # With enough combinations, we should see proper distribution + assert len(all_combinations) > 0, "Test data should have combination perturbations" + + train_combo_ratio = len(combinations_in_train) / len(all_combinations) + val_combo_ratio = len(combinations_in_val) / len(all_combinations) + test_combo_ratio = len(combinations_in_test) / len(all_combinations) + + # With 0.6, 0.2, 0.2 ratios, allow some tolerance + assert 0.4 < train_combo_ratio < 0.8, f"Expected ~60% of combinations in training, got {train_combo_ratio:.2%}" + assert 0.05 < val_combo_ratio < 0.35, f"Expected ~20% of combinations in val, got {val_combo_ratio:.2%}" + assert 0.05 < test_combo_ratio < 0.35, f"Expected ~20% of combinations in test, got {test_combo_ratio:.2%}" + + +class TestStratifiedSplit: + @pytest.mark.parametrize("hard_test_split", [True, False]) + def test_stratified_split_basic(self, adata_perturbation, hard_test_split): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="stratified", + split_key="drug", + hard_test_split=hard_test_split, + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + indices = results["dataset1"]["indices"] + + n_cells = train_data.perturbation_covariates_mask.shape[0] + total_assigned = len(indices["train"]) + len(indices["val"]) + len(indices["test"]) + assert total_assigned == n_cells + + +class TestMultipleDatasets: + def test_multiple_datasets_different_ratios(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data1 = dm.get_train_data(adata_perturbation) + train_data2 = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data1, train_data2], + dataset_names=["dataset1", "dataset2"], + split_ratios=[[0.8, 0.1, 0.1], [0.7, 0.2, 0.1]], + split_type="random", + random_state=42, + ) + + results = splitter.split_all_datasets() + + assert "dataset1" in results + assert "dataset2" in results + + n_cells = train_data1.perturbation_covariates_mask.shape[0] + + assert len(results["dataset1"]["indices"]["train"]) == pytest.approx(0.8 * n_cells, abs=1) + assert len(results["dataset2"]["indices"]["train"]) == pytest.approx(0.7 * n_cells, abs=1) + + +class TestSaveAndLoad: + def test_save_and_load_splits(self, adata_perturbation, tmp_path): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + split_key="drug", + random_state=42, + ) + + results = splitter.split_all_datasets() + splitter.save_splits(tmp_path / "splits") + + assert (tmp_path / "splits" / "split_summary.json").exists() + assert (tmp_path / "splits" / "dataset1" / "metadata.json").exists() + assert (tmp_path / "splits" / "dataset1" / "split_info.pkl").exists() + + loaded_info = DataSplitter.load_split_info(tmp_path / "splits", "dataset1") + + assert "indices" in loaded_info + assert "metadata" in loaded_info + + assert np.array_equal(loaded_info["indices"]["train"], results["dataset1"]["indices"]["train"]) + assert np.array_equal(loaded_info["indices"]["val"], results["dataset1"]["indices"]["val"]) + assert np.array_equal(loaded_info["indices"]["test"], results["dataset1"]["indices"]["test"]) + + def test_load_nonexistent_split(self, tmp_path): + with pytest.raises(FileNotFoundError): + DataSplitter.load_split_info(tmp_path / "nonexistent", "dataset1") + + +class TestSplitSummary: + def test_generate_split_summary(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="holdout_groups", + split_key="drug", + random_state=42, + ) + + splitter.split_all_datasets() + summary = splitter.generate_split_summary() + + assert "dataset1" in summary + assert "configuration" in summary["dataset1"] + assert "statistics" in summary["dataset1"] + assert "observations_per_condition" in summary["dataset1"] + + config = summary["dataset1"]["configuration"] + assert config["split_type"] == "holdout_groups" + assert config["random_state"] == 42 + + stats = summary["dataset1"]["statistics"] + assert "total_observations" in stats + assert "train_observations" in stats + assert "val_observations" in stats + assert "test_observations" in stats + + def test_summary_before_split_raises(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + random_state=42, + ) + + with pytest.raises(ValueError, match="No split results available"): + splitter.generate_split_summary() + + +class TestExtractPerturbationInfo: + def test_extract_perturbation_info(self, adata_perturbation): + dm = DataManager( + adata_perturbation, + sample_rep="X", + split_covariates=["cell_type"], + control_key="control", + perturbation_covariates={"drug": ["drug1"]}, + ) + train_data = dm.get_train_data(adata_perturbation) + + splitter = DataSplitter( + training_datasets=[train_data], + dataset_names=["dataset1"], + split_ratios=[[0.8, 0.1, 0.1]], + split_type="random", + ) + + pert_info = splitter.extract_perturbation_info(train_data) + + assert "perturbation_covariates_mask" in pert_info + assert "perturbation_idx_to_covariates" in pert_info + assert "n_cells" in pert_info + + assert isinstance(pert_info["perturbation_covariates_mask"], np.ndarray) + assert isinstance(pert_info["perturbation_idx_to_covariates"], dict) + assert pert_info["n_cells"] == len(train_data.perturbation_covariates_mask) diff --git a/tests/data/test_old_get_condition_data.py b/tests/data/test_old_get_condition_data.py index a546281a..a7e94575 100644 --- a/tests/data/test_old_get_condition_data.py +++ b/tests/data/test_old_get_condition_data.py @@ -9,8 +9,8 @@ import pytest from tqdm import tqdm -from cellflow._types import ArrayLike -from cellflow.data._datamanager import ( +from scaleflow._types import ArrayLike +from scaleflow.data._datamanager import ( DataManager, ReturnData, _to_list, diff --git a/tests/data/test_torch_dataloader.py b/tests/data/test_torch_dataloader.py new file mode 100644 index 00000000..ced4220e --- /dev/null +++ b/tests/data/test_torch_dataloader.py @@ -0,0 +1,55 @@ +import scaleflow +from scaleflow.data import TorchCombinedTrainSampler + + +class TestTorchDataloader: + def test_torch_dataloader_shapes( + self, + adata_perturbation, + tmp_path, + ): + solver = "otfm" + sample_rep = "X" + control_key = "control" + perturbation_covariates = {"drug": ["drug1", "drug2"]} + perturbation_covariate_reps = {"drug": "drug"} + batch_size = 18 + + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) + cf.prepare_data( + sample_rep=sample_rep, + control_key=control_key, + perturbation_covariates=perturbation_covariates, + perturbation_covariate_reps=perturbation_covariate_reps, + ) + assert cf.train_data is not None + assert hasattr(cf, "_data_dim") + cf.train_data.write_zarr(tmp_path / "train_data1.zarr") + cf.train_data.write_zarr(tmp_path / "train_data2.zarr") + cf.train_data.write_zarr(tmp_path / "train_data3.zarr") + + combined_dataloader = TorchCombinedTrainSampler.combine_zarr_training_samplers( + data_paths=[ + tmp_path / "train_data1.zarr", + tmp_path / "train_data2.zarr", + tmp_path / "train_data3.zarr", + ], + batch_size=batch_size, + num_workers=2, + weights=[0.3, 0.3, 0.4], + seed=42, + dataset_names=["train_data1", "train_data2", "train_data3"], + ) + iter_dl = iter(combined_dataloader) + batch = next(iter_dl) + assert "dataset_name" in batch + assert batch["dataset_name"] in ["train_data1", "train_data2", "train_data3"] + assert "src_cell_data" in batch + assert "tgt_cell_data" in batch + assert "condition" in batch + dim = adata_perturbation.shape[1] + assert batch["src_cell_data"].shape == (batch_size, dim) + assert batch["tgt_cell_data"].shape == (batch_size, dim) + assert "drug" in batch["condition"] + drug_dim = adata_perturbation.uns["drug"]["drug_a"].shape[0] + assert batch["condition"]["drug"].shape == (1, len(perturbation_covariates["drug"]), drug_dim) diff --git a/tests/external/test_CFJaxSCVI.py b/tests/external/test_CFJaxSCVI.py index d25d1907..9238afd4 100644 --- a/tests/external/test_CFJaxSCVI.py +++ b/tests/external/test_CFJaxSCVI.py @@ -1,7 +1,7 @@ import pytest from scvi.data import synthetic_iid -from cellflow.external import CFJaxSCVI +from scaleflow.external import CFJaxSCVI class TestCFJaxSCVI: diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 9239fcf1..e67f4c0e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -2,7 +2,7 @@ import numpy as np import pytest -import cellflow +import scaleflow class TestMetrics: @@ -11,8 +11,8 @@ def test_compute_metrics(self, metrics_data, prefix): x_test = metrics_data["x_test"] y_test = metrics_data["y_test"] - metrics = jtu.tree_map(cellflow.metrics.compute_metrics, x_test, y_test) - mean_metrics = cellflow.metrics.compute_mean_metrics(metrics, prefix) + metrics = jtu.tree_map(scaleflow.metrics.compute_metrics, x_test, y_test) + mean_metrics = scaleflow.metrics.compute_mean_metrics(metrics, prefix) assert "Alvespimycin+Pirarubicin" in metrics.keys() assert {"r_squared", "sinkhorn_div_1", "sinkhorn_div_10", "sinkhorn_div_100", "e_distance", "mmd"} == set( @@ -32,12 +32,12 @@ def test_function_output(self, metrics_data, epsilon): x_test = metrics_data["x_test"]["Alvespimycin+Pirarubicin"] y_test = metrics_data["y_test"]["Alvespimycin+Pirarubicin"] - r_squared = cellflow.metrics.compute_r_squared(x_test, y_test) - sinkhorn_div = cellflow.metrics.compute_sinkhorn_div(x_test, y_test, epsilon=epsilon) - e_distance = cellflow.metrics.compute_e_distance(x_test, y_test) - e_distance_fast = cellflow.metrics.compute_e_distance_fast(x_test, y_test) - scalar_mmd = cellflow.metrics.compute_scalar_mmd(x_test, y_test) - mmd_fast = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, exact=False) + r_squared = scaleflow.metrics.compute_r_squared(x_test, y_test) + sinkhorn_div = scaleflow.metrics.compute_sinkhorn_div(x_test, y_test, epsilon=epsilon) + e_distance = scaleflow.metrics.compute_e_distance(x_test, y_test) + e_distance_fast = scaleflow.metrics.compute_e_distance_fast(x_test, y_test) + scalar_mmd = scaleflow.metrics.compute_scalar_mmd(x_test, y_test) + mmd_fast = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, exact=False) assert -1000 <= r_squared <= 1 assert sinkhorn_div >= 0 @@ -51,11 +51,11 @@ def test_fast_metrics(self, metrics_data, gamma): x_test = metrics_data["x_test"]["Alvespimycin+Pirarubicin"] y_test = metrics_data["y_test"]["Alvespimycin+Pirarubicin"] - e_distance = cellflow.metrics.compute_e_distance(x_test, y_test) - e_distance_fast = cellflow.metrics.compute_e_distance_fast(x_test, y_test) + e_distance = scaleflow.metrics.compute_e_distance(x_test, y_test) + e_distance_fast = scaleflow.metrics.compute_e_distance_fast(x_test, y_test) - mmd = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=True) - mmd_fast = cellflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=False) + mmd = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=True) + mmd_fast = scaleflow.metrics.maximum_mean_discrepancy(x_test, y_test, gamma, exact=False) assert np.allclose(e_distance, e_distance_fast, rtol=1e-4, atol=1e-4) assert np.allclose(mmd, mmd_fast, rtol=1e-4, atol=1e-4) diff --git a/tests/model/test_cellflow.py b/tests/model/test_scaleflow.py similarity index 88% rename from tests/model/test_cellflow.py rename to tests/model/test_scaleflow.py index 024349cb..90566fcc 100644 --- a/tests/model/test_cellflow.py +++ b/tests/model/test_scaleflow.py @@ -2,8 +2,8 @@ import pandas as pd import pytest -import cellflow -from cellflow.networks import _velocity_field +import scaleflow +from scaleflow.networks import _velocity_field perturbation_covariate_comb_args = [ {"drug": ["drug1"]}, @@ -17,11 +17,11 @@ class TestCellFlow: @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm"]) # , "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) - def test_cellflow_solver( + def test_scaleflow_solver( self, adata_perturbation, solver, @@ -29,7 +29,7 @@ def test_cellflow_solver( regularization, conditioning, ): - if solver == "genot" and ((condition_mode == "stochastic") or (regularization > 0.0)): + if solver in ["genot", "eqm"] and ((condition_mode == "stochastic") or (regularization > 0.0)): return None sample_rep = "X" control_key = "control" @@ -38,7 +38,7 @@ def test_cellflow_solver( condition_embedding_dim = 32 vf_kwargs = {"genot_source_dims": (32, 32), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -77,15 +77,14 @@ def test_cellflow_solver( cf.train(num_iterations=3) assert cf._dataloader is not None - # we assume these are all source cells now in adata_perturbation adata_perturbation_pred = adata_perturbation.copy() adata_perturbation_pred.obs["control"] = True + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} pred = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, - max_steps=3, - throw=False, + **predict_kwargs, ) assert isinstance(pred, dict) key, out = next(iter(pred.items())) @@ -97,16 +96,14 @@ def test_cellflow_solver( sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, key_added_prefix="MY_PREDICTION_", - max_steps=3, - throw=False, + **predict_kwargs, ) assert pred_stored is None - if solver == "otfm": + if solver in ["otfm", "genot", "eqm"]: assert "MY_PREDICTION_" + str(key) in adata_perturbation_pred.obsm if solver == "genot": - assert "MY_PREDICTION_" + str(key) in adata_perturbation_pred.obsm pred2 = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, @@ -133,9 +130,9 @@ def test_cellflow_solver( assert cond_embed_var.shape[1] == condition_embedding_dim @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("perturbation_covariate_reps", [{}, {"drug": "drug"}]) - def test_cellflow_covar_reps( + def test_scaleflow_covar_reps( self, adata_perturbation, perturbation_covariate_reps, @@ -148,7 +145,7 @@ def test_cellflow_covar_reps( condition_embedding_dim = 32 vf_kwargs = {"genot_source_dims": (32, 32), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -166,24 +163,24 @@ def test_cellflow_covar_reps( ) assert cf._trainer is not None - vector_field_class = ( - _velocity_field.ConditionalVelocityField - if solver == "otfm" - else _velocity_field.GENOTConditionalVelocityField - ) + if solver == "otfm": + vector_field_class = _velocity_field.ConditionalVelocityField + elif solver == "genot": + vector_field_class = _velocity_field.GENOTConditionalVelocityField + else: + vector_field_class = _velocity_field.EquilibriumVelocityField assert cf._vf_class == vector_field_class cf.train(num_iterations=3) assert cf._dataloader is not None - # we assume these are all source cells now in adata_perturbation adata_perturbation_pred = adata_perturbation.copy() adata_perturbation_pred.obs["control"] = True + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} pred = cf.predict( adata_perturbation_pred, sample_rep=sample_rep, covariate_data=adata_perturbation_pred.obs, - max_steps=3, - throw=False, + **predict_kwargs, ) assert isinstance(pred, dict) out = next(iter(pred.values())) @@ -203,7 +200,7 @@ def test_cellflow_covar_reps( @pytest.mark.parametrize("perturbation_covariates", perturbation_covariate_comb_args) @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 0, 2]) @pytest.mark.parametrize("n_conditions_on_train_end", [None, 0, 2]) - def test_cellflow_val_data_loading( + def test_scaleflow_val_data_loading( self, adata_perturbation, split_covariates, @@ -211,7 +208,7 @@ def test_cellflow_val_data_loading( n_conditions_on_log_iteration, n_conditions_on_train_end, ): - cf = cellflow.model.CellFlow(adata_perturbation) + cf = scaleflow.model.CellFlow(adata_perturbation) cf.prepare_data( sample_rep="X", control_key="control", @@ -248,10 +245,10 @@ def test_cellflow_val_data_loading( assert cond_data[k].shape[1] == cf.train_data.max_combination_length @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("n_conditions_on_log_iteration", [None, 0, 1]) @pytest.mark.parametrize("n_conditions_on_train_end", [None, 0, 1]) - def test_cellflow_with_validation( + def test_scaleflow_with_validation( self, adata_perturbation, solver, @@ -259,7 +256,8 @@ def test_cellflow_with_validation( n_conditions_on_train_end, ): vf_kwargs = {"genot_source_dims": (2, 2), "genot_source_dropout": 0.1} if solver == "genot" else None - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", control_key="control", @@ -274,7 +272,7 @@ def test_cellflow_with_validation( name="val", n_conditions_on_log_iteration=n_conditions_on_log_iteration, n_conditions_on_train_end=n_conditions_on_train_end, - predict_kwargs={"max_steps": 3, "throw": False}, + predict_kwargs=predict_kwargs, ) assert isinstance(cf._validation_data, dict) assert "val" in cf._validation_data @@ -300,24 +298,26 @@ def test_cellflow_with_validation( assert cf._trainer is not None metric_to_compute = "r_squared" - metrics_callback = cellflow.training.Metrics(metrics=[metric_to_compute]) + metrics_callback = scaleflow.training.Metrics(metrics=[metric_to_compute]) cf.train(num_iterations=3, callbacks=[metrics_callback], valid_freq=1) assert cf._dataloader is not None assert f"val_{metric_to_compute}_mean" in cf._trainer.training_logs @pytest.mark.slow - @pytest.mark.parametrize("solver", ["otfm", "genot"]) + @pytest.mark.parametrize("solver", ["otfm", "genot", "eqm"]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) - def test_cellflow_predict( + def test_scaleflow_predict( self, adata_perturbation, solver, condition_mode, regularization, ): - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + if solver in ["genot", "eqm"] and ((condition_mode == "stochastic") or (regularization > 0.0)): + return None + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep="X", control_key="control", @@ -354,7 +354,8 @@ def test_cellflow_predict( adata_pred.obs["control"] = True covariate_data = adata_perturbation.obs.iloc[:3] - pred = cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, max_steps=3, throw=False) + predict_kwargs = {"max_steps": 3, "eta": 0.01} if solver == "eqm" else {"max_steps": 3, "throw": False} + pred = cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, **predict_kwargs) assert isinstance(pred, dict) out = next(iter(pred.values())) @@ -365,11 +366,11 @@ def test_cellflow_predict( ValueError, match=r".*If both `adata` and `covariate_data` are given, all samples in `adata` must be control samples*", ): - cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, max_steps=3, throw=False) + cf.predict(adata_pred, sample_rep="X", covariate_data=covariate_data, **predict_kwargs) with pytest.raises(ValueError, match="`covariate_data` is empty."): empty_covariate_data = covariate_data.head(0) - cf.predict(adata_pred, sample_rep="X", covariate_data=empty_covariate_data, max_steps=3, throw=False) + cf.predict(adata_pred, sample_rep="X", covariate_data=empty_covariate_data, **predict_kwargs) with pytest.raises( ValueError, @@ -381,12 +382,12 @@ def test_cellflow_predict( adata_pred_cell_type_2 = adata_pred[adata_pred.obs["cell_type"] == "cell_line_b"] adata_pred_cell_type_2.obs["control"] = True cf.predict( - adata_pred_cell_type_2, sample_rep="X", covariate_data=cov_data_cell_type_1, max_steps=3, throw=False + adata_pred_cell_type_2, sample_rep="X", covariate_data=cov_data_cell_type_1, **predict_kwargs ) def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): vf_kwargs = {"genot_source_dims": (2, 2), "genot_source_dropouts": 0.1} - cf = cellflow.model.CellFlow(adata_perturbation, solver="otfm") + cf = scaleflow.model.CellFlow(adata_perturbation, solver="otfm") cf.prepare_data( sample_rep="X", control_key="control", @@ -395,7 +396,7 @@ def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): ) with pytest.raises( ValueError, - match=r".*For `solver='otfm'`, `vf_kwargs` must be `None`.*", + match=r".*For `solver='otfm'` or `solver='eqm'`, `vf_kwargs` must be `None`.*", ): cf.prepare_model( condition_embedding_dim=2, @@ -413,7 +414,7 @@ def test_raise_otfm_vf_kwargs_passed(self, adata_perturbation): @pytest.mark.parametrize("perturbation_covariates", perturbation_covariate_comb_args) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize("regularization", [0.0, 0.1]) - def test_cellflow_get_condition_embedding( + def test_scaleflow_get_condition_embedding( self, adata_perturbation, sample_covariate_and_reps, @@ -430,7 +431,7 @@ def test_cellflow_get_condition_embedding( condition_embedding_dim = 2 solver = "otfm" - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, @@ -504,7 +505,7 @@ def test_time_embedding( solver = "otfm" time_freqs = 1024 - cf = cellflow.model.CellFlow(adata_perturbation, solver=solver) + cf = scaleflow.model.CellFlow(adata_perturbation, solver=solver) cf.prepare_data( sample_rep=sample_rep, control_key=control_key, diff --git a/tests/networks/test_aggregators.py b/tests/networks/test_aggregators.py index a2ce546a..682ae6d1 100644 --- a/tests/networks/test_aggregators.py +++ b/tests/networks/test_aggregators.py @@ -2,8 +2,8 @@ import jax.numpy as jnp import pytest -from cellflow.networks._set_encoders import ConditionEncoder -from cellflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling +from scaleflow.networks._set_encoders import ConditionEncoder +from scaleflow.networks._utils import SeedAttentionPooling, TokenAttentionPooling class TestAggregator: diff --git a/tests/networks/test_condencoder.py b/tests/networks/test_condencoder.py index 54589d62..325278ee 100644 --- a/tests/networks/test_condencoder.py +++ b/tests/networks/test_condencoder.py @@ -3,7 +3,7 @@ import optax import pytest -import cellflow +import scaleflow cond = { "pert1": jnp.ones((1, 3, 3)), @@ -54,7 +54,7 @@ class TestConditionEncoder: def test_condition_encoder_init( self, pooling, covariates_not_pooled, layers_before_pool, layers_after_pool, condition_mode, regularization ): - cond_encoder = cellflow.networks.ConditionEncoder( + cond_encoder = scaleflow.networks.ConditionEncoder( output_dim=5, condition_mode=condition_mode, regularization=regularization, diff --git a/tests/networks/test_velocityfield.py b/tests/networks/test_velocityfield.py index 4e651d26..96db1f2d 100644 --- a/tests/networks/test_velocityfield.py +++ b/tests/networks/test_velocityfield.py @@ -4,7 +4,7 @@ import pytest from flax.linen import activation -from cellflow.networks import _velocity_field +from scaleflow.networks import _velocity_field x_test = jnp.ones((10, 5)) * 10 t_test = jnp.ones((10, 1)) @@ -19,7 +19,7 @@ class TestVelocityField: @pytest.mark.parametrize("linear_projection_before_concatenation", [True, False]) @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocity_field_init( @@ -62,6 +62,15 @@ def test_velocity_field_init( train=True, rngs={"condition_encoder": apply_rng}, ) + elif isinstance(vf, _velocity_field.EquilibriumVelocityField): + out, out_mean, out_logvar = vf_state.apply_fn( + {"params": vf_state.params}, + x_test, + cond, + encoder_noise, + train=True, + rngs={"condition_encoder": apply_rng}, + ) elif isinstance(vf, _velocity_field.ConditionalVelocityField): out, out_mean, out_logvar = vf_state.apply_fn( {"params": vf_state.params}, @@ -84,7 +93,7 @@ def test_velocity_field_init( @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_cls, conditioning): @@ -127,6 +136,15 @@ def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_ train=True, rngs={"condition_encoder": apply_rng, "dropout": dropout_rng}, ) + elif isinstance(vf, _velocity_field.EquilibriumVelocityField): + out, out_mean, out_logvar = vf_state.apply_fn( + {"params": vf_state.params}, + x_test, + cond, + encoder_noise, + train=True, + rngs={"condition_encoder": apply_rng, "dropout": dropout_rng}, + ) elif isinstance(vf, _velocity_field.ConditionalVelocityField): out, out_mean, out_logvar = vf_state.apply_fn( {"params": vf_state.params}, @@ -145,7 +163,7 @@ def test_velocityfield_conditioning_kwargs(self, condition_mode, velocity_field_ @pytest.mark.parametrize("condition_mode", ["deterministic", "stochastic"]) @pytest.mark.parametrize( - "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField] + "velocity_field_cls", [_velocity_field.ConditionalVelocityField, _velocity_field.GENOTConditionalVelocityField, _velocity_field.EquilibriumVelocityField] ) @pytest.mark.parametrize("conditioning", ["concatenation", "film", "resnet"]) def test_velocityfield_conditioning_raises(self, condition_mode, velocity_field_cls, conditioning): diff --git a/tests/plotting/test_plotting.py b/tests/plotting/test_plotting.py index 7f431ff2..b3b0b0e6 100644 --- a/tests/plotting/test_plotting.py +++ b/tests/plotting/test_plotting.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import pytest -from cellflow.plotting import plot_condition_embedding +from scaleflow.plotting import plot_condition_embedding class TestCallbacks: diff --git a/tests/preprocessing/test_gene_emb.py b/tests/preprocessing/test_gene_emb.py index 4f79277e..3a1e3be0 100644 --- a/tests/preprocessing/test_gene_emb.py +++ b/tests/preprocessing/test_gene_emb.py @@ -6,7 +6,7 @@ import pytest import torch -from cellflow.preprocessing._gene_emb import get_esm_embedding +from scaleflow.preprocessing._gene_emb import get_esm_embedding IS_PROT_CODING = Counter(["ENSG00000139618", "ENSG00000206450", "ENSG00000049192"]) ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../test_artifacts/") diff --git a/tests/preprocessing/test_pca.py b/tests/preprocessing/test_pca.py index 662c721f..dccc5ff0 100644 --- a/tests/preprocessing/test_pca.py +++ b/tests/preprocessing/test_pca.py @@ -5,9 +5,9 @@ class TestPCA: def test_centered_pca(self, adata_pca: ad.AnnData): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) assert "X_pca" in adata_pca.obsm assert "PCs" in adata_pca.varm assert "X_mean" in adata_pca.varm @@ -17,10 +17,10 @@ def test_centered_pca(self, adata_pca: ad.AnnData): @pytest.mark.parametrize("layers_key_added", ["X_recon", "X_rec"]) def test_reconstruct_pca(self, adata_pca: ad.AnnData, layers_key_added): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - cellflow.pp.reconstruct_pca( + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.reconstruct_pca( adata_pca, ref_adata=adata_pca, use_rep="X_pca", @@ -36,18 +36,18 @@ def test_reconstruct_pca(self, adata_pca: ad.AnnData, layers_key_added): ) def test_reconstruct_pca_with_array_input(self, adata_pca: ad.AnnData): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - cellflow.pp.reconstruct_pca(adata_pca, ref_means=adata_pca.varm["X_mean"], ref_pcs=adata_pca.varm["PCs"]) + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + scaleflow.pp.reconstruct_pca(adata_pca, ref_means=adata_pca.varm["X_mean"], ref_pcs=adata_pca.varm["PCs"]) assert "X_recon" in adata_pca.layers @pytest.mark.parametrize("obsm_key_added", ["X_pca", "X_pca_projected"]) def test_project_pca(self, adata_pca: ad.AnnData, obsm_key_added): - import cellflow + import scaleflow - cellflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) - adata_pca_project = cellflow.pp.project_pca( + scaleflow.pp.centered_pca(adata_pca, n_comps=50, copy=False) + adata_pca_project = scaleflow.pp.project_pca( adata_pca, ref_adata=adata_pca, obsm_key_added=obsm_key_added, copy=True ) assert obsm_key_added in adata_pca_project.obsm diff --git a/tests/preprocessing/test_preprocessing.py b/tests/preprocessing/test_preprocessing.py index e5423663..ac9af815 100644 --- a/tests/preprocessing/test_preprocessing.py +++ b/tests/preprocessing/test_preprocessing.py @@ -13,10 +13,10 @@ class TestPreprocessing: ], ) def test_annotate_compounds(self, adata_with_compounds: ad.AnnData, compound_key_and_type): - import cellflow + import scaleflow try: - cellflow.pp.annotate_compounds( + scaleflow.pp.annotate_compounds( adata_with_compounds, compound_keys=compound_key_and_type[0], query_id_type=compound_key_and_type[1], @@ -45,11 +45,11 @@ def test_annotate_compounds(self, adata_with_compounds: ad.AnnData, compound_key ], ) def test_get_molecular_fingerprints(self, adata_with_compounds: ad.AnnData, n_bits, compound_and_smiles_keys): - import cellflow + import scaleflow uns_key_added = "compound_fingerprints" - cellflow.pp.get_molecular_fingerprints( + scaleflow.pp.get_molecular_fingerprints( adata_with_compounds, compound_keys=compound_and_smiles_keys[0], smiles_keys=compound_and_smiles_keys[1], @@ -67,9 +67,9 @@ def test_get_molecular_fingerprints(self, adata_with_compounds: ad.AnnData, n_bi @pytest.mark.parametrize("uns_key_added", ["compounds", "compounds_onehot"]) @pytest.mark.parametrize("exclude_values", [None, "GW0742"]) def test_encode_onehot(self, adata_with_compounds: ad.AnnData, uns_key_added, exclude_values): - import cellflow + import scaleflow - cellflow.pp.encode_onehot( + scaleflow.pp.encode_onehot( adata_with_compounds, covariate_keys="compound_name", uns_key_added=uns_key_added, diff --git a/tests/preprocessing/test_wknn.py b/tests/preprocessing/test_wknn.py index 2bbe12ab..be52cd58 100644 --- a/tests/preprocessing/test_wknn.py +++ b/tests/preprocessing/test_wknn.py @@ -6,9 +6,9 @@ class TestWKNN: @pytest.mark.parametrize("n_neighbors", [50, 100]) def test_compute_wknn_k(self, adata_perturbation: ad.AnnData, n_neighbors): - import cellflow + import scaleflow - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -24,12 +24,12 @@ def test_compute_wknn_k(self, adata_perturbation: ad.AnnData, n_neighbors): @pytest.mark.parametrize("weighting_scheme", ["top_n", "jaccard", "jaccard_square"]) def test_compute_wknn_weighting(self, adata_perturbation: ad.AnnData, weighting_scheme): - import cellflow + import scaleflow n_neighbors = 50 top_n = 10 - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -50,11 +50,11 @@ def test_compute_wknn_weighting(self, adata_perturbation: ad.AnnData, weighting_ @pytest.mark.parametrize("uns_key_added", ["wknn", "wknn2"]) def test_compute_wknn_key_added(self, adata_perturbation: ad.AnnData, uns_key_added): - import cellflow + import scaleflow n_neighbors = 50 - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=n_neighbors, @@ -71,16 +71,16 @@ def test_compute_wknn_key_added(self, adata_perturbation: ad.AnnData, uns_key_ad @pytest.mark.parametrize("label_key", ["drug1", "cell_type"]) def test_transfer_labels(self, adata_perturbation: ad.AnnData, label_key): - import cellflow + import scaleflow - cellflow.pp.compute_wknn( + scaleflow.pp.compute_wknn( ref_adata=adata_perturbation, query_adata=adata_perturbation, n_neighbors=50, copy=False, ) - cellflow.pp.transfer_labels( + scaleflow.pp.transfer_labels( adata_perturbation, adata_perturbation, label_key=label_key, diff --git a/tests/solver/test_solver.py b/tests/solver/test_solver.py index 6bf10401..08454141 100644 --- a/tests/solver/test_solver.py +++ b/tests/solver/test_solver.py @@ -2,14 +2,15 @@ import time import jax +import jax.numpy as jnp import numpy as np import optax import pytest from ott.neural.methods.flows import dynamics -import cellflow -from cellflow.solvers import _genot, _otfm -from cellflow.utils import match_linear +import scaleflow +from scaleflow.solvers import _eqm, _genot, _otfm +from scaleflow.utils import match_linear src = { ("drug_1",): np.random.rand(10, 5), @@ -22,13 +23,30 @@ vf_rng = jax.random.PRNGKey(111) +@pytest.fixture +def eqm_dataloader(): + class DataLoader: + n_conditions = 10 + + def sample(self, rng): + return { + "src_cell_data": jnp.ones((10, 5)) * 10, + "tgt_cell_data": jnp.ones((10, 5)), + "condition": {"pert1": jnp.ones((10, 2, 3))}, + } + + return DataLoader() + + class TestSolver: - @pytest.mark.parametrize("solver_class", ["otfm", "genot"]) - def test_predict_batch(self, dataloader, solver_class): + @pytest.mark.parametrize("solver_class", ["otfm", "genot", "eqm"]) + def test_predict_batch(self, dataloader, eqm_dataloader, solver_class): if solver_class == "otfm": - vf_class = cellflow.networks.ConditionalVelocityField + vf_class = scaleflow.networks.ConditionalVelocityField + elif solver_class == "genot": + vf_class = scaleflow.networks.GENOTConditionalVelocityField else: - vf_class = cellflow.networks.GENOTConditionalVelocityField + vf_class = scaleflow.networks.EquilibriumVelocityField opt = optax.adam(1e-3) vf = vf_class( @@ -47,7 +65,7 @@ def test_predict_batch(self, dataloader, solver_class): conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, ) - else: + elif solver_class == "genot": solver = _genot.GENOT( vf=vf, data_match_fn=match_linear, @@ -58,11 +76,20 @@ def test_predict_batch(self, dataloader, solver_class): conditions={"drug": np.random.rand(2, 1, 3)}, rng=vf_rng, ) + else: + solver = _eqm.EquilibriumMatching( + vf=vf, + match_fn=match_linear, + optimizer=opt, + conditions={"pert1": np.random.rand(2, 2, 3)}, + rng=vf_rng, + ) - predict_kwargs = {"max_steps": 3, "throw": False} - trainer = cellflow.training.CellFlowTrainer(solver=solver, predict_kwargs=predict_kwargs) + predict_kwargs = {"max_steps": 3, "throw": False} if solver_class != "eqm" else {"max_steps": 3, "eta": 0.01} + trainer = scaleflow.training.CellFlowTrainer(solver=solver, predict_kwargs=predict_kwargs) + train_dataloader = eqm_dataloader if solver_class == "eqm" else dataloader trainer.train( - dataloader=dataloader, + dataloader=train_dataloader, num_iterations=2, valid_freq=1, ) @@ -89,10 +116,18 @@ def test_predict_batch(self, dataloader, solver_class): ) assert diff_nonbatched - diff_batched > 0.5 + @pytest.mark.parametrize("solver_class", ["otfm", "eqm"]) @pytest.mark.parametrize("ema", [0.5, 1.0]) - def test_EMA(self, dataloader, ema): - vf_class = cellflow.networks.ConditionalVelocityField - drug = np.random.rand(2, 1, 3) + def test_EMA(self, dataloader, eqm_dataloader, solver_class, ema): + if solver_class == "otfm": + vf_class = scaleflow.networks.ConditionalVelocityField + drug = np.random.rand(2, 1, 3) + condition_key = "drug" + else: + vf_class = scaleflow.networks.EquilibriumVelocityField + drug = np.random.rand(2, 2, 3) + condition_key = "pert1" + opt = optax.adam(1e-3) vf1 = vf_class( output_dim=5, @@ -102,18 +137,30 @@ def test_EMA(self, dataloader, ema): decoder_dims=(5, 5), ) - solver1 = _otfm.OTFlowMatching( - vf=vf1, - match_fn=match_linear, - probability_path=dynamics.ConstantNoiseFlow(0.0), - optimizer=opt, - conditions={"drug": drug}, - rng=vf_rng, - ema=ema, - ) - trainer1 = cellflow.training.CellFlowTrainer(solver=solver1) + if solver_class == "otfm": + solver1 = _otfm.OTFlowMatching( + vf=vf1, + match_fn=match_linear, + probability_path=dynamics.ConstantNoiseFlow(0.0), + optimizer=opt, + conditions={condition_key: drug}, + rng=vf_rng, + ema=ema, + ) + else: + solver1 = _eqm.EquilibriumMatching( + vf=vf1, + match_fn=match_linear, + optimizer=opt, + conditions={condition_key: drug}, + rng=vf_rng, + ema=ema, + ) + + trainer1 = scaleflow.training.CellFlowTrainer(solver=solver1) + train_dataloader = eqm_dataloader if solver_class == "eqm" else dataloader trainer1.train( - dataloader=dataloader, + dataloader=train_dataloader, num_iterations=5, valid_freq=10, ) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index f1346ce6..6c00ad75 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -7,7 +7,7 @@ class TestCallbacks: @pytest.mark.parametrize("metrics", [["r_squared"]]) def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): - from cellflow.training import PCADecodedMetrics + from scaleflow.training import PCADecodedMetrics decoded_metrics_callback = PCADecodedMetrics( metrics=metrics, @@ -22,8 +22,8 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): def test_vae_reconstruction(self, metrics): from scvi.data import synthetic_iid - from cellflow.external import CFJaxSCVI - from cellflow.training import VAEDecodedMetrics + from scaleflow.external import CFJaxSCVI + from scaleflow.training import VAEDecodedMetrics adata = synthetic_iid() CFJaxSCVI.setup_anndata( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index beef4eb1..fb693bf2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,10 +7,10 @@ import pytest from ott.neural.methods.flows import dynamics -import cellflow -from cellflow.solvers import _otfm -from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics -from cellflow.utils import match_linear +import scaleflow +from scaleflow.solvers import _otfm +from scaleflow.training import CellFlowTrainer, ComputationCallback, Metrics +from scaleflow.utils import match_linear x_test = jnp.ones((10, 5)) * 10 t_test = jnp.ones((10, 1)) @@ -43,9 +43,9 @@ def on_train_end(self, source_data, validation_data, predicted_data, solver): class TestTrainer: @pytest.mark.parametrize("valid_freq", [10, 1]) - def test_cellflow_trainer(self, dataloader, valid_freq): + def test_scaleflow_trainer(self, dataloader, valid_freq): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -79,9 +79,9 @@ def test_cellflow_trainer(self, dataloader, valid_freq): assert out[1].shape == (1, 12) @pytest.mark.parametrize("use_validdata", [True, False]) - def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_validdata): + def test_scaleflow_trainer_with_callback(self, dataloader, valid_loader, use_validdata): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -124,9 +124,9 @@ def test_cellflow_trainer_with_callback(self, dataloader, valid_loader, use_vali assert isinstance(out[1], np.ndarray) assert out[1].shape == (1, 12) - def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): + def test_scaleflow_trainer_with_custom_callback(self, dataloader, valid_loader): opt = optax.adam(1e-3) - vf = cellflow.networks.ConditionalVelocityField( + vf = scaleflow.networks.ConditionalVelocityField( condition_mode="stochastic", output_dim=5, max_combination_length=2, @@ -164,14 +164,14 @@ def test_cellflow_trainer_with_custom_callback(self, dataloader, valid_loader): def test_predict_kwargs_iter(self, dataloader, valid_loader): opt_1 = optax.adam(1e-3) opt_2 = optax.adam(1e-3) - vf_1 = cellflow.networks.ConditionalVelocityField( + vf_1 = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, hidden_dims=(32, 32), decoder_dims=(32, 32), ) - vf_2 = cellflow.networks.ConditionalVelocityField( + vf_2 = scaleflow.networks.ConditionalVelocityField( output_dim=5, max_combination_length=2, condition_embedding_dim=12, @@ -196,13 +196,13 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader): ) metric_to_compute = "e_distance" - metrics_callback = cellflow.training.Metrics(metrics=[metric_to_compute]) + metrics_callback = scaleflow.training.Metrics(metrics=[metric_to_compute]) predict_kwargs_1 = {"max_steps": 3, "throw": False} predict_kwargs_2 = {"max_steps": 500, "throw": False} - trainer_1 = cellflow.training.CellFlowTrainer(solver=model_1, predict_kwargs=predict_kwargs_1) - trainer_2 = cellflow.training.CellFlowTrainer(solver=model_2, predict_kwargs=predict_kwargs_2) + trainer_1 = scaleflow.training.CellFlowTrainer(solver=model_1, predict_kwargs=predict_kwargs_1) + trainer_2 = scaleflow.training.CellFlowTrainer(solver=model_2, predict_kwargs=predict_kwargs_2) start_1 = time.time() trainer_1.train(