diff --git a/src/cfp/data/_datamanager.py b/src/cfp/data/_datamanager.py index 8f365469..fed87365 100644 --- a/src/cfp/data/_datamanager.py +++ b/src/cfp/data/_datamanager.py @@ -227,7 +227,7 @@ def get_prediction_data( is stored or ``'X'`` to use :attr:`~anndata.AnnData.X`. covariate_data A :class:`~pandas.DataFrame` with columns defining the covariates as - in :meth:`cfp.model.CellFlow.prepare_data` and stored in + in :meth:`cfp.model.CellFlow.prepare_data` and stored in :attr:`cfp.model.CellFlow.data_manager`. rep_dict Dictionary with representations of the covariates. @@ -799,10 +799,26 @@ def _check_covariate_type(adata: anndata.AnnData, covars: Sequence[str]) -> bool col_is_cat.append(False) continue if adata.obs[covariate].isin(["True", "False", True, False]).all(): - adata.obs[covariate] = adata.obs[covariate].astype(int) + if not isinstance(adata.obs[covariate].dtype, pd.CategoricalDtype): + logger.warning( + ( + f"Converting boolean covariate '{covariate}' to categorical, which requires " + f"instantiation of the `adata` object and thus leads to higher memory requirements. " + f"Consider changing '{covariate}' to to categorical before calling this function.", + ) + ) + adata.obs[covariate] = adata.obs[covariate].astype(int) col_is_cat.append(False) continue try: + if not isinstance(adata.obs[covariate].dtype, pd.CategoricalDtype): + logger.warning( + ( + f"Converting boolean covariate '{covariate}' to categorical, which requires " + f"instantiation of the `adata` object and thus leads to higher memory requirements. " + f"Consider changing '{covariate}' to to categorical before calling this function.", + ) + ) adata.obs[covariate] = adata.obs[covariate].astype("category") col_is_cat.append(True) except ValueError as e: diff --git a/src/cfp/external/_scvi.py b/src/cfp/external/_scvi.py index 988ee080..08830fa6 100644 --- a/src/cfp/external/_scvi.py +++ b/src/cfp/external/_scvi.py @@ -66,8 +66,8 @@ def get_latent_representation( Parameters ---------- adata - :class:`~anndata.AnnData` object with equivalent structure to initial - :class:`~anndata.AnnData` object. If `:obj:`None`, defaults to the + :class:`~anndata.AnnData` object with equivalent structure to initial + :class:`~anndata.AnnData` object. If `:obj:`None`, defaults to the :class:`~anndata.AnnData` object used to initialize the model. indices Indices of cells in adata to use. If :obj:`None`, all cells are used. @@ -76,7 +76,7 @@ def get_latent_representation( n_samples Number of samples to use for computing the latent representation. batch_size - Minibatch size for data loading into model. Defaults to + Minibatch size for data loading into model. Defaults to :attr:`scvi.settings.ScviConfig.batch_size`. Returns diff --git a/src/cfp/model/_cellflow.py b/src/cfp/model/_cellflow.py index 6161e90e..79354419 100644 --- a/src/cfp/model/_cellflow.py +++ b/src/cfp/model/_cellflow.py @@ -16,8 +16,8 @@ from cfp import _constants from cfp._logging import logger -from cfp._types import Layers_separate_input_t, Layers_t, ArrayLike -from cfp.data._data import ConditionData, ValidationData, TrainingData +from cfp._types import ArrayLike, Layers_separate_input_t, Layers_t +from cfp.data._data import ConditionData, TrainingData, ValidationData from cfp.data._dataloader import PredictionSampler, TrainSampler, ValidationSampler from cfp.data._datamanager import DataManager from cfp.model._utils import _write_predictions @@ -131,9 +131,9 @@ def prepare_data( :attr:`~anndata.AnnData.obs` as columns ``drug_1`` and ``drug_2`` with three different drugs ``DrugA``, ``DrugB``, and ``DrugC``, and ``dose_1`` and ``dose_2`` for their dosages, respectively. We store the embeddings of the drugs in - :attr:`~anndata.AnnData.uns` under the key ``drug_embeddings``, while the dosage - columns are numeric. Moreover, we have a covariate ``cell_type`` with values - ``cell_typeA`` and ``cell_typeB``, with embeddings stored in + :attr:`~anndata.AnnData.uns` under the key ``drug_embeddings``, while the dosage + columns are numeric. Moreover, we have a covariate ``cell_type`` with values + ``cell_typeA`` and ``cell_typeB``, with embeddings stored in :attr:`~anndata.AnnData.uns` under the key ``cell_type_embeddings``. Note that we then also have to set ``'split_covariates'`` as we assume we have an unperturbed population for each cell type. @@ -567,7 +567,7 @@ def predict( covariate_data Covariate data defining the condition to predict. This :class:`~pandas.DataFrame` should have the same columns as :attr:`~anndata.AnnData.obs` of - :attr:`cfp.model.CellFlow.adata`, and as registered in + :attr:`cfp.model.CellFlow.adata`, and as registered in :attr:`cfp.model.CellFlow.data_manager`. sample_rep Key in :attr:`~anndata.AnnData.obsm` where the sample representation is stored or @@ -838,7 +838,6 @@ def train_data(self, data: TrainingData) -> None: ) self._train_data = data - @velocity_field.setter def velocity_field(self, vf: ConditionalVelocityField) -> None: """Set the velocity field.""" diff --git a/src/cfp/networks/__init__.py b/src/cfp/networks/__init__.py index bfb56103..8a1dfa17 100644 --- a/src/cfp/networks/__init__.py +++ b/src/cfp/networks/__init__.py @@ -3,8 +3,8 @@ MLPBlock, SeedAttentionPooling, SelfAttention, - TokenAttentionPooling, SelfAttentionBlock, + TokenAttentionPooling, ) from cfp.networks._velocity_field import ConditionalVelocityField diff --git a/src/cfp/networks/_set_encoders.py b/src/cfp/networks/_set_encoders.py index 59296e61..7bd95476 100644 --- a/src/cfp/networks/_set_encoders.py +++ b/src/cfp/networks/_set_encoders.py @@ -405,7 +405,7 @@ class ConditionEncoder(BaseModule): pooling Pooling method, should be one of: - - ``'mean'``: Aggregates combinations of covariates by the mean of their learned + - ``'mean'``: Aggregates combinations of covariates by the mean of their learned embeddings. - ``'attention_token'``: Aggregates combinations of covariates by an attention mechanism with a token. diff --git a/src/cfp/networks/_velocity_field.py b/src/cfp/networks/_velocity_field.py index 9f1e9768..24a422eb 100644 --- a/src/cfp/networks/_velocity_field.py +++ b/src/cfp/networks/_velocity_field.py @@ -263,31 +263,32 @@ def create_train_state( def output_dims(self): """Dimensions of the output layers.""" return tuple(self.decoder_dims) + (self.output_dim,) + @property def time_encoder(self): """The time encoder used.""" return self._time_encoder - + @time_encoder.setter def time_encoder(self, encoder): """Set the time encoder.""" self._time_encoder = encoder - + @property def x_encoder(self): """The x encoder used.""" return self._x_encoder - + @x_encoder.setter def x_encoder(self, encoder): """Set the x encoder.""" self._x_encoder = encoder - + @property def decoder(self): """The decoder used.""" return self._decoder - + @decoder.setter def decoder(self, decoder): """Set the decoder.""" diff --git a/src/cfp/preprocessing/_gene_emb.py b/src/cfp/preprocessing/_gene_emb.py index aab798b8..7b9473c1 100644 --- a/src/cfp/preprocessing/_gene_emb.py +++ b/src/cfp/preprocessing/_gene_emb.py @@ -228,7 +228,9 @@ def get_model_and_tokenizer( model_name: str, use_cuda: bool, cache_dir: None | str ) -> tuple[EsmModel, AutoTokenizer]: model_path = os.path.join("facebook", model_name) - model = EsmModel.from_pretrained(model_path, cache_dir=cache_dir, add_pooling_layer=False) + model = EsmModel.from_pretrained( + model_path, cache_dir=cache_dir, add_pooling_layer=False + ) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir) if use_cuda: diff --git a/src/cfp/preprocessing/_pca.py b/src/cfp/preprocessing/_pca.py index 0c8bf782..6aeed09a 100644 --- a/src/cfp/preprocessing/_pca.py +++ b/src/cfp/preprocessing/_pca.py @@ -2,6 +2,7 @@ import numpy as np import scanpy as sc from scipy.sparse import csr_matrix + from cfp._types import ArrayLike __all__ = ["centered_pca", "reconstruct_pca", "project_pca"] @@ -38,7 +39,7 @@ def centered_pca( Returns ------- - If ``copy`` is :obj:`True`, returns a new :class:`~anndata.AnnData` object with the PCA + If ``copy`` is :obj:`True`, returns a new :class:`~anndata.AnnData` object with the PCA results stored in :attr:`~anndata.AnnData.obsm`. Otherwise, updates ``adata`` in place. Sets the following fields: @@ -111,7 +112,7 @@ def reconstruct_pca( query_adata An :class:`~anndata.AnnData` object with the query data. use_rep : str - Representation to use for PCA. If ``'X'``, uses :attr:`~anndata.AnnData.X`. Otherwise, uses + Representation to use for PCA. If ``'X'``, uses :attr:`~anndata.AnnData.X`. Otherwise, uses ``adata.obsm[use_rep]``. ref_adata An :class:`~anndata.AnnData` object with the reference data containing diff --git a/src/cfp/preprocessing/_preprocessing.py b/src/cfp/preprocessing/_preprocessing.py index 455ca71f..d45d64d1 100644 --- a/src/cfp/preprocessing/_preprocessing.py +++ b/src/cfp/preprocessing/_preprocessing.py @@ -4,9 +4,9 @@ import anndata as ad import numpy as np import sklearn.preprocessing as preprocessing -from cfp._types import ArrayLike from cfp._logging import logger +from cfp._types import ArrayLike from cfp.data._utils import _to_list __all__ = ["encode_onehot", "annotate_compounds", "get_molecular_fingerprints"] @@ -30,7 +30,7 @@ def annotate_compounds( query_id_type Type of the compound identifiers. Either ``'name'`` or ``'cid'``. obs_key_prefixes - Prefix for the keys in :attr:`~anndata.AnnData.obs` to store the annotations. If :obj:`None`, + Prefix for the keys in :attr:`~anndata.AnnData.obs` to store the annotations. If :obj:`None`, uses ``compound_keys`` as prefixes. copy Return a copy of ``adata`` instead of updating it in place. diff --git a/src/cfp/preprocessing/_wknn.py b/src/cfp/preprocessing/_wknn.py index 7284aa0c..502c2e7a 100644 --- a/src/cfp/preprocessing/_wknn.py +++ b/src/cfp/preprocessing/_wknn.py @@ -4,10 +4,10 @@ import jax import numpy as np import pandas as pd -from cfp._types import ArrayLike from scipy import sparse from cfp._logging import logger +from cfp._types import ArrayLike __all__ = ["compute_wknn", "transfer_labels"] @@ -65,7 +65,7 @@ def compute_wknn( Returns ------- If ``copy`` is :obj:`True`, returns a new :class:`~anndata.AnnData` object with the - weighted k-nearest neighbors stored in :attr:`~anndata.AnnData.uns`. Otherwise, updates + weighted k-nearest neighbors stored in :attr:`~anndata.AnnData.uns`. Otherwise, updates ``adata`` in place. Sets the following fields: @@ -111,7 +111,7 @@ def transfer_labels( label_key : str Key in :attr:`~anndata.AnnData.obs` of ``ref_adata`` containing the labels wknn_key : str - Key in :attr:`~anndata.AnnData.uns` of ``ref_adata`` containing the weighted k-nearest + Key in :attr:`~anndata.AnnData.uns` of ``ref_adata`` containing the weighted k-nearest neighbors graph copy : bool Return a copy of ``query_adata`` instead of updating it in place @@ -119,7 +119,7 @@ def transfer_labels( Returns ------- If ``copy`` is :obj:`True`, returns a new :class:`~anndata.AnnData` object with the - transferred labels stored in :attr:`~anndata.AnnData.obs`. Otherwise, updates ``adata`` in + transferred labels stored in :attr:`~anndata.AnnData.obs`. Otherwise, updates ``adata`` in place. Sets the following fields: diff --git a/src/cfp/training/_callbacks.py b/src/cfp/training/_callbacks.py index 68f938a9..565cee99 100644 --- a/src/cfp/training/_callbacks.py +++ b/src/cfp/training/_callbacks.py @@ -6,8 +6,8 @@ import jax.tree as jt import jax.tree_util as jtu import numpy as np -from cfp._types import ArrayLike +from cfp._types import ArrayLike from cfp.metrics._metrics import compute_e_distance, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div __all__ = [ @@ -225,7 +225,7 @@ class PCADecodedMetrics(Metrics): An :class:`~anndata.AnnData` object with the reference data containing ``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``. metrics - List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, + List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, ``"sinkhorn_div"``, and ``"e_distance"``. metric_aggregations List of aggregation functions to use for each metric. Supported aggregations are ``"mean"`` @@ -430,8 +430,8 @@ class CallbackRunner: Parameters ---------- callbacks - List of callbacks to run. Callbacks should be of type - :class:`~cfp.training.ComputationCallback` or + List of callbacks to run. Callbacks should be of type + :class:`~cfp.training.ComputationCallback` or :class:`~cfp.training.LoggingCallback` Returns