diff --git a/CHANGELOG.md b/CHANGELOG.md index c261cc775e..0c89a6484e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add MuData Minification option to {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.TOTALVI` {pr}`30XX`. - Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index fc6228a29f..3fff82a13b 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None: return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None) -def _is_minified(adata: AnnData | str) -> bool: +def _is_minified(adata: AnnOrMuData | str) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY if isinstance(adata, AnnData): return adata.uns.get(uns_key, None) is not None + elif isinstance(adata, MuData): + return adata.uns.get(uns_key, None) is not None elif isinstance(adata, str): with h5py.File(adata) as fp: return uns_key in read_elem(fp["uns"]).keys() diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 0bc5dfe517..0e625e9ee1 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -13,13 +13,17 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE +from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( CategoricalJointObsField, CategoricalObsField, LayerField, NumericalJointObsField, NumericalObsField, + ObsmField, ProteinObsmField, + StringUnsField, ) from scvi.model._utils import ( _get_batch_code_from_category, @@ -28,11 +32,12 @@ ) from scvi.model.base import ( ArchesMixin, - BaseModelClass, + BaseMudataMinifiedModeModelClass, UnsupervisedTrainingMixin, VAEMixin, ) from scvi.model.base._de_core import _de_core +from scvi.model.utils import get_minified_mudata from scvi.module import MULTIVAE from scvi.train import AdversarialTrainingPlan from scvi.train._callbacks import SaveBestState @@ -45,12 +50,19 @@ from anndata import AnnData from mudata import MuData - from scvi._types import AnnOrMuData, Number + from scvi._types import AnnOrMuData, MinifiedDataType, Number + from scvi.data.fields import ( + BaseAnnDataField, + ) + +_MULTIVI_LATENT_QZM = "_multivi_latent_qzm" +_MULTIVI_LATENT_QZV = "_multivi_latent_qzv" +_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" logger = logging.getLogger(__name__) -class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): +class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): """Integration of multi-modal and single-modality data :cite:p:`AshuachGabitto21`. MultiVI is used to integrate multiomic datasets with single-modality (expression @@ -174,6 +186,10 @@ def __init__( use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry + # TODO: ADD MINIFICATION CONSIDERATION HERE? + # if not use_size_factor_key and self.minified_data_type is None: + # library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) + if "n_proteins" in self.summary_stats: n_proteins = self.summary_stats.n_proteins else: @@ -224,6 +240,7 @@ def __init__( self.n_genes = n_genes self.n_regions = n_regions self.n_proteins = n_proteins + self.module.minified_data_type = self.minified_data_type @devices_dsp.dedent def train( @@ -414,6 +431,7 @@ def get_latent_representation( indices: Sequence[int] | None = None, give_mean: bool = True, batch_size: int | None = None, + return_dist: bool = False, ) -> np.ndarray: r"""Return the latent representation for each cell. @@ -430,6 +448,9 @@ def get_latent_representation( Give mean of distribution or sample from it. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_dist + If ``True``, returns the mean and variance of the latent distribution. Otherwise, + returns the mean of the latent distribution. Returns ------- @@ -457,6 +478,8 @@ def get_latent_representation( adata = self._validate_anndata(adata) scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent = [] + qz_means = [] + qz_vars = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) @@ -473,8 +496,17 @@ def get_latent_representation( else: z = qz_m + if return_dist: + qz_means.append(qz_m.cpu()) + qz_vars.append(qz_v.cpu()) + continue + latent += [z.cpu()] - return torch.cat(latent).numpy() + + if return_dist: + return torch.cat(qz_means).numpy(), torch.cat(qz_vars).numpy() + else: + return torch.cat(latent).numpy() @torch.inference_mode() def get_accessibility_estimates( @@ -1113,6 +1145,87 @@ def setup_mudata( mod_required=True, ) ) + # TODO: register new fields if the adata is minified + mdata_minify_type = _get_adata_minify_type(mdata) + if mdata_minify_type is not None: + mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + + @staticmethod + def _get_fields_for_mudata_minification( + minified_data_type: MinifiedDataType, + ) -> list[BaseAnnDataField]: + """Return the fields required for adata minification of the given minified_data_type.""" + if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + fields = [ + ObsmField( + REGISTRY_KEYS.LATENT_QZM_KEY, + _MULTIVI_LATENT_QZM, + ), + ObsmField( + REGISTRY_KEYS.LATENT_QZV_KEY, + _MULTIVI_LATENT_QZV, + ), + NumericalObsField( + REGISTRY_KEYS.OBSERVED_LIB_SIZE, + _MULTIVI_OBSERVED_LIB_SIZE, + ), + ] + else: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + fields.append( + StringUnsField( + REGISTRY_KEYS.MINIFY_TYPE_KEY, + _ADATA_MINIFY_TYPE_UNS_KEY, + ), + ) + return fields + + def minify_mudata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns + containing minified-adata type, and library size. + This also sets the appropriate property on the module to indicate that the mudata is + minified. + + Parameters + ---------- + minified_data_type + How to minify the data. Currently only supports `latent_posterior_parameters`. + If minified_data_type == `latent_posterior_parameters`: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + use_latent_qzm_key + Key to use in `adata.obsm` where the latent qzm params are stored + use_latent_qzv_key + Key to use in `adata.obsm` where the latent qzv params are stored + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + # without removing the original counts. + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + # if self.module.use_observed_lib_size is False: + # raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + + minified_adata = get_minified_mudata(self.adata, minified_data_type) + minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] + minified_adata.obsm[_MULTIVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] + counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + minified_adata.obs[_MULTIVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type) + self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a41123af72..0e3b2cd965 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -12,7 +12,9 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields -from scvi.data._utils import _check_nonnegative_integers +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE +from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type +from scvi.data.fields import NumericalObsField, ObsmField, StringUnsField from scvi.dataloaders import DataSplitter from scvi.model._utils import ( _get_batch_code_from_category, @@ -22,11 +24,12 @@ get_max_epochs_heuristic, ) from scvi.model.base._de_core import _de_core +from scvi.model.utils import get_minified_mudata from scvi.module import TOTALVAE from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp -from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin +from .base import ArchesMixin, BaseMudataMinifiedModeModelClass, RNASeqMixin, VAEMixin if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -35,12 +38,19 @@ from anndata import AnnData from mudata import MuData - from scvi._types import AnnOrMuData, Number + from scvi._types import AnnOrMuData, MinifiedDataType, Number + from scvi.data.fields import ( + BaseAnnDataField, + ) + +_TOTALVI_LATENT_QZM = "_totalvi_latent_qzm" +_TOTALVI_LATENT_QZV = "_totalvi_latent_qzv" +_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size" logger = logging.getLogger(__name__) -class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): +class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMudataMinifiedModeModelClass): """total Variational Inference :cite:p:`GayosoSteier21`. Parameters @@ -162,7 +172,8 @@ def __init__( n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None - if not use_size_factor_key: + # TODO: ADD MINIFICATION CONSIDERATION + if not use_size_factor_key and self.minified_data_type is None: library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( @@ -184,6 +195,7 @@ def __init__( library_log_vars=library_log_vars, **model_kwargs, ) + self.module.minified_data_type = self.minified_data_type self._model_summary_string = ( f"TotalVI Model with the following params: \nn_latent: {n_latent}, " f"gene_dispersion: {gene_dispersion}, protein_dispersion: {protein_dispersion}, " @@ -1331,6 +1343,87 @@ def setup_mudata( mod_required=True, ), ] + # TODO: register new fields if the mudata is minified + mdata_minify_type = _get_adata_minify_type(mdata) + if mdata_minify_type is not None: + mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) + + @staticmethod + def _get_fields_for_mudata_minification( + minified_data_type: MinifiedDataType, + ) -> list[BaseAnnDataField]: + """Return the fields required for mudata minification of the given minified_data_type.""" + if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + fields = [ + ObsmField( + REGISTRY_KEYS.LATENT_QZM_KEY, + _TOTALVI_LATENT_QZM, + ), + ObsmField( + REGISTRY_KEYS.LATENT_QZV_KEY, + _TOTALVI_LATENT_QZV, + ), + NumericalObsField( + REGISTRY_KEYS.OBSERVED_LIB_SIZE, + _TOTALVI_OBSERVED_LIB_SIZE, + ), + ] + else: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + fields.append( + StringUnsField( + REGISTRY_KEYS.MINIFY_TYPE_KEY, + _ADATA_MINIFY_TYPE_UNS_KEY, + ), + ) + return fields + + def minify_mudata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns + containing minified-adata type, and library size. + This also sets the appropriate property on the module to indicate that the mudata is + minified. + + Parameters + ---------- + minified_data_type + How to minify the data. Currently only supports `latent_posterior_parameters`. + If minified_data_type == `latent_posterior_parameters`: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + use_latent_qzm_key + Key to use in `adata.obsm` where the latent qzm params are stored + use_latent_qzv_key + Key to use in `adata.obsm` where the latent qzv params are stored + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + # without removing the original counts. + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + if self.module.use_observed_lib_size is False: + raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") + + minified_adata = get_minified_mudata(self.adata, minified_data_type) + minified_adata.obsm[_TOTALVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] + minified_adata.obsm[_TOTALVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] + counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + minified_adata.obs[_TOTALVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) + self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type) + self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index e8573f8d53..4b38494caf 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -1,5 +1,9 @@ from ._archesmixin import ArchesMixin -from ._base_model import BaseMinifiedModeModelClass, BaseModelClass +from ._base_model import ( + BaseMinifiedModeModelClass, + BaseModelClass, + BaseMudataMinifiedModeModelClass, +) from ._differential import DifferentialComputation from ._embedding_mixin import EmbeddingMixin from ._jaxmixin import JaxTrainingMixin @@ -26,5 +30,6 @@ "DifferentialComputation", "JaxTrainingMixin", "BaseMinifiedModeModelClass", + "BaseMudataMinifiedModeModelClass", "EmbeddingMixin", ] diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index fd47bf1926..137bc9e04a 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -944,3 +944,63 @@ def summary_string(self): hasattr(self, "minified_data_type") and self.minified_data_type is not None ) return summary_string + + +class BaseMudataMinifiedModeModelClass(BaseModelClass): + """Abstract base class for scvi-tools models that can handle minified data.""" + + @property + def minified_data_type(self) -> MinifiedDataType | None: + """The type of minified data associated with this model, if applicable.""" + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + + @abstractmethod + def minify_mudata( + self, + *args, + **kwargs, + ): + """Minifies the model's mudata. + + Minifies the mudata, and registers new mudata fields as required (can be model-specific). + This also sets the appropriate property on the module to indicate that the adata is + minified. + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + + @staticmethod + @abstractmethod + def _get_fields_for_mudata_minification(minified_data_type: MinifiedDataType): + """Return the mudata fields required for adata minification of the given type.""" + + def _update_mudata_and_manager_post_minification( + self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType + ): + """Update the mudata and manager inplace after creating a minified adata.""" + # Register this new adata with the model, creating a new manager in the cache + self._validate_anndata(minified_adata) + new_adata_manager = self.get_anndata_manager(minified_adata, required=True) + # This inplace edits the manager + new_adata_manager.register_new_fields( + self._get_fields_for_mudata_minification(minified_data_type) + ) + # We set the adata attribute of the model as this will update self.registry_ + # and self.adata_manager with the new adata manager + self.adata = minified_adata + + @property + def summary_string(self): + """Summary string of the model.""" + summary_string = super().summary_string + summary_string += "\nModel's adata is minified?: {}".format( + hasattr(self, "minified_data_type") and self.minified_data_type is not None + ) + return summary_string diff --git a/src/scvi/model/utils/__init__.py b/src/scvi/model/utils/__init__.py index 003b763e5e..0ee147802d 100644 --- a/src/scvi/model/utils/__init__.py +++ b/src/scvi/model/utils/__init__.py @@ -1,4 +1,4 @@ from ._mde import mde -from ._minification import get_minified_adata_scrna +from ._minification import get_minified_adata_scrna, get_minified_mudata -__all__ = ["mde", "get_minified_adata_scrna"] +__all__ = ["mde", "get_minified_adata_scrna", "get_minified_mudata"] diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index cf84687bc5..aab9cb79ff 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -1,4 +1,5 @@ from anndata import AnnData +from mudata import MuData from scipy.sparse import csr_matrix from scvi._types import MinifiedDataType @@ -41,3 +42,32 @@ def get_minified_adata_scrna( del bdata.uns[_SCVI_UUID_KEY] bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type return bdata + + +def get_minified_mudata( + mdata: MuData, + minified_data_type: MinifiedDataType, +) -> MuData: + """Returns a minified adata that works for most multi modality models (MULTIVI, TOTALVI). + + Parameters + ---------- + mdata + Original adata, of which we to create a minified version. + minified_data_type + How to minify the data. + """ + if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + bdata = mdata.copy() + for modality in mdata.mod_names: + all_zeros = csr_matrix(mdata[modality].X.shape) + bdata[modality].X = all_zeros + if len(mdata[modality].layers) > 0: + layers = {layer: all_zeros for layer in mdata[modality].layers} + bdata[modality].layers = layers + # Remove scvi uuid key to make bdata fresh w.r.t. the model's manager + del bdata.uns[_SCVI_UUID_KEY] + bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type + return bdata diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 6ad24b65d2..cec268d184 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -15,7 +15,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.module._peakvae import Decoder as DecoderPeakVI -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, FCLayers from ._utils import masked_softmax @@ -179,7 +179,7 @@ def forward(self, z: torch.Tensor, *cat_list: int): return py_, log_pro_back_mean -class MULTIVAE(BaseModuleClass): +class MULTIVAE(BaseMinifiedModeModuleClass): """Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data. Parameters @@ -533,6 +533,9 @@ def __init__( def _get_inference_input(self, tensors): """Get input tensors for the inference model.""" + # from scvi.data._constants import ADATA_MINIFY_TYPE + # TODO: ADD MINIFICATION CONSIDERATION + x = tensors[REGISTRY_KEYS.X_KEY] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index d3fb5488da..ba54ec0b6f 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -18,7 +18,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.model.base import BaseModelClass -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderTOTALVI, EncoderTOTALVI from scvi.nn._utils import ExpActivation @@ -26,7 +26,7 @@ # VAE model -class TOTALVAE(BaseModuleClass): +class TOTALVAE(BaseMinifiedModeModuleClass): """Total variational inference for CITE-seq data. Implements the totalVI model of :cite:p:`GayosoSteier21`. @@ -325,6 +325,9 @@ def get_reconstruction_loss( return reconst_loss_gene, reconst_loss_protein def _get_inference_input(self, tensors): + # from scvi.data._constants import ADATA_MINIFY_TYPE + # TODO: ADD MINIFICATION CONSIDERATION + x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py new file mode 100644 index 0000000000..fab87eb2a5 --- /dev/null +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -0,0 +1,358 @@ +import numpy as np +import pytest + +import scvi +from scvi.data import synthetic_iid +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.data._utils import _is_minified +from scvi.model import MULTIVI, TOTALVI + +_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size" +_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size" + + +def prep_model_mudata(cls=TOTALVI, layer=None, use_size_factor=False): + # create a synthetic dataset + mdata = synthetic_iid(return_mudata=True) + if use_size_factor: + mdata.obs["size_factor"] = np.random.randint(1, 5, size=(mdata.shape[0],)) + if layer is not None: + for mod in mdata.mod_names: + mdata[mod].layers[layer] = mdata[mod].X.copy() + mdata[mod].X = np.zeros_like(mdata[mod].X) + mdata.var["n_counts"] = np.squeeze( + np.concatenate( + [ + np.asarray(np.sum(mdata["rna"].X, axis=0)), + np.asarray(np.sum(mdata["protein_expression"].X, axis=0)), + np.asarray(np.sum(mdata["accessibility"].X, axis=0)), + ] + ) + ) + mdata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(mdata.shape[1], 3)) + mdata["rna"].layers["my_layer"] = np.ones_like(mdata["rna"].X) + mdata_before_setup = mdata.copy() + + # run setup_anndata + setup_kwargs = { + "batch_key": "batch", + } + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor" + + if cls == TOTALVI: + # create and train the model + cls.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + **setup_kwargs, + ) + model = cls(mdata, n_latent=5) + elif cls == MULTIVI: + # create and train the model + cls.setup_mudata( + mdata, + modalities={ + "rna_layer": "rna", + "protein_layer": "protein_expression", + "atac_layer": "accessibility", + }, + **setup_kwargs, + ) + model = cls(mdata, n_latent=5, n_genes=50, n_regions=50) + else: + raise ValueError("Bad Model name as input to test") + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + + # get the mdata lib size + mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) + assert ( + np.min(mdata_lib_size) > 0 + ) # make sure it's not all zeros and there are no negative values + + return model, mdata, mdata_lib_size, mdata_before_setup + + +def assert_approx_equal(a, b): + # Allclose because on GPU, the values are not exactly the same + # as some values are moved to cpu during data minification + np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) + + +def run_test_for_model_with_minified_mudata( + cls=TOTALVI, + layer: str = None, + use_size_factor=False, +): + model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, layer, use_size_factor) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + scvi.settings.seed = 1 + mdata_orig = mdata.copy() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.adata_manager.registry is model.registry_ + + # make sure the original mdata we set up the model with was not changed + assert mdata is not model.adata + assert _is_minified(mdata) is False + assert _is_minified(model.adata) is True + + assert mdata_orig["rna"].layers.keys() == model.adata["rna"].layers.keys() + orig_obs_df = mdata_orig.obs + obs_keys = _TOTALVI_OBSERVED_LIB_SIZE if cls == TOTALVI else _MULTIVI_OBSERVED_LIB_SIZE + orig_obs_df[obs_keys] = mdata_lib_size + assert model.adata.obs.equals(orig_obs_df) + assert model.adata.var_names.equals(mdata_orig.var_names) + assert model.adata.var.equals(mdata_orig.var) + assert model.adata.varm.keys() == mdata_orig.varm.keys() + np.testing.assert_array_equal(model.adata.varm["my_varm"], mdata_orig.varm["my_varm"]) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +@pytest.mark.parametrize("layer", [None, "data_layer"]) +@pytest.mark.parametrize("use_size_factor", [False, True]) +def test_with_minified_mudata(cls, layer: str, use_size_factor: bool): + run_test_for_model_with_minified_mudata(cls=cls, layer=layer, use_size_factor=use_size_factor) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_normalized_expression(cls): + model, mdata, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + exprs_orig = model.get_normalized_expression() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + exprs_new = model.get_normalized_expression() + for ii in range(len(exprs_new)): + assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape + + for ii in range(len(exprs_new)): + np.testing.assert_array_equal(exprs_new[ii], exprs_orig[ii]) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_normalized_expression_non_default_gene_list(cls): + model, mdata, _, _ = prep_model_mudata(cls=cls) + + # non-default gene list and n_samples > 1 + gl = mdata.var_names[:5].to_list() + n_samples = 10 + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + exprs_orig = model.get_normalized_expression( + gene_list=gl, n_samples=n_samples, library_size="latent" + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + # do this so that we generate the same sequence of random numbers in the + # minified and non-minified cases (purely to get the tests to pass). this is + # because in the non-minified case we sample once more (in the call to z_encoder + # during inference) + exprs_new = model.get_normalized_expression( + gene_list=gl, n_samples=n_samples + 1, return_mean=False, library_size="latent" + ) + exprs_new = exprs_new[0][:, :, 1:].mean(2) + + assert exprs_new.shape == (mdata.shape[0], 5) + np.testing.assert_allclose(exprs_new, exprs_orig[0], rtol=3e-1, atol=5e-1) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_validate_unsupported_if_minified(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + common_err_msg = "The {} function currently does not support minified data." + + with pytest.raises(ValueError) as e: + model.get_elbo() + assert str(e.value) == common_err_msg.format("VAEMixin.get_elbo") + + with pytest.raises(ValueError) as e: + model.get_reconstruction_error() + assert str(e.value) == common_err_msg.format("VAEMixin.get_reconstruction_error") + + with pytest.raises(ValueError) as e: + model.get_marginal_ll() + assert str(e.value) == common_err_msg.format("VAEMixin.get_marginal_ll") + + if cls != TOTALVI: + with pytest.raises(AttributeError) as e: + model.get_latent_library_size() + assert str(e.value) == "'MULTIVI' object has no attribute 'get_latent_library_size'" + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_save_then_load(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using the same (minified) mdata. Validate that the + # loaded model has the minified_data_type attribute set as expected. + model, mdata, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + # load saved model with saved (minified) mdata + loaded_model = cls.load(save_path, adata=mdata) + + assert loaded_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using a non-minified mdata. Validate that the + # loaded model does not has the minified_data_type attribute set. + model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + # load saved model with a non-minified mdata + loaded_model = cls.load(save_path, adata=mdata_before_setup) + + assert loaded_model.minified_data_type is None + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_save_then_load_with_minified_mdata(cls, save_path): + # create a model, then save it and its mdata (non-minified). + # Load it back up using a minified mdata. Validate that this + # fails, as expected because we don't have a way to validate + # whether the minified-mdata was set up correctly + model, _, _, _ = prep_model_mudata(cls=cls) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + # loading this model with a minified mdata is not allowed because + # we don't have a way to validate whether the minified-mdata was + # set up correctly + with pytest.raises(KeyError): + cls.load(save_path, adata=model.adata) + + +@pytest.mark.parametrize("cls", [TOTALVI, MULTIVI]) +def test_scvi_with_minified_mdata_get_latent_representation(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + latent_repr_orig = model.get_latent_representation() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + latent_repr_new = model.get_latent_representation() + + np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_scvi_with_minified_mdata_posterior_predictive_sample(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + sample_orig = model.posterior_predictive_sample( + indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + sample_new = model.posterior_predictive_sample( + indices=[1, 2, 3], gene_list=["gene_1", "gene_2"] + ) + assert sample_new.shape == (3, 2) + + np.testing.assert_array_equal(sample_new.todense(), sample_orig.todense()) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_scvi_with_minified_mdata_get_feature_correlation_matrix(cls): + model, _, _, _ = prep_model_mudata(cls=cls) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + fcm_orig = model.get_feature_correlation_matrix( + correlation_type="pearson", + n_samples=1, + transform_batch=["batch_0", "batch_1"], + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + fcm_new = model.get_feature_correlation_matrix( + correlation_type="pearson", + n_samples=1, + transform_batch=["batch_0", "batch_1"], + ) + + assert_approx_equal(fcm_new, fcm_orig)