Skip to content

Commit

Permalink
Added mudata minification models for MULTIVI & TOTALVI as well as tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Nov 12, 2024
1 parent 9622f88 commit 3299355
Show file tree
Hide file tree
Showing 11 changed files with 685 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Add MuData Minification option to {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.TOTALVI` {pr}`30XX`.
- Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method
{meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`.

Expand Down
4 changes: 3 additions & 1 deletion src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None:
return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None)


def _is_minified(adata: AnnData | str) -> bool:
def _is_minified(adata: AnnOrMuData | str) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, MuData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
Expand Down
121 changes: 117 additions & 4 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
ProteinObsmField,
StringUnsField,
)
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -28,11 +32,12 @@
)
from scvi.model.base import (
ArchesMixin,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
UnsupervisedTrainingMixin,
VAEMixin,
)
from scvi.model.base._de_core import _de_core
from scvi.model.utils import get_minified_mudata
from scvi.module import MULTIVAE
from scvi.train import AdversarialTrainingPlan
from scvi.train._callbacks import SaveBestState
Expand All @@ -45,12 +50,19 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import AnnOrMuData, Number
from scvi._types import AnnOrMuData, MinifiedDataType, Number
from scvi.data.fields import (
BaseAnnDataField,
)

_MULTIVI_LATENT_QZM = "_multivi_latent_qzm"
_MULTIVI_LATENT_QZV = "_multivi_latent_qzv"
_MULTIVI_OBSERVED_LIB_SIZE = "_multivi_observed_lib_size"

logger = logging.getLogger(__name__)


class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseMudataMinifiedModeModelClass):
"""Integration of multi-modal and single-modality data :cite:p:`AshuachGabitto21`.
MultiVI is used to integrate multiomic datasets with single-modality (expression
Expand Down Expand Up @@ -174,6 +186,10 @@ def __init__(

use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry

# TODO: ADD MINIFICATION CONSIDERATION HERE?
# if not use_size_factor_key and self.minified_data_type is None:
# library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

if "n_proteins" in self.summary_stats:
n_proteins = self.summary_stats.n_proteins
else:
Expand Down Expand Up @@ -224,6 +240,7 @@ def __init__(
self.n_genes = n_genes
self.n_regions = n_regions
self.n_proteins = n_proteins
self.module.minified_data_type = self.minified_data_type

@devices_dsp.dedent
def train(
Expand Down Expand Up @@ -414,6 +431,7 @@ def get_latent_representation(
indices: Sequence[int] | None = None,
give_mean: bool = True,
batch_size: int | None = None,
return_dist: bool = False,
) -> np.ndarray:
r"""Return the latent representation for each cell.
Expand All @@ -430,6 +448,9 @@ def get_latent_representation(
Give mean of distribution or sample from it.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_dist
If ``True``, returns the mean and variance of the latent distribution. Otherwise,
returns the mean of the latent distribution.
Returns
-------
Expand Down Expand Up @@ -457,6 +478,8 @@ def get_latent_representation(
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent = []
qz_means = []
qz_vars = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
Expand All @@ -473,8 +496,17 @@ def get_latent_representation(
else:
z = qz_m

if return_dist:
qz_means.append(qz_m.cpu())
qz_vars.append(qz_v.cpu())
continue

latent += [z.cpu()]
return torch.cat(latent).numpy()

if return_dist:
return torch.cat(qz_means).numpy(), torch.cat(qz_vars).numpy()
else:
return torch.cat(latent).numpy()

@torch.inference_mode()
def get_accessibility_estimates(
Expand Down Expand Up @@ -1113,6 +1145,87 @@ def setup_mudata(
mod_required=True,
)
)
# TODO: register new fields if the adata is minified
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_mudata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_MULTIVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_MULTIVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_MULTIVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_mudata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's mudata.
Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the mudata is
minified.
Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored
Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

# if self.module.use_observed_lib_size is False:
# raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_mudata(self.adata, minified_data_type)
minified_adata.obsm[_MULTIVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_MULTIVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_MULTIVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
103 changes: 98 additions & 5 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type
from scvi.data.fields import NumericalObsField, ObsmField, StringUnsField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -22,11 +24,12 @@
get_max_epochs_heuristic,
)
from scvi.model.base._de_core import _de_core
from scvi.model.utils import get_minified_mudata
from scvi.module import TOTALVAE
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
from .base import ArchesMixin, BaseMudataMinifiedModeModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand All @@ -35,12 +38,19 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import AnnOrMuData, Number
from scvi._types import AnnOrMuData, MinifiedDataType, Number
from scvi.data.fields import (
BaseAnnDataField,
)

_TOTALVI_LATENT_QZM = "_totalvi_latent_qzm"
_TOTALVI_LATENT_QZV = "_totalvi_latent_qzv"
_TOTALVI_OBSERVED_LIB_SIZE = "_totalvi_observed_lib_size"

logger = logging.getLogger(__name__)


class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMudataMinifiedModeModelClass):
"""total Variational Inference :cite:p:`GayosoSteier21`.
Parameters
Expand Down Expand Up @@ -162,7 +172,8 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
# TODO: ADD MINIFICATION CONSIDERATION
if not use_size_factor_key and self.minified_data_type is None:
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand All @@ -184,6 +195,7 @@ def __init__(
library_log_vars=library_log_vars,
**model_kwargs,
)
self.module.minified_data_type = self.minified_data_type
self._model_summary_string = (
f"TotalVI Model with the following params: \nn_latent: {n_latent}, "
f"gene_dispersion: {gene_dispersion}, protein_dispersion: {protein_dispersion}, "
Expand Down Expand Up @@ -1331,6 +1343,87 @@ def setup_mudata(
mod_required=True,
),
]
# TODO: register new fields if the mudata is minified
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_mudata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for mudata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_TOTALVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_TOTALVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_TOTALVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_mudata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's mudata.
Minifies the mudata, and registers new mudata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the mudata is
minified.
Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored
Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_mudata(self.adata, minified_data_type)
minified_adata.obsm[_TOTALVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_TOTALVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_TOTALVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_mudata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
7 changes: 6 additions & 1 deletion src/scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseMinifiedModeModelClass, BaseModelClass
from ._base_model import (
BaseMinifiedModeModelClass,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
)
from ._differential import DifferentialComputation
from ._embedding_mixin import EmbeddingMixin
from ._jaxmixin import JaxTrainingMixin
Expand All @@ -26,5 +30,6 @@
"DifferentialComputation",
"JaxTrainingMixin",
"BaseMinifiedModeModelClass",
"BaseMudataMinifiedModeModelClass",
"EmbeddingMixin",
]
Loading

0 comments on commit 3299355

Please sign in to comment.