From a5822514048033516492e8facdf2e8c0cefe0c45 Mon Sep 17 00:00:00 2001 From: Hrovatin Date: Sun, 14 Jan 2024 21:46:18 +0100 Subject: [PATCH 01/60] add model and tests --- scvi/external/__init__.py | 2 + scvi/external/csi/model/__init__.py | 7 + scvi/external/csi/model/_model.py | 323 +++++++++++++++ scvi/external/csi/model/_training.py | 14 + scvi/external/csi/model/_utils.py | 149 +++++++ scvi/external/csi/module/__init__.py | 11 + scvi/external/csi/module/_loss_recorder.py | 39 ++ scvi/external/csi/module/_module.py | 440 +++++++++++++++++++++ scvi/external/csi/module/_priors.py | 118 ++++++ scvi/external/csi/nn/__init__.py | 9 + scvi/external/csi/nn/_base_components.py | 341 ++++++++++++++++ scvi/external/csi/train/__init__.py | 7 + scvi/external/csi/train/_trainingplans.py | 278 +++++++++++++ tests/external/csi/test_model.py | 178 +++++++++ 14 files changed, 1916 insertions(+) create mode 100644 scvi/external/csi/model/__init__.py create mode 100644 scvi/external/csi/model/_model.py create mode 100644 scvi/external/csi/model/_training.py create mode 100644 scvi/external/csi/model/_utils.py create mode 100644 scvi/external/csi/module/__init__.py create mode 100644 scvi/external/csi/module/_loss_recorder.py create mode 100644 scvi/external/csi/module/_module.py create mode 100644 scvi/external/csi/module/_priors.py create mode 100644 scvi/external/csi/nn/__init__.py create mode 100644 scvi/external/csi/nn/_base_components.py create mode 100644 scvi/external/csi/train/__init__.py create mode 100644 scvi/external/csi/train/_trainingplans.py create mode 100644 tests/external/csi/test_model.py diff --git a/scvi/external/__init__.py b/scvi/external/__init__.py index 8eb7cc0567..560b5770a5 100644 --- a/scvi/external/__init__.py +++ b/scvi/external/__init__.py @@ -1,5 +1,6 @@ from .cellassign import CellAssign from .contrastivevi import ContrastiveVI +from .csi.model import Model as SysVI from .gimvi import GIMVI from .poissonvi import POISSONVI from .scar import SCAR @@ -19,4 +20,5 @@ "SCBASSET", "POISSONVI", "ContrastiveVI", + "SysVI", ] diff --git a/scvi/external/csi/model/__init__.py b/scvi/external/csi/model/__init__.py new file mode 100644 index 0000000000..406e1b34e6 --- /dev/null +++ b/scvi/external/csi/model/__init__.py @@ -0,0 +1,7 @@ +from ._model import ( + Model, +) + +__all__ = [ + "Model", +] diff --git a/scvi/external/csi/model/_model.py b/scvi/external/csi/model/_model.py new file mode 100644 index 0000000000..eae48b80af --- /dev/null +++ b/scvi/external/csi/model/_model.py @@ -0,0 +1,323 @@ +import logging +from collections.abc import Sequence +from typing import Optional, Union + +import numpy as np +import pandas as pd +import torch +from anndata import AnnData +from typing_extensions import Literal + +from scvi import REGISTRY_KEYS +from scvi.data import AnnDataManager +from scvi.data.fields import ( + LayerField, + ObsmField, +) +from scvi.external.csi.module import Module +from scvi.model.base import BaseModelClass +from scvi.utils import setup_anndata_dsp + +from ._training import TrainingCustom +from ._utils import prepare_metadata + +logger = logging.getLogger(__name__) + + +class Model(TrainingCustom, BaseModelClass): + def __init__( + self, + adata: AnnData, + prior: Literal["standard_normal", "vamp"] = "vamp", + n_prior_components=5, + pseudoinputs_data_indices: Optional[np.array] = None, + **model_kwargs, + ): + """CVAE integration model with optional VampPrior and latent cycle-consistency loss + + Parameters + ---------- + adata + AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. + prior + The prior to be used. You can choose between "standard_normal" and "vamp". + n_prior_components + Number of prior components in VampPrior. + pseudoinputs_data_indices + By default (based on pseudoinputs_data_init), + VAMP prior pseudoinputs are randomly selected from data. + Alternatively, one can specify pseudoinput indices using this parameter. + **model_kwargs + Keyword args for :class:`~scvi.external.csi.module.Module` + """ + super().__init__(adata) + + if prior == "vamp": + if pseudoinputs_data_indices is None: + pseudoinputs_data_indices = np.random.randint( + 0, adata.shape[0], n_prior_components + ) + pseudoinput_data = next( + iter( + self._make_data_loader( + adata=adata, + indices=pseudoinputs_data_indices, + batch_size=n_prior_components, + shuffle=False, + ) + ) + ) + else: + pseudoinput_data = None + + n_cov_const = ( + adata.obsm["covariates"].shape[1] if "covariates" in adata.obsm else 0 + ) + cov_embed_sizes = ( + pd.DataFrame(adata.obsm["covariates_embed"]).nunique(axis=0).to_list() + if "covariates_embed" in adata.obsm + else [] + ) + + # self.summary_stats provides information about anndata dimensions and other tensor info + self.module = Module( + n_input=adata.shape[1], + n_cov_const=n_cov_const, + cov_embed_sizes=cov_embed_sizes, + n_system=adata.obsm["system"].shape[1], + prior=prior, + n_prior_components=n_prior_components, + pseudoinput_data=pseudoinput_data, + **model_kwargs, + ) + + self._model_summary_string = ( + "cVAE model with optional VampPrior and latent cycle-consistency loss" + ) + # necessary line to get params that will be used for saving/loading + self.init_params_ = self._get_init_params(locals()) + + logger.info("The model has been initialized") + + @torch.no_grad() + def embed( + self, + adata: AnnData, + indices: Optional[Sequence[int]] = None, + cycle: bool = False, + give_mean: bool = True, + batch_size: Optional[int] = None, + as_numpy: bool = True, + ) -> Union[np.ndarray, torch.Tensor]: + """ + Return the latent representation for each cell. + + Parameters + ---------- + adata + Input adata based on which latent representation is obtained. + indices + Data indices to embed. If None embedd all. + cycle + Return latent embedding of cycle pass. + give_mean + Return posterior mean instead of a sample from posterior. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + as_numpy + Return iin numpy rather than torch format. + + Returns + ------- + Latent Embedding + """ + # Check model and adata + self._check_if_trained(warn=False) + adata = self._validate_anndata(adata) + if indices is None: + indices = np.arange(adata.n_obs) + # Prediction + # Do not shuffle to retain order + tensors_fwd = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size, shuffle=False + ) + predicted = [] + for tensors in tensors_fwd: + # Inference + inference_inputs = self.module._get_inference_input(tensors) + inference_outputs = self.module.inference(**inference_inputs) + if cycle: + selected_system = self.module.random_select_systems(tensors["system"]) + generative_inputs = self.module._get_generative_input( + tensors, + inference_outputs, + selected_system=selected_system, + ) + generative_outputs = self.module.generative( + **generative_inputs, x_x=False, x_y=True + ) + inference_cycle_inputs = self.module._get_inference_cycle_input( + tensors=tensors, + generative_outputs=generative_outputs, + selected_system=selected_system, + ) + inference_outputs = self.module.inference(**inference_cycle_inputs) + if give_mean: + predicted += [inference_outputs["z_m"]] + else: + predicted += [inference_outputs["z"]] + + predicted = torch.cat(predicted) + + if as_numpy: + predicted = predicted.cpu().numpy() + return predicted + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + system_key: str, + layer: Optional[str] = None, + categorical_covariate_keys: Optional[list[str]] = None, + categorical_covariate_embed_keys: Optional[list[str]] = None, + continuous_covariate_keys: Optional[list[str]] = None, + covariate_categ_orders: Optional[dict] = None, + covariate_key_orders: Optional[dict] = None, + system_order: Optional[list[str]] = None, + **kwargs, + ) -> AnnData: + """ + Prepare adata for input to Model + + Parameters + ---------- + adata + Adata object - will be modified in place. + system_key + Name of obs column with categorical system information. + layer + AnnData layer to use, default X. Should contain normalized and log+1 transformed expression. + categorical_covariate_keys + Name of obs column with additional categorical covariate information. Will be one hot encoded. + categorical_covariate_embed_keys + Name of obs column with additional categorical covariate information. Embedding will be learned. + This can be useful if the number of categories is very large, which would increase memory usage. + If using this type of covariate representation please also cite + `scPoli <[https://doi.org/10.1038/s41592-023-02035-2]>`_ . + continuous_covariate_keys + Name of obs column with additional continuous covariate information. + covariate_categ_orders + Covariate encoding information. Should be used if a new adata is to be set up according + to setup of an existing adata. Access via adata.uns['covariate_categ_orders'] of already setup adata. + covariate_key_orders + Covariate encoding information. Should be used if a new adata is to be set up according + to setup of an existing adata. Access via adata.uns['covariate_key_orders'] of already setup adata. + system_order + Same as covariate_orders, but for system. Access via adata.uns['system_order'] + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + # Make sure var names are unique + if adata.shape[1] != len(set(adata.var_names)): + raise ValueError("Adata var_names are not unique") + + # If setup is to be prepared wtr another adata specs make sure all relevant info is present + if covariate_categ_orders or covariate_key_orders or system_order: + assert system_order is not None + if ( + categorical_covariate_keys is not None + or categorical_covariate_embed_keys is not None + or continuous_covariate_keys is not None + ): + assert covariate_categ_orders is not None + assert covariate_key_orders is not None + + # Make system embedding with specific category order + + # Define order of system categories + if system_order is None: + system_order = sorted(adata.obs[system_key].unique()) + # Validate that the provided system_order matches the categories in adata.obs[system_key] + if set(system_order) != set(adata.obs[system_key].unique()): + raise ValueError( + "Provided system_order does not match the categories in adata.obs[system_key]" + ) + + # Make one-hot embedding with specified order + systems_dict = dict( + zip(system_order, ([float(i) for i in range(0, len(system_order))])) + ) + adata.uns["system_order"] = system_order + system_cat = pd.Series( + pd.Categorical( + values=adata.obs[system_key], categories=system_order, ordered=True + ), + index=adata.obs.index, + name="system", + ) + adata.obsm["system"] = pd.get_dummies(system_cat, dtype=float) + + # Set up covariates + # TODO this could be handled by specific field type in registry + + # System must not be in cov + if categorical_covariate_keys is not None: + if system_key in categorical_covariate_keys: + raise ValueError("system_key should not be within covariate keys") + if categorical_covariate_embed_keys is not None: + if system_key in categorical_covariate_embed_keys: + raise ValueError("system_key should not be within covariate keys") + if continuous_covariate_keys is not None: + if system_key in continuous_covariate_keys: + raise ValueError("system_key should not be within covariate keys") + + # Prepare covariate training representations/embedding + covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata( + meta_data=adata.obs, + cov_cat_keys=categorical_covariate_keys, + cov_cat_embed_keys=categorical_covariate_embed_keys, + cov_cont_keys=continuous_covariate_keys, + categ_orders=covariate_categ_orders, + key_orders=covariate_key_orders, + ) + + # Save covariate representation and order information + adata.uns["covariate_categ_orders"] = orders_dict + adata.uns["covariate_key_orders"] = cov_dict + if ( + continuous_covariate_keys is not None + or categorical_covariate_keys is not None + ): + adata.obsm["covariates"] = covariates + else: + # Remove if present since the presence of this key + # is in model used to determine if cov should be used or not + if "covariates" in adata.obsm: + del adata.obsm["covariates"] + if categorical_covariate_embed_keys is not None: + adata.obsm["covariates_embed"] = covariates_embed + else: + if "covariates_embed" in adata.obsm: + del adata.obsm["covariates_embed"] + + # Anndata setup + + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False), + ObsmField("system", "system"), + ] + # Covariate fields are optional + if ( + continuous_covariate_keys is not None + or categorical_covariate_keys is not None + ): + anndata_fields.append(ObsmField("covariates", "covariates")) + if categorical_covariate_embed_keys is not None: + anndata_fields.append(ObsmField("covariates_embed", "covariates_embed")) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) diff --git a/scvi/external/csi/model/_training.py b/scvi/external/csi/model/_training.py new file mode 100644 index 0000000000..4ac9543a34 --- /dev/null +++ b/scvi/external/csi/model/_training.py @@ -0,0 +1,14 @@ +from scvi.external.csi.train import TrainingPlanCustom +from scvi.model.base import UnsupervisedTrainingMixin + + +class TrainingCustom(UnsupervisedTrainingMixin): + """Train method with custom TrainingPlan.""" + + # TODO could make custom Trainer (in a custom TrainRunner) to have in init params for early stopping + # "loss" rather than "elbo" components in available param specifications - for now just use + # a loss that is against the param specification + + # TODO run and log val before training - already tried some solutions by calling trainer.validate before + # fit and num_sanity_val_steps (designed not to log) + _training_plan_cls = TrainingPlanCustom diff --git a/scvi/external/csi/model/_utils.py b/scvi/external/csi/model/_utils.py new file mode 100644 index 0000000000..9f7dc5fe98 --- /dev/null +++ b/scvi/external/csi/model/_utils.py @@ -0,0 +1,149 @@ +from typing import Optional, Union + +import pandas as pd + + +def prepare_metadata( + meta_data: pd.DataFrame, + cov_cat_keys: Optional[list] = None, + cov_cat_embed_keys: Optional[list] = None, + cov_cont_keys: Optional[list] = None, + categ_orders: Optional[dict] = None, + key_orders: Optional[dict] = None, +): + """ + Prepare content of dataframe columns for model training (one hot encoding, encoding for embedding, ...) + + Parameters + ---------- + meta_data + Dataframe containing metadata columns. + cov_cat_keys + List of categorical covariates column names to be one-hot encoded. + cov_cat_embed_keys + List of categorical covariates column names to be embedded. + cov_cont_keys + List of continuous covariates column names. + categ_orders + Defined orders for categorical covariates. Dict with keys being + categorical covariates keys and values being lists of categories. May contain more + categories than data. + key_orders + Defines order of covariate columns. Dict with keys being 'categorical', 'categorical_embed', 'continuous' + and values being lists of keys. + + Returns + ------- + Tuple of: covariate data that does not require embedding, + covariate data that requires embedding, + dict with order of categories per covariate (as orders), + dict with keys (categorical, categorical_embed, and continuous) specifying as values + order of covariates used to construct the two covariate datas + + """ + + def get_categories_order(values: pd.Series, categories: Union[list, None] = None): + """ + Helper to get order of categories based on values and optional list of categories + + Parameters + ---------- + values + Categorical values + categories + Optional order of categories + + Returns + ------- + Categories order + """ + if categories is None: + categories = pd.Categorical(values).categories.values + else: + missing = set(values.unique()) - set(categories) + if len(missing) > 0: + raise ValueError( + f"Some values of {values.name} are not in the specified categories order: {missing}" + ) + return list(categories) + + if cov_cat_keys is None: + cov_cat_keys = [] + if cov_cat_embed_keys is None: + cov_cat_embed_keys = [] + if cov_cont_keys is None: + cov_cont_keys = [] + + # Check & set order of covariates and categories + if key_orders is not None: + assert set(key_orders["categorical"]) == set(cov_cat_keys) + cov_cat_keys = key_orders["categorical"] + assert set(key_orders["categorical_embed"]) == set(cov_cat_embed_keys) + cov_cat_embed_keys = key_orders["categorical_embed"] + assert set(key_orders["continuous"]) == set(cov_cont_keys) + cov_cont_keys = key_orders["continuous"] + cov_dict = { + "categorical": cov_cat_keys, + "categorical_embed": cov_cat_embed_keys, + "continuous": cov_cont_keys, + } + + if categ_orders is None: + categ_orders = {} + for cov_key in cov_cat_keys + cov_cat_embed_keys: + categ_orders[cov_key] = get_categories_order( + values=meta_data[cov_key], categories=None + ) + + def dummies_categories(values: pd.Series, categories: list): + """ + Make dummies of categorical covariates. Use specified order of categories. + + Parameters + ---------- + values + Categorical vales for each observation. + categories + Order of categories to use + + Returns + ------- + dummies - one-hot encoding of categories in same order as categories. + """ + # Get dummies + # Ensure ordering + values = pd.Series( + pd.Categorical(values=values, categories=categories, ordered=True), + index=values.index, + name=values.name, + ) + # This is problematic if many covariates + dummies = pd.get_dummies(values, prefix=values.name) + + return dummies + + # Covs that are not embedded: continuous and one-hot encoded categorical covariates + if len(cov_cat_keys) > 0 or len(cov_cont_keys) > 0: + cov_cat_data = [] + for cov_cat_key in cov_cat_keys: + cat_dummies = dummies_categories( + values=meta_data[cov_cat_key], categories=categ_orders[cov_cat_key] + ) + cov_cat_data.append(cat_dummies) + # Prepare single cov array for all covariates + cov_data_parsed = pd.concat(cov_cat_data + [meta_data[cov_cont_keys]], axis=1) + else: + cov_data_parsed = None + + # Data of covariates to be embedded + if len(cov_cat_embed_keys) > 0: + cov_embed_data = [] + for cov_cat_embed_key in cov_cat_embed_keys: + cat_order = categ_orders[cov_cat_embed_key] + cat_map = dict(zip(cat_order, range(len(cat_order)))) + cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map)) + cov_embed_data = pd.concat(cov_embed_data, axis=1) + else: + cov_embed_data = None + + return cov_data_parsed, cov_embed_data, categ_orders, cov_dict diff --git a/scvi/external/csi/module/__init__.py b/scvi/external/csi/module/__init__.py new file mode 100644 index 0000000000..eeac8e7d9c --- /dev/null +++ b/scvi/external/csi/module/__init__.py @@ -0,0 +1,11 @@ +from ._loss_recorder import ( + LossRecorder, +) +from ._module import ( + Module, +) + +__all__ = [ + "LossRecorder", + "Module", +] diff --git a/scvi/external/csi/module/_loss_recorder.py b/scvi/external/csi/module/_loss_recorder.py new file mode 100644 index 0000000000..b699d80031 --- /dev/null +++ b/scvi/external/csi/module/_loss_recorder.py @@ -0,0 +1,39 @@ +class LossRecorder: + """ + Loss signature for models. + + This class provides an organized way to record the model loss, as well as + the components of the ELBO. This may also be used in MLE, MAP, EM methods. + The loss is used for backpropagation during inference. The other parameters + are used for logging/early stopping during inference. + + Parameters + ---------- + loss + Tensor with loss for minibatch. Should be one dimensional with one value. + Note that loss should be a :class:`~torch.Tensor` and not the result of ``.item()``. + reconstruction_loss + Reconstruction loss for each observation in the minibatch. + kl_local + KL divergence associated with each observation in the minibatch. + kl_global + Global kl divergence term. Should be one dimensional with one value. + **kwargs + Additional metrics can be passed as keyword arguments and will + be available as attributes of the object. + """ + + def __init__( + self, + n_obs: int, + loss: float, + loss_sum: float, + **kwargs, + ): + self.n_obs = n_obs + self.loss = loss + self.loss_sum = loss_sum + self.extra_metric_attrs = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.extra_metric_attrs.append(key) diff --git a/scvi/external/csi/module/_module.py b/scvi/external/csi/module/_module.py new file mode 100644 index 0000000000..4031599121 --- /dev/null +++ b/scvi/external/csi/module/_module.py @@ -0,0 +1,440 @@ +from typing import Optional, Union + +import torch +from typing_extensions import Literal + +from scvi import REGISTRY_KEYS +from scvi.external.csi.nn import Embedding, EncoderDecoder +from scvi.module.base import BaseModuleClass, auto_move_data + +from . import LossRecorder +from ._priors import StandardPrior, VampPrior + +torch.backends.cudnn.benchmark = True + + +class Module(BaseModuleClass): + # TODO could disable computation of cycle if predefined that cycle wil not be used + + def __init__( + self, + n_input: int, + n_cov_const: int, + cov_embed_sizes: list, + n_system: int, + cov_embed_dims: int = 10, + prior: Literal["standard_normal", "vamp"] = "vamp", + n_prior_components: int = 5, + trainable_priors: bool = True, + pseudoinput_data: Optional[dict[str, torch.Tensor]] = None, + n_latent: int = 15, + n_hidden: int = 256, + n_layers: int = 2, + dropout_rate: float = 0.1, + out_var_mode: str = "feature", + **kwargs, + ): + """CVAE with optional VampPrior and latent cycle consistency loss. + + Parameters + ---------- + n_input + Number of input genes + n_cov_const + Dimensionality of covariate data that will not be further embedded + cov_embed_sizes + Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...] + n_system + Number of systems + cov_embed_dims + Dimension for covariate embedding + prior + Which prior to use + n_prior_components + If VampPrior - how many prior components to use + trainable_priors + If VampPrior- should prior components be trainable + pseudoinput_data + Initialisation data for VampPrior. Should match input tensors structure + n_latent + n_hidden + n_layers + dropout_rate + out_var_mode + kwargs + """ + super().__init__() + + self.embed_cov = len(cov_embed_sizes) > 0 # Will any covs be embedded + self.n_cov_const = n_cov_const # Dimension of covariates that are not embedded + n_cov = ( + n_cov_const + len(cov_embed_sizes) * cov_embed_dims + ) # Total size of covs (embedded & not embedded) + n_cov_encoder = n_cov + n_system # N covariates passed to Module (cov & system) + + if self.embed_cov: + self.cov_embeddings = torch.nn.ModuleList( + [ + Embedding(size=size, cov_embed_dims=cov_embed_dims) + for size in cov_embed_sizes + ] + ) + + self.encoder = EncoderDecoder( + n_input=n_input, + n_output=n_latent, + n_cov=n_cov_encoder, + n_hidden=n_hidden, + n_layers=n_layers, + dropout_rate=dropout_rate, + sample=True, + var_mode="sample_feature", + **kwargs, + ) + + self.decoder = EncoderDecoder( + n_input=n_latent, + n_output=n_input, + n_cov=n_cov_encoder, + n_hidden=n_hidden, + n_layers=n_layers, + dropout_rate=dropout_rate, + sample=True, + var_mode=out_var_mode, + **kwargs, + ) + + if prior == "standard_normal": + self.prior = StandardPrior() + elif prior == "vamp": + if pseudoinput_data is not None: + pseudoinput_data = self._get_inference_input(pseudoinput_data) + self.prior = VampPrior( + n_components=n_prior_components, + n_input=n_input, + n_cov=n_cov_encoder, + encoder=self.encoder, + data=( + pseudoinput_data["expr"], + self._merge_cov( + cov=pseudoinput_data["cov"], system=pseudoinput_data["system"] + ), + ), + trainable_priors=trainable_priors, + ) + else: + raise ValueError("Prior not recognised") + + def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]: + """Parse the input tensors to get inference inputs""" + expr = tensors[REGISTRY_KEYS.X_KEY] + cov = self._get_cov(tensors=tensors) + system = tensors["system"] + input_dict = {"expr": expr, "cov": cov, "system": system} + return input_dict + + def _get_inference_cycle_input( + self, tensors, generative_outputs, selected_system: torch.Tensor, **kwargs + ) -> dict[str, torch.Tensor]: + """Parse the input tensors and cycle system info to get cycle inference inputs""" + expr = generative_outputs["y_m"] + cov = self._mock_cov(self._get_cov(tensors=tensors)) + system = selected_system + input_dict = {"expr": expr, "cov": cov, "system": system} + return input_dict + + def _get_generative_input( + self, tensors, inference_outputs, selected_system: torch.Tensor, **kwargs + ) -> dict[str, torch.Tensor]: + """Parse the input tensors, inference inputs, and cycle system to get generative inputs""" + z = inference_outputs["z"] + + cov = self._get_cov(tensors=tensors) + cov = {"x": cov, "y": self._mock_cov(cov)} + + system = {"x": tensors["system"], "y": selected_system} + + input_dict = {"z": z, "cov": cov, "system": system} + return input_dict + + def _get_cov(self, tensors: dict[str, torch.Tensor]) -> Optional[torch.Tensor]: + """Merge all covariates into single tensor, including embedding of covariates""" + cov = [] + if self.n_cov_const > 0: + cov.append(tensors["covariates"]) + if self.embed_cov: + cov.extend( + [ + embedding(tensors["covariates_embed"][:, idx].int()) + for idx, embedding in enumerate(self.cov_embeddings) + ] + ) + cov = torch.concat(cov, dim=1) if len(cov) > 0 else None + return cov + + @staticmethod + def _merge_cov(cov: Optional[torch.Tensor], system: torch.Tensor) -> torch.Tensor: + """Merge full covariate data and system data to get cov for model input""" + return torch.cat([cov, system], dim=1) if cov is not None else system + + @staticmethod + def _mock_cov(cov: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Make mock (all 0) covariates for cycle""" + return torch.zeros_like(cov) if cov is not None else None + + @auto_move_data + def inference(self, expr, cov, system) -> dict: + """ + expression & cov -> latent representation + + Parameters + ---------- + expr + Expression data + cov + Full covariate data (categorical, categorical embedded, and continuous + system + System representation + + Returns + ------- + Posterior parameters and sample + """ + z = self.encoder(x=expr, cov=self._merge_cov(cov=cov, system=system)) + return {"z": z["y"], "z_m": z["y_m"], "z_v": z["y_v"]} + + @auto_move_data + def generative(self, z, cov, system, x_x: bool = True, x_y: bool = True) -> dict: + """ + latent representation & cov -> expression + + Parameters + ---------- + z + Latent embedding + cov + Full covariate data (categorical, categorical embedded, and continuous + system + System representation + x_x + Decode to original system + x_y + Decode to replacement system + + Returns + ------- + Decoded distribution parameters and sample + """ + + def outputs(compute, name, res, x, cov, system): + if compute: + res_sub = self.decoder(x=x, cov=self._merge_cov(cov=cov, system=system)) + res[name] = res_sub["y"] + res[name + "_m"] = res_sub["y_m"] + res[name + "_v"] = res_sub["y_v"] + + res = {} + outputs(compute=x_x, name="x", res=res, x=z, cov=cov["x"], system=system["x"]) + outputs(compute=x_y, name="y", res=res, x=z, cov=cov["y"], system=system["y"]) + return res + + @auto_move_data + def forward( + self, + tensors, + get_inference_input_kwargs: Optional[dict] = None, + get_generative_input_kwargs: Optional[dict] = None, + inference_kwargs: Optional[dict] = None, + generative_kwargs: Optional[dict] = None, + loss_kwargs: Optional[dict] = None, + compute_loss=True, + ) -> Union[ + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], LossRecorder], + ]: + """ + Forward pass through the network. + + Parameters + ---------- + tensors + tensors to pass through + get_inference_input_kwargs + Keyword args for ``_get_inference_input()`` + get_generative_input_kwargs + Keyword args for ``_get_generative_input()`` + inference_kwargs + Keyword args for ``inference()`` + generative_kwargs + Keyword args for ``generative()`` + loss_kwargs + Keyword args for ``loss()`` + compute_loss + Whether to compute loss on forward pass. This adds + another return value. + """ + """Core of the forward call shared by PyTorch- and Jax-based modules.""" + + # TODO currently some forward paths are computed despite potentially having loss weight=0 - + # don't compute if not needed + + # Parse kwargs + inference_kwargs = _get_dict_if_none(inference_kwargs) + generative_kwargs = _get_dict_if_none(generative_kwargs) + loss_kwargs = _get_dict_if_none(loss_kwargs) + get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) + get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) + + # Inference + inference_inputs = self._get_inference_input( + tensors, **get_inference_input_kwargs + ) + inference_outputs = self.inference(**inference_inputs, **inference_kwargs) + # Generative + selected_system = self.random_select_systems(tensors["system"]) + generative_inputs = self._get_generative_input( + tensors, + inference_outputs, + selected_system=selected_system, + **get_generative_input_kwargs, + ) + generative_outputs = self.generative( + **generative_inputs, x_x=True, x_y=True, **generative_kwargs + ) + # Inference cycle + inference_cycle_inputs = self._get_inference_cycle_input( + tensors=tensors, + generative_outputs=generative_outputs, + selected_system=selected_system, + **get_inference_input_kwargs, + ) + inference_cycle_outputs = self.inference( + **inference_cycle_inputs, **inference_kwargs + ) + + # Combine outputs of all forward pass components + inference_outputs_merged = dict(**inference_outputs) + inference_outputs_merged.update( + **{k.replace("z", "z_cyc"): v for k, v in inference_cycle_outputs.items()} + ) + generative_outputs_merged = dict(**generative_outputs) + + if compute_loss: + losses = self.loss( + tensors=tensors, + inference_outputs=inference_outputs_merged, + generative_outputs=generative_outputs_merged, + **loss_kwargs, + ) + return inference_outputs_merged, generative_outputs_merged, losses + else: + return inference_outputs_merged, generative_outputs_merged + + def loss( + self, + tensors, + inference_outputs, + generative_outputs, + kl_weight: float = 1.0, + reconstruction_weight: float = 1.0, + z_distance_cycle_weight: float = 2.0, + ): + x_true = tensors[REGISTRY_KEYS.X_KEY] + + # Reconstruction loss + + def reconst_loss_part(x_m, x, x_v): + """Compute reconstruction loss""" + return torch.nn.GaussianNLLLoss(reduction="none")(x_m, x, x_v).sum(dim=1) + + # Reconstruction loss + reconst_loss_x = reconst_loss_part( + x_m=generative_outputs["x_m"], x=x_true, x_v=generative_outputs["x_v"] + ) + reconst_loss = reconst_loss_x + + # Kl divergence on latent space + kl_divergence_z = self.prior.kl( + m_q=inference_outputs["z_m"], + v_q=inference_outputs["z_v"], + z=inference_outputs["z"], + ) + + def z_dist(z_x_m: torch.Tensor, z_y_m: torch.Tensor): + """MSE loss between standardised inputs with standardizer fitted on concatenation of both inputs""" + # Standardise data (jointly both z-s) before MSE calculation + z = torch.concat([z_x_m, z_y_m]) + means = z.mean(dim=0, keepdim=True) + stds = z.std(dim=0, keepdim=True) + + def standardize(x): + return (x - means) / stds + + return torch.nn.MSELoss(reduction="none")( + standardize(z_x_m), standardize(z_y_m) + ).sum(dim=1) + + z_distance_cyc = z_dist( + z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"] + ) + + # Overall loss + loss = ( + reconst_loss * reconstruction_weight + + kl_divergence_z * kl_weight + + z_distance_cyc * z_distance_cycle_weight + ) + + return LossRecorder( + n_obs=loss.shape[0], + loss=loss.mean(), + loss_sum=loss.sum(), + reconstruction_loss=reconst_loss.sum(), + kl_local=kl_divergence_z.sum(), + z_distance_cycle=z_distance_cyc.sum(), + ) + + @staticmethod + def random_select_systems(system: torch.Tensor) -> torch.Tensor: + """For every cell randomly selects a new system that is different from the original system + + Parameters + ---------- + system + One hot encoded system information for each cell + + Returns + ------- + One hot encoding of newly selected system for each cell + + """ + # get available systems - those that are zero will become nonzero and vice versa + available_systems = 1 - system + # Get nonzero indices for each cell + row_indices, col_indices = torch.nonzero(available_systems, as_tuple=True) + col_pairs = col_indices.view(-1, system.shape[1] - 1) + # Select system for every cell from available systems + randomly_selected_indices = col_pairs.gather( + 1, + torch.randint( + 0, + system.shape[1] - 1, + size=(col_pairs.size(0), 1), + device=col_pairs.device, + dtype=col_pairs.dtype, + ), + ) + new_tensor = torch.zeros_like(available_systems) + # generate system covariate tensor + new_tensor.scatter_(1, randomly_selected_indices, 1) + + return new_tensor + + def sample(self, *args, **kwargs): + raise NotImplementedError("") + + +def _get_dict_if_none(param: Optional[dict]) -> dict: + """If not a dict return empty dict""" + param = {} if not isinstance(param, dict) else param + return param diff --git a/scvi/external/csi/module/_priors.py b/scvi/external/csi/module/_priors.py new file mode 100644 index 0000000000..7db63fbd1b --- /dev/null +++ b/scvi/external/csi/module/_priors.py @@ -0,0 +1,118 @@ +import abc +from abc import abstractmethod +from typing import Optional + +import torch +from torch.distributions import Normal, kl_divergence + + +class Prior(torch.nn.Module, abc.ABC): + @abstractmethod + def kl(self, m_q, v_q, z): + pass + + +class StandardPrior(Prior): + def kl(self, m_q, v_q, z=None): + # 1 x N + return kl_divergence( + Normal(m_q, v_q.sqrt()), Normal(torch.zeros_like(m_q), torch.ones_like(v_q)) + ).sum(dim=1) + + +class VampPrior(Prior): + # K - components, I - inputs, L - latent, N - samples + + def __init__( + self, + n_components, + n_input, + n_cov, + encoder, + data: Optional[tuple[torch.tensor, torch.tensor]] = None, + trainable_priors=True, + ): + """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/main/vaes/vae_priors_example.ipynb + + Parameters + ---------- + n_components + Prior components + n_input + Model input dimensions + n_cov + Model input covariate dimensions + encoder + The encoder + data + Data for pseudoinputs initialisation tuple(input,covs) + trainable_priors + Are pseudoinput parameters trainable or fixed + """ + super().__init__() + + self.encoder = encoder + + # Get pseudoinputs + if data is None: + u = torch.rand(n_components, n_input) # K * I + u_cov = torch.zeros(n_components, n_cov) # K * C + else: + u = data[0] + u_cov = data[1] + assert n_components == data[0].shape[0] == data[1].shape[0] + assert n_input == data[0].shape[1] + assert n_cov == data[1].shape[1] + self.u = torch.nn.Parameter(u, requires_grad=trainable_priors) + self.u_cov = torch.nn.Parameter(u_cov, requires_grad=trainable_priors) + + # mixing weights + self.w = torch.nn.Parameter(torch.zeros(self.u.shape[0], 1, 1)) # K x 1 x 1 + + def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get posterior of pseudoinputs + + Returns + ------- + Posterior mean, var + """ + # u, u_cov -> encoder -> mean, var + original_mode = self.encoder.training + self.encoder.train(False) + z = self.encoder(x=self.u, cov=self.u_cov) + self.encoder.train(original_mode) + return z["y_m"], z["y_v"] # (K x L), (K x L) + + def log_prob(self, z) -> torch.Tensor: + """ + Log probability of z under the prior + + Parameters + ---------- + z + Latent embedding of samples + + Returns + ------- + Log probability of every sample: samples * latent dimensions + """ + # Mixture of gaussian computed on K x N x L + z = z.unsqueeze(0) # 1 x N x L + + # Get pseudoinputs posteriors - prior params + m_p, v_p = self.get_params() # (K x L), (K x L) + m_p = m_p.unsqueeze(1) # K x 1 x L + v_p = v_p.unsqueeze(1) # K x 1 x L + + # mixing probabilities + w = torch.nn.functional.softmax(self.w, dim=0) # K x 1 x 1 + + # sum of log_p across components weighted by w + log_prob = Normal(m_p, v_p.sqrt()).log_prob(z) + torch.log(w) # K x N x L + log_prob = torch.logsumexp(log_prob, dim=0, keepdim=False) # N x L + + return log_prob # N x L + + def kl(self, m_q, v_q, z): + return (Normal(m_q, v_q.sqrt()).log_prob(z) - self.log_prob(z)).sum(1) diff --git a/scvi/external/csi/nn/__init__.py b/scvi/external/csi/nn/__init__.py new file mode 100644 index 0000000000..449b333424 --- /dev/null +++ b/scvi/external/csi/nn/__init__.py @@ -0,0 +1,9 @@ +from ._base_components import ( + Embedding, + EncoderDecoder, +) + +__all__ = [ + "EncoderDecoder", + "Embedding", +] diff --git a/scvi/external/csi/nn/_base_components.py b/scvi/external/csi/nn/_base_components.py new file mode 100644 index 0000000000..c44f422f89 --- /dev/null +++ b/scvi/external/csi/nn/_base_components.py @@ -0,0 +1,341 @@ +from collections import OrderedDict +from typing import Literal, Union + +import torch +from torch.distributions import Normal +from torch.nn import ( + BatchNorm1d, + Dropout, + LayerNorm, + Linear, + Module, + Parameter, + ReLU, + Sequential, +) + + +class Embedding(Module): + def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True): + """Module for obtaining embedding of categorical covariates + + Parameters + ---------- + size + N categories + cov_embed_dims + Dimensions of embedding + normalize + Apply layer normalization + """ + super().__init__() + + self.normalize = normalize + + self.embedding = torch.nn.Embedding(size, cov_embed_dims) + + if self.normalize: + # TODO this could probably be implemented more efficiently as embed gives same result for every sample in + # a give class. However, if we have many balanced classes there wont be many repetitions within minibatch + self.layer_norm = torch.nn.LayerNorm( + cov_embed_dims, elementwise_affine=False + ) + + def forward(self, x): + x = self.embedding(x) + if self.normalize: + x = self.layer_norm(x) + + return x + + +class EncoderDecoder(Module): + def __init__( + self, + n_input: int, + n_output: int, + n_cov: int, + n_hidden: int = 256, + n_layers: int = 3, + var_eps: float = 1e-4, + var_mode: str = "feature", + sample: bool = False, + **kwargs, + ): + """Module that can be used as probabilistic encoder or decoder + + Based on inputs and optional covariates predicts output mean and var + + Parameters + ---------- + n_input + n_output + n_cov + n_hidden + n_layers + var_eps + See :class:`~scvi.external.csi.nn.VarEncoder` + var_mode + See :class:`~scvi.external.csi.nn.VarEncoder` + sample + Return samples from predicted distribution + kwargs + Passed to :class:`~scvi.external.csi.nn.Layers` + """ + super().__init__() + self.sample = sample + + self.var_eps = var_eps + + self.decoder_y = Layers( + n_in=n_input, + n_cov=n_cov, + n_out=n_hidden, + n_hidden=n_hidden, + n_layers=n_layers, + **kwargs, + ) + + self.mean_encoder = Linear(n_hidden, n_output) + self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) + + def forward(self, x, cov: Union[torch.Tensor, None] = None): + y = self.decoder_y(x=x, cov=cov) + # TODO better handling of inappropriate edge-case values than nan_to_num or at least warn + y_m = torch.nan_to_num(self.mean_encoder(y)) + y_v = self.var_encoder(y, x_m=y_m) + + outputs = {"y_m": y_m, "y_v": y_v} + + # Sample from latent distribution + if self.sample: + y = Normal(y_m, y_v.sqrt()).rsample() + outputs["y"] = y + + return outputs + + +class Layers(Module): + def __init__( + self, + n_in: int, + n_out: int, + n_cov: Union[int, None] = None, + n_layers: int = 1, + n_hidden: int = 128, + dropout_rate: float = 0.1, + use_batch_norm: bool = True, + use_layer_norm: bool = False, + use_activation: bool = True, + bias: bool = True, + inject_covariates: bool = True, + activation_fn: Module = ReLU, + ): + """A helper class to build fully-connected layers for a neural network. + + Adapted from scVI FCLayers to use covariates more flexibly + + Parameters + ---------- + n_in + The dimensionality of the main input + n_out + The dimensionality of the output + n_cov + Dimensionality of covariates. + If there are no cov this should be set to None - + in this case cov will not be used. + n_layers + The number of fully-connected hidden layers + n_hidden + The number of nodes per hidden layer + dropout_rate + Dropout rate to apply to each of the hidden layers + use_batch_norm + Whether to have `BatchNorm` layers or not + use_layer_norm + Whether to have `LayerNorm` layers or not + use_activation + Whether to have layer activation or not + bias + Whether to learn bias in linear layers or not + inject_covariates + Whether to inject covariates in each layer, or just the first. + activation_fn + Which activation function to use + """ + super().__init__() + + self.inject_covariates = inject_covariates + self.n_cov = n_cov if n_cov is not None else 0 + + layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] + + self.fc_layers = Sequential( + OrderedDict( + [ + ( + f"Layer {i}", + Sequential( + Linear( + n_in + self.n_cov * self.inject_into_layer(i), + n_out, + bias=bias, + ), + # non-default params come from defaults in original Tensorflow implementation + BatchNorm1d(n_out, momentum=0.01, eps=0.001) + if use_batch_norm + else None, + LayerNorm(n_out, elementwise_affine=False) + if use_layer_norm + else None, + activation_fn() if use_activation else None, + Dropout(p=dropout_rate) if dropout_rate > 0 else None, + ), + ) + for i, (n_in, n_out) in enumerate( + zip(layers_dim[:-1], layers_dim[1:]) + ) + ] + ) + ) + + def inject_into_layer(self, layer_num) -> bool: + """Helper to determine if covariates should be injected.""" + user_cond = layer_num == 0 or (layer_num > 0 and self.inject_covariates) + return user_cond + + def set_online_update_hooks(self, hook_first_layer=True): + self.hooks = [] + + def _hook_fn_weight(grad): + new_grad = torch.zeros_like(grad) + if self.n_cov > 0: + new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] + return new_grad + + def _hook_fn_zero_out(grad): + return grad * 0 + + for i, layers in enumerate(self.fc_layers): + for layer in layers: + if i == 0 and not hook_first_layer: + continue + if isinstance(layer, Linear): + if self.inject_into_layer(i): + w = layer.weight.register_hook(_hook_fn_weight) + else: + w = layer.weight.register_hook(_hook_fn_zero_out) + self.hooks.append(w) + b = layer.bias.register_hook(_hook_fn_zero_out) + self.hooks.append(b) + + def forward(self, x: torch.Tensor, cov: Union[torch.Tensor, None] = None): + """ + Forward computation on ``x``. + + Parameters + ---------- + x + tensor of values with shape ``(n_in,)`` + cov + tensor of covariate values with shape ``(n_cov,)`` or None + + Returns + ------- + py:class:`torch.Tensor` + tensor of shape ``(n_out,)`` + + """ + for i, layers in enumerate(self.fc_layers): + for layer in layers: + if layer is not None: + if isinstance(layer, BatchNorm1d): + if x.dim() == 3: + x = torch.cat( + [(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0 + ) + else: + x = layer(x) + else: + # Injection of covariates + if ( + self.n_cov > 0 + and isinstance(layer, Linear) + and self.inject_into_layer(i) + ): + x = torch.cat((x, cov), dim=-1) + x = layer(x) + return x + + +class VarEncoder(Module): + def __init__( + self, + n_input: int, + n_output: int, + mode: Literal["sample_feature", "feature", "linear"], + eps: float = 1e-4, + ): + """Encode variance (strictly positive). + + Parameters + ---------- + n_input + Number of input dimensions, used if mode is sample_feature + n_output + Number of variances to predict + mode + How to compute var + 'sample_feature' - learn per sample and feature + 'feature' - learn per feature, constant across samples + 'linear' - linear with respect to input mean, var = a1 * mean + a0; + not suggested to be used due to bad implementation for positive constraining + eps + """ + super().__init__() + + self.eps = eps + self.mode = mode + if self.mode == "sample_feature": + self.encoder = Linear(n_input, n_output) + elif self.mode == "feature": + self.var_param = Parameter(torch.zeros(1, n_output)) + elif self.mode == "linear": + self.var_param_a1 = Parameter(torch.tensor([1.0])) + self.var_param_a0 = Parameter(torch.tensor([self.eps])) + else: + raise ValueError("Mode not recognised.") + self.activation = torch.exp + + def forward(self, x: torch.Tensor, x_m: torch.Tensor): + """Forward pass through model + + Parameters + ---------- + x + Used to encode var if mode is sample_feature; dim = n_samples x n_input + x_m + Used to predict var instead of x if mode is linear; dim = n_samples x 1 + + Returns + ------- + Predicted var + """ + # Force to be non nan - TODO come up with better way to do so + if self.mode == "sample_feature": + v = self.encoder(x) + v = ( + torch.nan_to_num(self.activation(v)) + self.eps + ) # Ensure that var is strictly positive + elif self.mode == "feature": + v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size + v = ( + torch.nan_to_num(self.activation(v)) + self.eps + ) # Ensure that var is strictly positive + elif self.mode == "linear": + v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0 + # TODO come up with a better way to constrain this to positive while having lin relationship + # Could activation be used for log-lin relationship? + v = torch.clamp(torch.nan_to_num(v), min=self.eps) + return v diff --git a/scvi/external/csi/train/__init__.py b/scvi/external/csi/train/__init__.py new file mode 100644 index 0000000000..cacfa20b74 --- /dev/null +++ b/scvi/external/csi/train/__init__.py @@ -0,0 +1,7 @@ +from ._trainingplans import ( + TrainingPlanCustom, +) + +__all__ = [ + "TrainingPlanCustom", +] diff --git a/scvi/external/csi/train/_trainingplans.py b/scvi/external/csi/train/_trainingplans.py new file mode 100644 index 0000000000..7793045db1 --- /dev/null +++ b/scvi/external/csi/train/_trainingplans.py @@ -0,0 +1,278 @@ +from inspect import getfullargspec +from typing import Literal, Union + +import torch +from torchmetrics import MetricCollection + +from scvi.external.csi.module import LossRecorder +from scvi.module.base import BaseModuleClass +from scvi.train import TrainingPlan + +# TODO could make new metric class to not be called elbo metric as used for other metrics as well +from scvi.train._metrics import ElboMetric + + +class WeightScaling: + def __init__( + self, + weight_start: float, + weight_end: float, + point_start: int, + point_end: int, + update_on: Literal["epoch", "step"] = "step", + ): + """Linearly scale loss weights between start and end weight accordingly to the current training stage + + Parameters + ---------- + weight_start + Starting weight value + weight_end + End weight vlue + point_start + Training point to start scaling - before weight is weight_start + Since the epochs are counted after they are run, + the start point must be set to 0 to represent 1st epoch + point_end + Training point to end scaling - after weight is weight_end + Since the epochs are counted after they are run, + the start point must be set to n-1 to represent the last epoch + update_on + Define training progression based on epochs or steps + + """ + self.weight_start = weight_start + self.weight_end = weight_end + self.point_start = point_start + self.point_end = point_end + if update_on not in ["step", "epoch"]: + raise ValueError("update_on not recognized") + self.update_on = update_on + + weight_diff = self.weight_end - self.weight_start + n_points = self.point_end - self.point_start + self.slope = weight_diff / n_points + + if ( + self.weight(epoch=self.point_start, step=self.point_start) < 0 + or self.weight(epoch=self.point_end, step=self.point_end) < 0 + ): + raise ValueError("Specified weight scaling would lead to negative weights") + + def weight( + self, + epoch: int, + step: int, + ) -> float: + """ + Computes the weight for the current step/epoch depending on which update type was set in init + + Parameters + ---------- + epoch + Current epoch. + step + Current step. + """ + if self.update_on == "epoch": + point = epoch + elif self.update_on == "step": + point = step + else: + # This is ensured not to happen by above init check + raise ValueError("self.update_on not recognised") + + if point < self.point_start: + return self.weight_start + elif point > self.point_end: + return self.weight_end + else: + return self.slope * (point - self.point_start) + self.weight_start + + +class TrainingPlanCustom(TrainingPlan): + def __init__( + self, + module: BaseModuleClass, + loss_weights: Union[None, dict[str, Union[float, WeightScaling]]] = None, + log_on_epoch: bool = True, + log_on_step: bool = False, + **kwargs, + ): + """Extends scvi TrainingPlan for custom support for other losses. + + Parameters + ---------- + args + Passed to parent + log_on_epoch + See on_epoch of lightning Module log method + log_on_step + See on_step of lightning Module log method + loss_weights + Specifies how losses should be weighted and how it may change during training + Dict with keys being loss names and values being loss weights. + Loss weights can be floats for constant weight or dict of params passed to WeightScaling object + Note that other loss weight params from the parent class are ignored + (e.g. n_steps/epochs_kl_warmup and min/max_kl_weight) + kwargs + Passed to parent. + As described in param loss_weights the loss weighting params of parent are ignored + """ + super().__init__(module, **kwargs) + + self.log_on_epoch = log_on_epoch + self.log_on_step = log_on_step + + # automatic handling of loss component weights + if loss_weights is None: + loss_weights = {} + # Make weighting object + for loss, weight in loss_weights.items(): + if isinstance(weight, dict): + loss_weights[loss] = WeightScaling(**weight) + self.loss_weights = loss_weights + + # Ensure that all passed loss weight specifications are in available loss params + # Also update loss kwargs based on specified weights + self._loss_args = getfullargspec(self.module.loss)[0] + # Make sure no loss weights are already in loss kwargs (e.g. from parent init) + for loss in self._loss_args: + if loss in self.loss_kwargs: + del self.loss_kwargs[loss] + for loss, weight in loss_weights.items(): + if loss not in self._loss_args: + raise ValueError( + f"Loss {loss} for which a weight was specified is not in loss parameters" + ) + # This will also overwrite the kl_weight from parent + self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)}) + + def compute_loss_weight(self, weight): + if isinstance(weight, float): + return weight + elif isinstance(weight, int): + return float(weight) + elif isinstance(weight, WeightScaling): + return weight.weight(epoch=self.current_epoch, step=self.global_step) + + @staticmethod + def _create_elbo_metric_components( + mode: Literal["train", "validation"], **kwargs + ) -> (ElboMetric, MetricCollection): + """ + Initialize the combined loss collection. + + Parameters + ---------- + mode + train/validation + + Returns + ------- + tuple + Objects for storing the combined loss + + """ + loss = ElboMetric("loss", mode, "obs") + collection = MetricCollection({metric.name: metric for metric in [loss]}) + return loss, collection + + def initialize_train_metrics(self): + """Initialize train combined loss. + + TODO could add other losses + """ + ( + self.loss_train, + self.train_metrics, + ) = self._create_elbo_metric_components( + mode="train", n_total=self.n_obs_training + ) + self.loss_train.reset() + + def initialize_val_metrics(self): + """Initialize val combined loss. + + TODO could add other losses + """ + ( + self.loss_val, + self.val_metrics, + ) = self._create_elbo_metric_components( + mode="validation", n_total=self.n_obs_validation + ) + self.loss_val.reset() + + @torch.no_grad() + def compute_and_log_metrics( + self, + loss_recorder: LossRecorder, + metrics: MetricCollection, + mode: str, + ): + """ + Computes and logs metrics. + + Parameters + ---------- + loss_recorder + LossRecorder object from scvi-tools module + metrics + The loss Metric Collection to update + mode + Postfix string to add to the metric name of + extra metrics. If train also logs the loss in progress bar + """ + n_obs_minibatch = loss_recorder.n_obs + loss_sum = loss_recorder.loss_sum + + # use the torchmetric object + metrics.update( + loss=loss_sum, + n_obs_minibatch=n_obs_minibatch, + ) + + self.log( + f"loss_{mode}", + loss_recorder.loss_sum, + on_step=self.log_on_step, + on_epoch=self.log_on_epoch, + batch_size=n_obs_minibatch, + prog_bar=True if mode == "train" else False, + sync_dist=self.use_sync_dist, + ) + + # accumulate extra metrics passed to loss recorder + for extra_metric in loss_recorder.extra_metric_attrs: + met = getattr(loss_recorder, extra_metric) + if isinstance(met, torch.Tensor): + if met.shape != torch.Size([]): + raise ValueError("Extra tracked metrics should be 0-d tensors.") + met = met.detach() + self.log( + f"{extra_metric}_{mode}", + met, + on_step=self.log_on_step, + on_epoch=self.log_on_epoch, + batch_size=n_obs_minibatch, + sync_dist=self.use_sync_dist, + ) + + def training_step(self, batch, batch_idx): + for loss, weight in self.loss_weights.items(): + self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)}) + _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) + # combined loss is logged via compute_and_log_metrics + self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train") + return scvi_loss.loss + + def validation_step(self, batch, batch_idx): + _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) + # Combined loss is logged via compute_and_log_metrics + self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") + + @property + def kl_weight(self): + # Can not raise not implemented error as used in parent init + pass diff --git a/tests/external/csi/test_model.py b/tests/external/csi/test_model.py new file mode 100644 index 0000000000..d8d6f05cc3 --- /dev/null +++ b/tests/external/csi/test_model.py @@ -0,0 +1,178 @@ +import math + +import numpy as np +import pandas as pd +from anndata import AnnData +from numpy.testing import assert_raises +from scipy import sparse + +from scvi.external.csi.model import Model + + +def mock_adata(): + # Make random data + adata = AnnData( + sparse.csr_matrix( + np.exp( + np.concatenate( + [ + np.random.normal(1, 0.5, (200, 5)), + np.random.normal(1.1, 0.00237, (200, 5)), + np.random.normal(1.3, 0.35, (200, 5)), + np.random.normal(2, 0.111, (200, 5)), + np.random.normal(2.2, 0.3, (200, 5)), + np.random.normal(2.7, 0.01, (200, 5)), + np.random.normal(1, 0.001, (200, 5)), + np.random.normal(0.00001, 0.4, (200, 5)), + np.random.normal(0.2, 0.91, (200, 5)), + np.random.normal(0.1, 0.0234, (200, 5)), + np.random.normal(0.00005, 0.1, (200, 5)), + np.random.normal(0.05, 0.001, (200, 5)), + np.random.normal(0.023, 0.3, (200, 5)), + np.random.normal(0.6, 0.13, (200, 5)), + np.random.normal(0.9, 0.5, (200, 5)), + np.random.normal(1, 0.0001, (200, 5)), + np.random.normal(1.5, 0.05, (200, 5)), + np.random.normal(2, 0.009, (200, 5)), + np.random.normal(1, 0.0001, (200, 5)), + ], + axis=1, + ) + ) + ), + var=pd.DataFrame(index=[str(i) for i in range(95)]), + ) + adata.obs["covariate_cont"] = list(range(200)) + adata.obs["covariate_cat"] = ["a"] * 50 + ["b"] * 50 + ["c"] * 50 + ["d"] * 50 + adata.obs["covariate_cat_emb"] = ["a"] * 50 + ["b"] * 50 + ["c"] * 50 + ["d"] * 50 + adata.obs["system"] = ["a"] * 100 + ["b"] * 50 + ["c"] * 50 + + return adata + + +def test_model(): + adata0 = mock_adata() + + # Run adata setup with all covariates + Model.setup_anndata( + adata0, + system_key="system", + categorical_covariate_keys=["covariate_cat"], + categorical_covariate_embed_keys=["covariate_cat_emb"], + continuous_covariate_keys=["covariate_cont"], + ) + + # Run adata setup transfer + # TODO ensure this is actually done correctly, not just that it runs through + adata = mock_adata() + Model.setup_anndata( + adata, + system_key="system", + categorical_covariate_keys=["covariate_cat"], + categorical_covariate_embed_keys=["covariate_cat_emb"], + continuous_covariate_keys=["covariate_cont"], + covariate_categ_orders=adata0.uns["covariate_categ_orders"], + covariate_key_orders=adata0.uns["covariate_key_orders"], + system_order=adata0.uns["system_order"], + ) + + # Check that setup of adata without covariates works + Model.setup_anndata( + adata0, + system_key="system", + ) + assert "covariates" not in adata0.obsm + assert "covariates_embed" not in adata0.obsm + + # Model + + # Check that model runs through with standard normal prior + model = Model(adata=adata, prior="standard_normal") + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Check that model runs through without covariates + model = Model(adata=adata0) + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Check pre-specifying pseudoinput indices for vamp prior + _ = Model( + adata=adata, + prior="vamp", + pseudoinputs_data_indices=list(range(5)), + n_prior_components=5, + ) + + # Check that model runs through with vamp prior without specifying pseudoinput indices, + # all covariates, and weight scaling + model = Model(adata=adata, prior="vamp") + model.train( + max_epochs=2, + batch_size=math.ceil(adata.n_obs / 2.0), + log_every_n_steps=1, + check_val_every_n_epoch=1, + val_check_interval=1, + plan_kwargs={ + "log_on_epoch": False, + "log_on_step": True, + "loss_weights": { + "kl_weight": 2, + "z_distance_cycle_weight": { + "weight_start": 1, + "weight_end": 3, + "point_start": 1, + "point_end": 3, + "update_on": "step", + }, + }, + }, + ) + + # Embedding + + # Check that embedding default works + assert ( + model.embed( + adata=adata, + ).shape[0] + == adata.shape[0] + ) + + # Check that indices in embedding works + idx = [1, 2, 3] + embed = model.embed( + adata=adata, + indices=idx, + give_mean=True, + ) + assert embed.shape[0] == 3 + + # Check predicting mean/sample + np.testing.assert_allclose( + embed, + model.embed( + adata=adata, + indices=idx, + give_mean=True, + ), + ) + with assert_raises(AssertionError): + np.testing.assert_allclose( + embed, + model.embed( + adata=adata, + indices=idx, + give_mean=False, + ), + ) + + # Check predicting cycle + with assert_raises(AssertionError): + np.testing.assert_allclose( + embed, + model.embed( + adata=adata, + indices=idx, + give_mean=True, + cycle=True, + ), + ) From bb359ac4a20b33b57b3f774e6ffe09b70e4acb19 Mon Sep 17 00:00:00 2001 From: Hrovatin Date: Sun, 14 Jan 2024 21:59:52 +0100 Subject: [PATCH 02/60] update documentation --- scvi/external/csi/model/_model.py | 29 +++++++++++---------- scvi/external/csi/module/_module.py | 39 +++++++++++++++++++---------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/scvi/external/csi/model/_model.py b/scvi/external/csi/model/_model.py index eae48b80af..cb71d3e714 100644 --- a/scvi/external/csi/model/_model.py +++ b/scvi/external/csi/model/_model.py @@ -33,19 +33,18 @@ def __init__( pseudoinputs_data_indices: Optional[np.array] = None, **model_kwargs, ): - """CVAE integration model with optional VampPrior and latent cycle-consistency loss + """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. Parameters ---------- adata - AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. + AnnData object that has been registered via :meth:`~scvi-tools.SysVI.setup_anndata`. prior - The prior to be used. You can choose between "standard_normal" and "vamp". + The prior distribution to be used. You can choose between "standard_normal" and "vamp". n_prior_components Number of prior components in VampPrior. pseudoinputs_data_indices - By default (based on pseudoinputs_data_init), - VAMP prior pseudoinputs are randomly selected from data. + By default VampPrior pseudoinputs are randomly selected from data. Alternatively, one can specify pseudoinput indices using this parameter. **model_kwargs Keyword args for :class:`~scvi.external.csi.module.Module` @@ -109,23 +108,22 @@ def embed( batch_size: Optional[int] = None, as_numpy: bool = True, ) -> Union[np.ndarray, torch.Tensor]: - """ - Return the latent representation for each cell. + """Return the latent representation for each cell. Parameters ---------- adata - Input adata based on which latent representation is obtained. + Input adata for which latent representation should be obtained. indices - Data indices to embed. If None embedd all. + Data indices to embed. If None embedd all cells. cycle - Return latent embedding of cycle pass. + Return latent embedding of the cycle pass. give_mean - Return posterior mean instead of a sample from posterior. + Return the posterior mean instead of a sample from the posterior. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. as_numpy - Return iin numpy rather than torch format. + Return in numpy rather than torch format. Returns ------- @@ -133,6 +131,7 @@ def embed( """ # Check model and adata self._check_if_trained(warn=False) + # TODO extend to check if adata setup is correct wrt training data adata = self._validate_anndata(adata) if indices is None: indices = np.arange(adata.n_obs) @@ -188,8 +187,7 @@ def setup_anndata( system_order: Optional[list[str]] = None, **kwargs, ) -> AnnData: - """ - Prepare adata for input to Model + """Prepare adata for input to Model Parameters ---------- @@ -198,7 +196,8 @@ def setup_anndata( system_key Name of obs column with categorical system information. layer - AnnData layer to use, default X. Should contain normalized and log+1 transformed expression. + AnnData layer to use, default is X. + Should contain normalized and log+1 transformed expression. categorical_covariate_keys Name of obs column with additional categorical covariate information. Will be one hot encoded. categorical_covariate_embed_keys diff --git a/scvi/external/csi/module/_module.py b/scvi/external/csi/module/_module.py index 4031599121..a9fa8431b3 100644 --- a/scvi/external/csi/module/_module.py +++ b/scvi/external/csi/module/_module.py @@ -32,36 +32,49 @@ def __init__( n_layers: int = 2, dropout_rate: float = 0.1, out_var_mode: str = "feature", - **kwargs, + **enc_dec_kwargs, ): """CVAE with optional VampPrior and latent cycle consistency loss. Parameters ---------- n_input - Number of input genes + Number of input features. + Passed directly from Model. n_cov_const - Dimensionality of covariate data that will not be further embedded + Dimensionality of covariate data that will not be further embedded. + Passed directly from Model. cov_embed_sizes - Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...] + Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...]. + Passed directly from Model. n_system - Number of systems + Number of systems. + Passed directly from Model. cov_embed_dims - Dimension for covariate embedding + Dimension for covariate embedding. prior - Which prior to use + Which prior distribution to use. + Passed directly from Model. n_prior_components - If VampPrior - how many prior components to use + If VampPrior - how many prior components to use. + Passed directly from Model. trainable_priors - If VampPrior- should prior components be trainable + If VampPrior - should prior components be trainable. pseudoinput_data - Initialisation data for VampPrior. Should match input tensors structure + Initialisation data for VampPrior. Should match input tensors structure. + Passed directly from Model. n_latent + Numer of latent space dimensions. n_hidden + Number of nodes in hidden layers. n_layers + Number of hidden layers. dropout_rate + Dropout rate. out_var_mode - kwargs + See :class:`~scvi.external.csi.nn.VarEncoder` + enc_dec_kwargs + Additional kwargs passed to encoder and decoder. """ super().__init__() @@ -89,7 +102,7 @@ def __init__( dropout_rate=dropout_rate, sample=True, var_mode="sample_feature", - **kwargs, + **enc_dec_kwargs, ) self.decoder = EncoderDecoder( @@ -101,7 +114,7 @@ def __init__( dropout_rate=dropout_rate, sample=True, var_mode=out_var_mode, - **kwargs, + **enc_dec_kwargs, ) if prior == "standard_normal": From 14e41f1b5c259ed57346c6d7e0e84eba9fcf6eb0 Mon Sep 17 00:00:00 2001 From: Hrovatin Date: Sun, 14 Jan 2024 22:53:50 +0100 Subject: [PATCH 03/60] move embedding to device --- scvi/external/csi/module/_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scvi/external/csi/module/_module.py b/scvi/external/csi/module/_module.py index a9fa8431b3..b798916623 100644 --- a/scvi/external/csi/module/_module.py +++ b/scvi/external/csi/module/_module.py @@ -170,6 +170,7 @@ def _get_generative_input( input_dict = {"z": z, "cov": cov, "system": system} return input_dict + @auto_move_data def _get_cov(self, tensors: dict[str, torch.Tensor]) -> Optional[torch.Tensor]: """Merge all covariates into single tensor, including embedding of covariates""" cov = [] From 3f492663af880281845f043c5acfd59f7f1a06d3 Mon Sep 17 00:00:00 2001 From: Hrovatin Date: Sun, 21 Jan 2024 18:21:24 +0100 Subject: [PATCH 04/60] pr comments --- scvi/external/__init__.py | 2 +- scvi/external/csi/model/__init__.py | 7 - scvi/external/csi/model/_training.py | 14 -- scvi/external/csi/module/__init__.py | 11 -- scvi/external/csi/module/_loss_recorder.py | 39 ---- scvi/external/csi/nn/__init__.py | 9 - scvi/external/csi/train/__init__.py | 7 - scvi/external/sysvi/__init__.py | 3 + .../{csi/nn => sysvi}/_base_components.py | 174 +++++++++--------- scvi/external/{csi/model => sysvi}/_model.py | 100 +++++----- .../external/{csi/module => sysvi}/_module.py | 135 +++++++------- .../external/{csi/module => sysvi}/_priors.py | 40 ++-- .../{csi/train => sysvi}/_trainingplans.py | 140 +++++++------- scvi/external/{csi/model => sysvi}/_utils.py | 14 +- tests/external/{csi => sysvi}/test_model.py | 34 ++-- 15 files changed, 330 insertions(+), 399 deletions(-) delete mode 100644 scvi/external/csi/model/__init__.py delete mode 100644 scvi/external/csi/model/_training.py delete mode 100644 scvi/external/csi/module/__init__.py delete mode 100644 scvi/external/csi/module/_loss_recorder.py delete mode 100644 scvi/external/csi/nn/__init__.py delete mode 100644 scvi/external/csi/train/__init__.py create mode 100644 scvi/external/sysvi/__init__.py rename scvi/external/{csi/nn => sysvi}/_base_components.py (74%) rename scvi/external/{csi/model => sysvi}/_model.py (80%) rename scvi/external/{csi/module => sysvi}/_module.py (82%) rename scvi/external/{csi/module => sysvi}/_priors.py (82%) rename scvi/external/{csi/train => sysvi}/_trainingplans.py (67%) rename scvi/external/{csi/model => sysvi}/_utils.py (93%) rename tests/external/{csi => sysvi}/test_model.py (89%) diff --git a/scvi/external/__init__.py b/scvi/external/__init__.py index 560b5770a5..02aff0d3b4 100644 --- a/scvi/external/__init__.py +++ b/scvi/external/__init__.py @@ -1,12 +1,12 @@ from .cellassign import CellAssign from .contrastivevi import ContrastiveVI -from .csi.model import Model as SysVI from .gimvi import GIMVI from .poissonvi import POISSONVI from .scar import SCAR from .scbasset import SCBASSET from .solo import SOLO from .stereoscope import RNAStereoscope, SpatialStereoscope +from .sysvi import SysVI from .tangram import Tangram __all__ = [ diff --git a/scvi/external/csi/model/__init__.py b/scvi/external/csi/model/__init__.py deleted file mode 100644 index 406e1b34e6..0000000000 --- a/scvi/external/csi/model/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._model import ( - Model, -) - -__all__ = [ - "Model", -] diff --git a/scvi/external/csi/model/_training.py b/scvi/external/csi/model/_training.py deleted file mode 100644 index 4ac9543a34..0000000000 --- a/scvi/external/csi/model/_training.py +++ /dev/null @@ -1,14 +0,0 @@ -from scvi.external.csi.train import TrainingPlanCustom -from scvi.model.base import UnsupervisedTrainingMixin - - -class TrainingCustom(UnsupervisedTrainingMixin): - """Train method with custom TrainingPlan.""" - - # TODO could make custom Trainer (in a custom TrainRunner) to have in init params for early stopping - # "loss" rather than "elbo" components in available param specifications - for now just use - # a loss that is against the param specification - - # TODO run and log val before training - already tried some solutions by calling trainer.validate before - # fit and num_sanity_val_steps (designed not to log) - _training_plan_cls = TrainingPlanCustom diff --git a/scvi/external/csi/module/__init__.py b/scvi/external/csi/module/__init__.py deleted file mode 100644 index eeac8e7d9c..0000000000 --- a/scvi/external/csi/module/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from ._loss_recorder import ( - LossRecorder, -) -from ._module import ( - Module, -) - -__all__ = [ - "LossRecorder", - "Module", -] diff --git a/scvi/external/csi/module/_loss_recorder.py b/scvi/external/csi/module/_loss_recorder.py deleted file mode 100644 index b699d80031..0000000000 --- a/scvi/external/csi/module/_loss_recorder.py +++ /dev/null @@ -1,39 +0,0 @@ -class LossRecorder: - """ - Loss signature for models. - - This class provides an organized way to record the model loss, as well as - the components of the ELBO. This may also be used in MLE, MAP, EM methods. - The loss is used for backpropagation during inference. The other parameters - are used for logging/early stopping during inference. - - Parameters - ---------- - loss - Tensor with loss for minibatch. Should be one dimensional with one value. - Note that loss should be a :class:`~torch.Tensor` and not the result of ``.item()``. - reconstruction_loss - Reconstruction loss for each observation in the minibatch. - kl_local - KL divergence associated with each observation in the minibatch. - kl_global - Global kl divergence term. Should be one dimensional with one value. - **kwargs - Additional metrics can be passed as keyword arguments and will - be available as attributes of the object. - """ - - def __init__( - self, - n_obs: int, - loss: float, - loss_sum: float, - **kwargs, - ): - self.n_obs = n_obs - self.loss = loss - self.loss_sum = loss_sum - self.extra_metric_attrs = [] - for key, value in kwargs.items(): - setattr(self, key, value) - self.extra_metric_attrs.append(key) diff --git a/scvi/external/csi/nn/__init__.py b/scvi/external/csi/nn/__init__.py deleted file mode 100644 index 449b333424..0000000000 --- a/scvi/external/csi/nn/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from ._base_components import ( - Embedding, - EncoderDecoder, -) - -__all__ = [ - "EncoderDecoder", - "Embedding", -] diff --git a/scvi/external/csi/train/__init__.py b/scvi/external/csi/train/__init__.py deleted file mode 100644 index cacfa20b74..0000000000 --- a/scvi/external/csi/train/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._trainingplans import ( - TrainingPlanCustom, -) - -__all__ = [ - "TrainingPlanCustom", -] diff --git a/scvi/external/sysvi/__init__.py b/scvi/external/sysvi/__init__.py new file mode 100644 index 0000000000..b0aec05bb2 --- /dev/null +++ b/scvi/external/sysvi/__init__.py @@ -0,0 +1,3 @@ +from ._model import SysVI + +__all__ = ["SysVI"] diff --git a/scvi/external/csi/nn/_base_components.py b/scvi/external/sysvi/_base_components.py similarity index 74% rename from scvi/external/csi/nn/_base_components.py rename to scvi/external/sysvi/_base_components.py index c44f422f89..630e6afd8a 100644 --- a/scvi/external/csi/nn/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from collections import OrderedDict -from typing import Literal, Union +from typing import Literal import torch from torch.distributions import Normal @@ -16,18 +18,19 @@ class Embedding(Module): - def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True): - """Module for obtaining embedding of categorical covariates + """Module for obtaining embedding of categorical covariates + + Parameters + ---------- + size + N categories + cov_embed_dims + Dimensions of embedding + normalize + Apply layer normalization + """ - Parameters - ---------- - size - N categories - cov_embed_dims - Dimensions of embedding - normalize - Apply layer normalization - """ + def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True): super().__init__() self.normalize = normalize @@ -50,6 +53,27 @@ def forward(self, x): class EncoderDecoder(Module): + """Module that can be used as probabilistic encoder or decoder + + Based on inputs and optional covariates predicts output mean and var + + Parameters + ---------- + n_input + n_output + n_cov + n_hidden + n_layers + var_eps + See :class:`~scvi.external.sysvi.nn.VarEncoder` + var_mode + See :class:`~scvi.external.sysvi.nn.VarEncoder` + sample + Return samples from predicted distribution + kwargs + Passed to :class:`~scvi.external.sysvi.nn.Layers` + """ + def __init__( self, n_input: int, @@ -62,26 +86,6 @@ def __init__( sample: bool = False, **kwargs, ): - """Module that can be used as probabilistic encoder or decoder - - Based on inputs and optional covariates predicts output mean and var - - Parameters - ---------- - n_input - n_output - n_cov - n_hidden - n_layers - var_eps - See :class:`~scvi.external.csi.nn.VarEncoder` - var_mode - See :class:`~scvi.external.csi.nn.VarEncoder` - sample - Return samples from predicted distribution - kwargs - Passed to :class:`~scvi.external.csi.nn.Layers` - """ super().__init__() self.sample = sample @@ -99,7 +103,7 @@ def __init__( self.mean_encoder = Linear(n_hidden, n_output) self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) - def forward(self, x, cov: Union[torch.Tensor, None] = None): + def forward(self, x, cov: torch.Tensor | None = None): y = self.decoder_y(x=x, cov=cov) # TODO better handling of inappropriate edge-case values than nan_to_num or at least warn y_m = torch.nan_to_num(self.mean_encoder(y)) @@ -116,11 +120,45 @@ def forward(self, x, cov: Union[torch.Tensor, None] = None): class Layers(Module): + """A helper class to build fully-connected layers for a neural network. + + Adapted from scVI FCLayers to use covariates more flexibly + + Parameters + ---------- + n_in + The dimensionality of the main input + n_out + The dimensionality of the output + n_cov + Dimensionality of covariates. + If there are no cov this should be set to None - + in this case cov will not be used. + n_layers + The number of fully-connected hidden layers + n_hidden + The number of nodes per hidden layer + dropout_rate + Dropout rate to apply to each of the hidden layers + use_batch_norm + Whether to have `BatchNorm` layers or not + use_layer_norm + Whether to have `LayerNorm` layers or not + use_activation + Whether to have layer activation or not + bias + Whether to learn bias in linear layers or not + inject_covariates + Whether to inject covariates in each layer, or just the first. + activation_fn + Which activation function to use + """ + def __init__( self, n_in: int, n_out: int, - n_cov: Union[int, None] = None, + n_cov: int | None = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, @@ -131,39 +169,6 @@ def __init__( inject_covariates: bool = True, activation_fn: Module = ReLU, ): - """A helper class to build fully-connected layers for a neural network. - - Adapted from scVI FCLayers to use covariates more flexibly - - Parameters - ---------- - n_in - The dimensionality of the main input - n_out - The dimensionality of the output - n_cov - Dimensionality of covariates. - If there are no cov this should be set to None - - in this case cov will not be used. - n_layers - The number of fully-connected hidden layers - n_hidden - The number of nodes per hidden layer - dropout_rate - Dropout rate to apply to each of the hidden layers - use_batch_norm - Whether to have `BatchNorm` layers or not - use_layer_norm - Whether to have `LayerNorm` layers or not - use_activation - Whether to have layer activation or not - bias - Whether to learn bias in linear layers or not - inject_covariates - Whether to inject covariates in each layer, or just the first. - activation_fn - Which activation function to use - """ super().__init__() self.inject_covariates = inject_covariates @@ -230,7 +235,7 @@ def _hook_fn_zero_out(grad): b = layer.bias.register_hook(_hook_fn_zero_out) self.hooks.append(b) - def forward(self, x: torch.Tensor, cov: Union[torch.Tensor, None] = None): + def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): """ Forward computation on ``x``. @@ -270,6 +275,23 @@ def forward(self, x: torch.Tensor, cov: Union[torch.Tensor, None] = None): class VarEncoder(Module): + """Encode variance (strictly positive). + + Parameters + ---------- + n_input + Number of input dimensions, used if mode is sample_feature + n_output + Number of variances to predict + mode + How to compute var + 'sample_feature' - learn per sample and feature + 'feature' - learn per feature, constant across samples + 'linear' - linear with respect to input mean, var = a1 * mean + a0; + not suggested to be used due to bad implementation for positive constraining + eps + """ + def __init__( self, n_input: int, @@ -277,22 +299,6 @@ def __init__( mode: Literal["sample_feature", "feature", "linear"], eps: float = 1e-4, ): - """Encode variance (strictly positive). - - Parameters - ---------- - n_input - Number of input dimensions, used if mode is sample_feature - n_output - Number of variances to predict - mode - How to compute var - 'sample_feature' - learn per sample and feature - 'feature' - learn per feature, constant across samples - 'linear' - linear with respect to input mean, var = a1 * mean + a0; - not suggested to be used due to bad implementation for positive constraining - eps - """ super().__init__() self.eps = eps diff --git a/scvi/external/csi/model/_model.py b/scvi/external/sysvi/_model.py similarity index 80% rename from scvi/external/csi/model/_model.py rename to scvi/external/sysvi/_model.py index cb71d3e714..a116d8424c 100644 --- a/scvi/external/csi/model/_model.py +++ b/scvi/external/sysvi/_model.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import logging from collections.abc import Sequence -from typing import Optional, Union import numpy as np import pandas as pd @@ -14,41 +15,52 @@ LayerField, ObsmField, ) -from scvi.external.csi.module import Module -from scvi.model.base import BaseModelClass +from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.utils import setup_anndata_dsp -from ._training import TrainingCustom +from ._module import SysVAE +from ._trainingplans import TrainingPlanCustom from ._utils import prepare_metadata logger = logging.getLogger(__name__) -class Model(TrainingCustom, BaseModelClass): +class TrainingCustom(UnsupervisedTrainingMixin): + """Train method with custom TrainingPlan.""" + + # TODO could make custom Trainer (in a custom TrainRunner) to have in init params for early stopping + # "loss" rather than "elbo" components in available param specifications - for now just use + # a loss that is against the param specification + + _training_plan_cls = TrainingPlanCustom + + +class SysVI(TrainingCustom, BaseModelClass): + """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. + + Parameters + ---------- + adata + AnnData object that has been registered via :meth:`~scvi-tools.SysVI.setup_anndata`. + prior + The prior distribution to be used. You can choose between "standard_normal" and "vamp". + n_prior_components + Number of prior components in VampPrior. + pseudoinputs_data_indices + By default VampPrior pseudoinputs are randomly selected from data. + Alternatively, one can specify pseudoinput indices using this parameter. + **model_kwargs + Keyword args for :class:`~scvi.external.sysvi.module.Module` + """ + def __init__( self, adata: AnnData, prior: Literal["standard_normal", "vamp"] = "vamp", n_prior_components=5, - pseudoinputs_data_indices: Optional[np.array] = None, + pseudoinputs_data_indices: np.array | None = None, **model_kwargs, ): - """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi-tools.SysVI.setup_anndata`. - prior - The prior distribution to be used. You can choose between "standard_normal" and "vamp". - n_prior_components - Number of prior components in VampPrior. - pseudoinputs_data_indices - By default VampPrior pseudoinputs are randomly selected from data. - Alternatively, one can specify pseudoinput indices using this parameter. - **model_kwargs - Keyword args for :class:`~scvi.external.csi.module.Module` - """ super().__init__(adata) if prior == "vamp": @@ -79,7 +91,7 @@ def __init__( ) # self.summary_stats provides information about anndata dimensions and other tensor info - self.module = Module( + self.module = SysVAE( n_input=adata.shape[1], n_cov_const=n_cov_const, cov_embed_sizes=cov_embed_sizes, @@ -99,15 +111,15 @@ def __init__( logger.info("The model has been initialized") @torch.no_grad() - def embed( + def get_latent_representation( self, adata: AnnData, - indices: Optional[Sequence[int]] = None, + indices: Sequence[int] | None = None, cycle: bool = False, give_mean: bool = True, - batch_size: Optional[int] = None, + batch_size: int | None = None, as_numpy: bool = True, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> np.ndarray | torch.Tensor: """Return the latent representation for each cell. Parameters @@ -177,14 +189,14 @@ def embed( def setup_anndata( cls, adata: AnnData, - system_key: str, - layer: Optional[str] = None, - categorical_covariate_keys: Optional[list[str]] = None, - categorical_covariate_embed_keys: Optional[list[str]] = None, - continuous_covariate_keys: Optional[list[str]] = None, - covariate_categ_orders: Optional[dict] = None, - covariate_key_orders: Optional[dict] = None, - system_order: Optional[list[str]] = None, + batch_key: str, + layer: str | None = None, + categorical_covariate_keys: list[str] | None = None, + categorical_covariate_embed_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + covariate_categ_orders: dict | None = None, + covariate_key_orders: dict | None = None, + system_order: list[str] | None = None, **kwargs, ) -> AnnData: """Prepare adata for input to Model @@ -193,8 +205,10 @@ def setup_anndata( ---------- adata Adata object - will be modified in place. - system_key - Name of obs column with categorical system information. + batch_key + Name of the obs column with the substantial batch effect covariate, + referred to as system in the original publication (Hrovatin, et al., 2023). + Should be categorical. layer AnnData layer to use, default is X. Should contain normalized and log+1 transformed expression. @@ -237,9 +251,9 @@ def setup_anndata( # Define order of system categories if system_order is None: - system_order = sorted(adata.obs[system_key].unique()) + system_order = sorted(adata.obs[batch_key].unique()) # Validate that the provided system_order matches the categories in adata.obs[system_key] - if set(system_order) != set(adata.obs[system_key].unique()): + if set(system_order) != set(adata.obs[batch_key].unique()): raise ValueError( "Provided system_order does not match the categories in adata.obs[system_key]" ) @@ -251,7 +265,7 @@ def setup_anndata( adata.uns["system_order"] = system_order system_cat = pd.Series( pd.Categorical( - values=adata.obs[system_key], categories=system_order, ordered=True + values=adata.obs[batch_key], categories=system_order, ordered=True ), index=adata.obs.index, name="system", @@ -263,13 +277,13 @@ def setup_anndata( # System must not be in cov if categorical_covariate_keys is not None: - if system_key in categorical_covariate_keys: + if batch_key in categorical_covariate_keys: raise ValueError("system_key should not be within covariate keys") if categorical_covariate_embed_keys is not None: - if system_key in categorical_covariate_embed_keys: + if batch_key in categorical_covariate_embed_keys: raise ValueError("system_key should not be within covariate keys") if continuous_covariate_keys is not None: - if system_key in continuous_covariate_keys: + if batch_key in continuous_covariate_keys: raise ValueError("system_key should not be within covariate keys") # Prepare covariate training representations/embedding diff --git a/scvi/external/csi/module/_module.py b/scvi/external/sysvi/_module.py similarity index 82% rename from scvi/external/csi/module/_module.py rename to scvi/external/sysvi/_module.py index b798916623..4c77aedada 100644 --- a/scvi/external/csi/module/_module.py +++ b/scvi/external/sysvi/_module.py @@ -1,19 +1,61 @@ -from typing import Optional, Union +from __future__ import annotations import torch from typing_extensions import Literal from scvi import REGISTRY_KEYS -from scvi.external.csi.nn import Embedding, EncoderDecoder -from scvi.module.base import BaseModuleClass, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data -from . import LossRecorder +from ._base_components import Embedding, EncoderDecoder from ._priors import StandardPrior, VampPrior torch.backends.cudnn.benchmark = True -class Module(BaseModuleClass): +class SysVAE(BaseModuleClass): + """CVAE with optional VampPrior and latent cycle consistency loss. + + Parameters + ---------- + n_input + Number of input features. + Passed directly from Model. + n_cov_const + Dimensionality of covariate data that will not be further embedded. + Passed directly from Model. + cov_embed_sizes + Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...]. + Passed directly from Model. + n_system + Number of systems. + Passed directly from Model. + cov_embed_dims + Dimension for covariate embedding. + prior + Which prior distribution to use. + Passed directly from Model. + n_prior_components + If VampPrior - how many prior components to use. + Passed directly from Model. + trainable_priors + If VampPrior - should prior components be trainable. + pseudoinput_data + Initialisation data for VampPrior. Should match input tensors structure. + Passed directly from Model. + n_latent + Numer of latent space dimensions. + n_hidden + Number of nodes in hidden layers. + n_layers + Number of hidden layers. + dropout_rate + Dropout rate. + out_var_mode + See :class:`~scvi.external.sysvi.nn.VarEncoder` + enc_dec_kwargs + Additional kwargs passed to encoder and decoder. + """ + # TODO could disable computation of cycle if predefined that cycle wil not be used def __init__( @@ -26,7 +68,7 @@ def __init__( prior: Literal["standard_normal", "vamp"] = "vamp", n_prior_components: int = 5, trainable_priors: bool = True, - pseudoinput_data: Optional[dict[str, torch.Tensor]] = None, + pseudoinput_data: dict[str, torch.Tensor] | None = None, n_latent: int = 15, n_hidden: int = 256, n_layers: int = 2, @@ -34,48 +76,6 @@ def __init__( out_var_mode: str = "feature", **enc_dec_kwargs, ): - """CVAE with optional VampPrior and latent cycle consistency loss. - - Parameters - ---------- - n_input - Number of input features. - Passed directly from Model. - n_cov_const - Dimensionality of covariate data that will not be further embedded. - Passed directly from Model. - cov_embed_sizes - Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...]. - Passed directly from Model. - n_system - Number of systems. - Passed directly from Model. - cov_embed_dims - Dimension for covariate embedding. - prior - Which prior distribution to use. - Passed directly from Model. - n_prior_components - If VampPrior - how many prior components to use. - Passed directly from Model. - trainable_priors - If VampPrior - should prior components be trainable. - pseudoinput_data - Initialisation data for VampPrior. Should match input tensors structure. - Passed directly from Model. - n_latent - Numer of latent space dimensions. - n_hidden - Number of nodes in hidden layers. - n_layers - Number of hidden layers. - dropout_rate - Dropout rate. - out_var_mode - See :class:`~scvi.external.csi.nn.VarEncoder` - enc_dec_kwargs - Additional kwargs passed to encoder and decoder. - """ super().__init__() self.embed_cov = len(cov_embed_sizes) > 0 # Will any covs be embedded @@ -171,7 +171,7 @@ def _get_generative_input( return input_dict @auto_move_data - def _get_cov(self, tensors: dict[str, torch.Tensor]) -> Optional[torch.Tensor]: + def _get_cov(self, tensors: dict[str, torch.Tensor]) -> torch.Tensor | None: """Merge all covariates into single tensor, including embedding of covariates""" cov = [] if self.n_cov_const > 0: @@ -187,12 +187,12 @@ def _get_cov(self, tensors: dict[str, torch.Tensor]) -> Optional[torch.Tensor]: return cov @staticmethod - def _merge_cov(cov: Optional[torch.Tensor], system: torch.Tensor) -> torch.Tensor: + def _merge_cov(cov: torch.Tensor | None, system: torch.Tensor) -> torch.Tensor: """Merge full covariate data and system data to get cov for model input""" return torch.cat([cov, system], dim=1) if cov is not None else system @staticmethod - def _mock_cov(cov: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + def _mock_cov(cov: torch.Tensor | None) -> torch.Tensor | None: """Make mock (all 0) covariates for cycle""" return torch.zeros_like(cov) if cov is not None else None @@ -256,16 +256,16 @@ def outputs(compute, name, res, x, cov, system): def forward( self, tensors, - get_inference_input_kwargs: Optional[dict] = None, - get_generative_input_kwargs: Optional[dict] = None, - inference_kwargs: Optional[dict] = None, - generative_kwargs: Optional[dict] = None, - loss_kwargs: Optional[dict] = None, + get_inference_input_kwargs: dict | None = None, + get_generative_input_kwargs: dict | None = None, + inference_kwargs: dict | None = None, + generative_kwargs: dict | None = None, + loss_kwargs: dict | None = None, compute_loss=True, - ) -> Union[ - tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], - tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], LossRecorder], - ]: + ) -> ( + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] + | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], LossOutput] + ): """ Forward pass through the network. @@ -399,13 +399,14 @@ def standardize(x): + z_distance_cyc * z_distance_cycle_weight ) - return LossRecorder( - n_obs=loss.shape[0], + return LossOutput( + n_obs_minibatch=loss.shape[0], loss=loss.mean(), - loss_sum=loss.sum(), - reconstruction_loss=reconst_loss.sum(), - kl_local=kl_divergence_z.sum(), - z_distance_cycle=z_distance_cyc.sum(), + extra_metrics={ + "reconstruction_loss": reconst_loss.mean(), + "kl_local": kl_divergence_z.mean(), + "z_distance_cycle": z_distance_cyc.mean(), + }, ) @staticmethod @@ -448,7 +449,7 @@ def sample(self, *args, **kwargs): raise NotImplementedError("") -def _get_dict_if_none(param: Optional[dict]) -> dict: +def _get_dict_if_none(param: dict | None) -> dict: """If not a dict return empty dict""" param = {} if not isinstance(param, dict) else param return param diff --git a/scvi/external/csi/module/_priors.py b/scvi/external/sysvi/_priors.py similarity index 82% rename from scvi/external/csi/module/_priors.py rename to scvi/external/sysvi/_priors.py index 7db63fbd1b..ed2e5acc02 100644 --- a/scvi/external/csi/module/_priors.py +++ b/scvi/external/sysvi/_priors.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import abc from abc import abstractmethod -from typing import Optional import torch from torch.distributions import Normal, kl_divergence @@ -21,6 +22,24 @@ def kl(self, m_q, v_q, z=None): class VampPrior(Prior): + """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/main/vaes/vae_priors_example.ipynb + + Parameters + ---------- + n_components + Prior components + n_input + Model input dimensions + n_cov + Model input covariate dimensions + encoder + The encoder + data + Data for pseudoinputs initialisation tuple(input,covs) + trainable_priors + Are pseudoinput parameters trainable or fixed + """ + # K - components, I - inputs, L - latent, N - samples def __init__( @@ -29,26 +48,9 @@ def __init__( n_input, n_cov, encoder, - data: Optional[tuple[torch.tensor, torch.tensor]] = None, + data: tuple[torch.tensor, torch.tensor] | None = None, trainable_priors=True, ): - """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/main/vaes/vae_priors_example.ipynb - - Parameters - ---------- - n_components - Prior components - n_input - Model input dimensions - n_cov - Model input covariate dimensions - encoder - The encoder - data - Data for pseudoinputs initialisation tuple(input,covs) - trainable_priors - Are pseudoinput parameters trainable or fixed - """ super().__init__() self.encoder = encoder diff --git a/scvi/external/csi/train/_trainingplans.py b/scvi/external/sysvi/_trainingplans.py similarity index 67% rename from scvi/external/csi/train/_trainingplans.py rename to scvi/external/sysvi/_trainingplans.py index 7793045db1..ce4b521c1c 100644 --- a/scvi/external/csi/train/_trainingplans.py +++ b/scvi/external/sysvi/_trainingplans.py @@ -1,11 +1,11 @@ +from __future__ import annotations + from inspect import getfullargspec -from typing import Literal, Union +from typing import Literal import torch -from torchmetrics import MetricCollection -from scvi.external.csi.module import LossRecorder -from scvi.module.base import BaseModuleClass +from scvi.module.base import BaseModuleClass, LossOutput from scvi.train import TrainingPlan # TODO could make new metric class to not be called elbo metric as used for other metrics as well @@ -13,6 +13,27 @@ class WeightScaling: + """Linearly scale loss weights between start and end weight accordingly to the current training stage + + Parameters + ---------- + weight_start + Starting weight value + weight_end + End weight vlue + point_start + Training point to start scaling - before weight is weight_start + Since the epochs are counted after they are run, + the start point must be set to 0 to represent 1st epoch + point_end + Training point to end scaling - after weight is weight_end + Since the epochs are counted after they are run, + the start point must be set to n-1 to represent the last epoch + update_on + Define training progression based on epochs or steps + + """ + def __init__( self, weight_start: float, @@ -21,26 +42,6 @@ def __init__( point_end: int, update_on: Literal["epoch", "step"] = "step", ): - """Linearly scale loss weights between start and end weight accordingly to the current training stage - - Parameters - ---------- - weight_start - Starting weight value - weight_end - End weight vlue - point_start - Training point to start scaling - before weight is weight_start - Since the epochs are counted after they are run, - the start point must be set to 0 to represent 1st epoch - point_end - Training point to end scaling - after weight is weight_end - Since the epochs are counted after they are run, - the start point must be set to n-1 to represent the last epoch - update_on - Define training progression based on epochs or steps - - """ self.weight_start = weight_start self.weight_end = weight_end self.point_start = point_start @@ -91,34 +92,35 @@ def weight( class TrainingPlanCustom(TrainingPlan): + """Extends scvi TrainingPlan for custom support for other losses. + + Parameters + ---------- + args + Passed to parent + log_on_epoch + See on_epoch of lightning Module log method + log_on_step + See on_step of lightning Module log method + loss_weights + Specifies how losses should be weighted and how it may change during training + Dict with keys being loss names and values being loss weights. + Loss weights can be floats for constant weight or dict of params passed to WeightScaling object + Note that other loss weight params from the parent class are ignored + (e.g. n_steps/epochs_kl_warmup and min/max_kl_weight) + kwargs + Passed to parent. + As described in param loss_weights the loss weighting params of parent are ignored + """ + def __init__( self, module: BaseModuleClass, - loss_weights: Union[None, dict[str, Union[float, WeightScaling]]] = None, + loss_weights: None | dict[str, float | WeightScaling] = None, log_on_epoch: bool = True, log_on_step: bool = False, **kwargs, ): - """Extends scvi TrainingPlan for custom support for other losses. - - Parameters - ---------- - args - Passed to parent - log_on_epoch - See on_epoch of lightning Module log method - log_on_step - See on_step of lightning Module log method - loss_weights - Specifies how losses should be weighted and how it may change during training - Dict with keys being loss names and values being loss weights. - Loss weights can be floats for constant weight or dict of params passed to WeightScaling object - Note that other loss weight params from the parent class are ignored - (e.g. n_steps/epochs_kl_warmup and min/max_kl_weight) - kwargs - Passed to parent. - As described in param loss_weights the loss weighting params of parent are ignored - """ super().__init__(module, **kwargs) self.log_on_epoch = log_on_epoch @@ -159,7 +161,7 @@ def compute_loss_weight(self, weight): @staticmethod def _create_elbo_metric_components( mode: Literal["train", "validation"], **kwargs - ) -> (ElboMetric, MetricCollection): + ) -> ElboMetric: """ Initialize the combined loss collection. @@ -174,19 +176,15 @@ def _create_elbo_metric_components( Objects for storing the combined loss """ - loss = ElboMetric("loss", mode, "obs") - collection = MetricCollection({metric.name: metric for metric in [loss]}) - return loss, collection + loss = ElboMetric("loss", mode, "batch") + return loss def initialize_train_metrics(self): """Initialize train combined loss. TODO could add other losses """ - ( - self.loss_train, - self.train_metrics, - ) = self._create_elbo_metric_components( + self.loss_train = self._create_elbo_metric_components( mode="train", n_total=self.n_obs_training ) self.loss_train.reset() @@ -196,10 +194,7 @@ def initialize_val_metrics(self): TODO could add other losses """ - ( - self.loss_val, - self.val_metrics, - ) = self._create_elbo_metric_components( + self.loss_val = self._create_elbo_metric_components( mode="validation", n_total=self.n_obs_validation ) self.loss_val.reset() @@ -207,8 +202,8 @@ def initialize_val_metrics(self): @torch.no_grad() def compute_and_log_metrics( self, - loss_recorder: LossRecorder, - metrics: MetricCollection, + loss_output: LossOutput, + loss_metric: ElboMetric, mode: str, ): """ @@ -216,36 +211,34 @@ def compute_and_log_metrics( Parameters ---------- - loss_recorder + loss_output LossRecorder object from scvi-tools module - metrics - The loss Metric Collection to update + loss_metric + The loss Metric to update mode Postfix string to add to the metric name of extra metrics. If train also logs the loss in progress bar """ - n_obs_minibatch = loss_recorder.n_obs - loss_sum = loss_recorder.loss_sum + n_obs_minibatch = loss_output.n_obs_minibatch + loss = loss_output.loss - # use the torchmetric object - metrics.update( - loss=loss_sum, + loss_metric.update( + loss=loss, n_obs_minibatch=n_obs_minibatch, ) self.log( f"loss_{mode}", - loss_recorder.loss_sum, + loss, on_step=self.log_on_step, on_epoch=self.log_on_epoch, - batch_size=n_obs_minibatch, prog_bar=True if mode == "train" else False, sync_dist=self.use_sync_dist, ) # accumulate extra metrics passed to loss recorder - for extra_metric in loss_recorder.extra_metric_attrs: - met = getattr(loss_recorder, extra_metric) + for extra_metric in loss_output.extra_metrics_keys: + met = loss_output.extra_metrics[extra_metric] if isinstance(met, torch.Tensor): if met.shape != torch.Size([]): raise ValueError("Extra tracked metrics should be 0-d tensors.") @@ -255,7 +248,6 @@ def compute_and_log_metrics( met, on_step=self.log_on_step, on_epoch=self.log_on_epoch, - batch_size=n_obs_minibatch, sync_dist=self.use_sync_dist, ) @@ -264,13 +256,13 @@ def training_step(self, batch, batch_idx): self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)}) _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) # combined loss is logged via compute_and_log_metrics - self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train") + self.compute_and_log_metrics(scvi_loss, self.loss_train, "train") return scvi_loss.loss def validation_step(self, batch, batch_idx): _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) # Combined loss is logged via compute_and_log_metrics - self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") + self.compute_and_log_metrics(scvi_loss, self.loss_val, "validation") @property def kl_weight(self): diff --git a/scvi/external/csi/model/_utils.py b/scvi/external/sysvi/_utils.py similarity index 93% rename from scvi/external/csi/model/_utils.py rename to scvi/external/sysvi/_utils.py index 9f7dc5fe98..5eadcb8819 100644 --- a/scvi/external/csi/model/_utils.py +++ b/scvi/external/sysvi/_utils.py @@ -1,15 +1,15 @@ -from typing import Optional, Union +from __future__ import annotations import pandas as pd def prepare_metadata( meta_data: pd.DataFrame, - cov_cat_keys: Optional[list] = None, - cov_cat_embed_keys: Optional[list] = None, - cov_cont_keys: Optional[list] = None, - categ_orders: Optional[dict] = None, - key_orders: Optional[dict] = None, + cov_cat_keys: list | None = None, + cov_cat_embed_keys: list | None = None, + cov_cont_keys: list | None = None, + categ_orders: list | None = None, + key_orders: list | None = None, ): """ Prepare content of dataframe columns for model training (one hot encoding, encoding for embedding, ...) @@ -42,7 +42,7 @@ def prepare_metadata( """ - def get_categories_order(values: pd.Series, categories: Union[list, None] = None): + def get_categories_order(values: pd.Series, categories: list | None = None): """ Helper to get order of categories based on values and optional list of categories diff --git a/tests/external/csi/test_model.py b/tests/external/sysvi/test_model.py similarity index 89% rename from tests/external/csi/test_model.py rename to tests/external/sysvi/test_model.py index d8d6f05cc3..3fa2089e06 100644 --- a/tests/external/csi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -6,7 +6,7 @@ from numpy.testing import assert_raises from scipy import sparse -from scvi.external.csi.model import Model +from scvi.external import SysVI def mock_adata(): @@ -54,9 +54,9 @@ def test_model(): adata0 = mock_adata() # Run adata setup with all covariates - Model.setup_anndata( + SysVI.setup_anndata( adata0, - system_key="system", + batch_key="system", categorical_covariate_keys=["covariate_cat"], categorical_covariate_embed_keys=["covariate_cat_emb"], continuous_covariate_keys=["covariate_cont"], @@ -65,9 +65,9 @@ def test_model(): # Run adata setup transfer # TODO ensure this is actually done correctly, not just that it runs through adata = mock_adata() - Model.setup_anndata( + SysVI.setup_anndata( adata, - system_key="system", + batch_key="system", categorical_covariate_keys=["covariate_cat"], categorical_covariate_embed_keys=["covariate_cat_emb"], continuous_covariate_keys=["covariate_cont"], @@ -77,9 +77,9 @@ def test_model(): ) # Check that setup of adata without covariates works - Model.setup_anndata( + SysVI.setup_anndata( adata0, - system_key="system", + batch_key="system", ) assert "covariates" not in adata0.obsm assert "covariates_embed" not in adata0.obsm @@ -87,15 +87,15 @@ def test_model(): # Model # Check that model runs through with standard normal prior - model = Model(adata=adata, prior="standard_normal") + model = SysVI(adata=adata, prior="standard_normal") model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) - # Check that model runs through without covariates - model = Model(adata=adata0) + # Check that mode runs through without covariates + model = SysVI(adata=adata0) model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) # Check pre-specifying pseudoinput indices for vamp prior - _ = Model( + _ = SysVI( adata=adata, prior="vamp", pseudoinputs_data_indices=list(range(5)), @@ -104,7 +104,7 @@ def test_model(): # Check that model runs through with vamp prior without specifying pseudoinput indices, # all covariates, and weight scaling - model = Model(adata=adata, prior="vamp") + model = SysVI(adata=adata, prior="vamp") model.train( max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0), @@ -131,7 +131,7 @@ def test_model(): # Check that embedding default works assert ( - model.embed( + model.get_latent_representation( adata=adata, ).shape[0] == adata.shape[0] @@ -139,7 +139,7 @@ def test_model(): # Check that indices in embedding works idx = [1, 2, 3] - embed = model.embed( + embed = model.get_latent_representation( adata=adata, indices=idx, give_mean=True, @@ -149,7 +149,7 @@ def test_model(): # Check predicting mean/sample np.testing.assert_allclose( embed, - model.embed( + model.get_latent_representation( adata=adata, indices=idx, give_mean=True, @@ -158,7 +158,7 @@ def test_model(): with assert_raises(AssertionError): np.testing.assert_allclose( embed, - model.embed( + model.get_latent_representation( adata=adata, indices=idx, give_mean=False, @@ -169,7 +169,7 @@ def test_model(): with assert_raises(AssertionError): np.testing.assert_allclose( embed, - model.embed( + model.get_latent_representation( adata=adata, indices=idx, give_mean=True, From 5edad83bb1f24900d2c514c7a0460c7ad4a28bd0 Mon Sep 17 00:00:00 2001 From: Hrovatin Date: Wed, 7 Feb 2024 13:35:08 +0100 Subject: [PATCH 05/60] updates --- scvi/external/sysvi/__init__.py | 4 +- scvi/external/sysvi/_base_components.py | 15 ++++-- scvi/external/sysvi/_model.py | 68 +++++++++++++++++++++++-- scvi/external/sysvi/_module.py | 1 - scvi/external/sysvi/_priors.py | 2 +- tests/external/sysvi/test_model.py | 15 ++++-- 6 files changed, 89 insertions(+), 16 deletions(-) diff --git a/scvi/external/sysvi/__init__.py b/scvi/external/sysvi/__init__.py index b0aec05bb2..916b81ff2d 100644 --- a/scvi/external/sysvi/__init__.py +++ b/scvi/external/sysvi/__init__.py @@ -1,3 +1,5 @@ +from ._base_components import Layers, VarEncoder from ._model import SysVI +from ._module import SysVAE -__all__ = ["SysVI"] +__all__ = ["SysVI", "VarEncoder", "Layers", "SysVAE"] diff --git a/scvi/external/sysvi/_base_components.py b/scvi/external/sysvi/_base_components.py index 630e6afd8a..6813101571 100644 --- a/scvi/external/sysvi/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -60,18 +60,25 @@ class EncoderDecoder(Module): Parameters ---------- n_input + The dimensionality of the main input n_output + The dimensionality of the output n_cov + Dimensionality of covariates. + If there are no cov this should be set to None - + in this case cov will not be used. n_hidden + The number of fully-connected hidden layers n_layers + Number of hidden layers var_eps - See :class:`~scvi.external.sysvi.nn.VarEncoder` + See :class:`~scvi.external.sysvi.VarEncoder` var_mode - See :class:`~scvi.external.sysvi.nn.VarEncoder` + See :class:`~scvi.external.sysvi.VarEncoder` sample Return samples from predicted distribution kwargs - Passed to :class:`~scvi.external.sysvi.nn.Layers` + Passed to :class:`~scvi.external.sysvi.Layers` """ def __init__( @@ -82,7 +89,7 @@ def __init__( n_hidden: int = 256, n_layers: int = 3, var_eps: float = 1e-4, - var_mode: str = "feature", + var_mode: Literal["sample_feature", "feature", "linear"] = "feature", sample: bool = False, **kwargs, ): diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index a116d8424c..9b19114b9a 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -11,6 +11,8 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager +from scvi.data._constants import _SCVI_UUID_KEY +from scvi.data._utils import _check_if_view from scvi.data.fields import ( LayerField, ObsmField, @@ -143,7 +145,6 @@ def get_latent_representation( """ # Check model and adata self._check_if_trained(warn=False) - # TODO extend to check if adata setup is correct wrt training data adata = self._validate_anndata(adata) if indices is None: indices = np.arange(adata.n_obs) @@ -178,12 +179,63 @@ def get_latent_representation( else: predicted += [inference_outputs["z"]] - predicted = torch.cat(predicted) + predicted = torch.cat(predicted).cpu() if as_numpy: - predicted = predicted.cpu().numpy() + predicted = predicted.numpy() return predicted + def _validate_anndata( + self, adata: AnnData | None = None, copy_if_view: bool = True + ) -> AnnData: + """Validate anndata has been properly registered" + + Parameters + ---------- + adata + Adata to validate. If None use SysVI's adata. + copy_if_view + Whether to copy adata before + + Returns + ------- + + """ "" + if adata is None: + adata = self.adata + + _check_if_view(adata, copy_if_view=copy_if_view) + + if _SCVI_UUID_KEY not in adata.uns: + raise ValueError("Adata is not set up. Use SysVI.setup_anndata first.") + else: + # Check that all required fields are present and match the Model's adata + assert ( + self.adata.uns["layer_information"]["layer"] + == adata.uns["layer_information"]["layer"] + ) + assert ( + self.adata.uns["layer_information"]["var_names"] + == adata.uns["layer_information"]["var_names"] + ) + assert self.adata.uns["system_order"] == adata.uns["system_order"] + for covariate_type, covariate_keys in self.adata.uns[ + "covariate_key_orders" + ].items(): + assert ( + covariate_keys == adata.uns["covariate_key_orders"][covariate_type] + ) + if "categorical" in covariate_type: + for covariate_key in covariate_keys: + assert ( + self.adata.uns["covariate_categ_orders"][covariate_key] + == adata.uns["covariate_categ_orders"][covariate_key] + ) + # Ensures that manager is set up + super()._validate_anndata(adata) + + return adata + @classmethod @setup_anndata_dsp.dedent def setup_anndata( @@ -198,7 +250,7 @@ def setup_anndata( covariate_key_orders: dict | None = None, system_order: list[str] | None = None, **kwargs, - ) -> AnnData: + ): """Prepare adata for input to Model Parameters @@ -232,10 +284,16 @@ def setup_anndata( """ setup_method_args = cls._get_setup_method_args(**locals()) - # Make sure var names are unique if adata.shape[1] != len(set(adata.var_names)): raise ValueError("Adata var_names are not unique") + # The used layer argument + # This could be also done via registry, but that is too cumbersome + adata.uns["layer_information"] = { + "layer": layer, + "var_names": list(adata.var_names), + } + # If setup is to be prepared wtr another adata specs make sure all relevant info is present if covariate_categ_orders or covariate_key_orders or system_order: assert system_order is not None diff --git a/scvi/external/sysvi/_module.py b/scvi/external/sysvi/_module.py index 4c77aedada..12be6adaae 100644 --- a/scvi/external/sysvi/_module.py +++ b/scvi/external/sysvi/_module.py @@ -392,7 +392,6 @@ def standardize(x): z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"] ) - # Overall loss loss = ( reconst_loss * reconstruction_weight + kl_divergence_z * kl_weight diff --git a/scvi/external/sysvi/_priors.py b/scvi/external/sysvi/_priors.py index ed2e5acc02..d2c28e1375 100644 --- a/scvi/external/sysvi/_priors.py +++ b/scvi/external/sysvi/_priors.py @@ -22,7 +22,7 @@ def kl(self, m_q, v_q, z=None): class VampPrior(Prior): - """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/main/vaes/vae_priors_example.ipynb + """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/blob/main/vaes/vae_priors_example.ipynb Parameters ---------- diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 3fa2089e06..1bc074f1fe 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -77,12 +77,13 @@ def test_model(): ) # Check that setup of adata without covariates works + adata_no_cov = mock_adata() SysVI.setup_anndata( - adata0, + adata_no_cov, batch_key="system", ) - assert "covariates" not in adata0.obsm - assert "covariates_embed" not in adata0.obsm + assert "covariates" not in adata_no_cov.obsm + assert "covariates_embed" not in adata_no_cov.obsm # Model @@ -91,7 +92,7 @@ def test_model(): model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) # Check that mode runs through without covariates - model = SysVI(adata=adata0) + model = SysVI(adata=adata_no_cov) model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) # Check pre-specifying pseudoinput indices for vamp prior @@ -137,6 +138,12 @@ def test_model(): == adata.shape[0] ) + # Ensure that embedding with another adata properly checks if it was setu up correctly + _ = model.get_latent_representation(adata=adata0) + with assert_raises(AssertionError): + # TODO could add more check for each property separately + _ = model.get_latent_representation(adata=adata_no_cov) + # Check that indices in embedding works idx = [1, 2, 3] embed = model.get_latent_representation( From 9b05bca98665fe77b4e732a93c3d407a143d8899 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:00:13 +0000 Subject: [PATCH 06/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scvi/external/sysvi/_base_components.py | 16 +++--------- scvi/external/sysvi/_model.py | 34 ++++++------------------- scvi/external/sysvi/_module.py | 23 +++++------------ 3 files changed, 19 insertions(+), 54 deletions(-) diff --git a/scvi/external/sysvi/_base_components.py b/scvi/external/sysvi/_base_components.py index 6813101571..4cb545795a 100644 --- a/scvi/external/sysvi/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -40,9 +40,7 @@ def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True): if self.normalize: # TODO this could probably be implemented more efficiently as embed gives same result for every sample in # a give class. However, if we have many balanced classes there wont be many repetitions within minibatch - self.layer_norm = torch.nn.LayerNorm( - cov_embed_dims, elementwise_affine=False - ) + self.layer_norm = torch.nn.LayerNorm(cov_embed_dims, elementwise_affine=False) def forward(self, x): x = self.embedding(x) @@ -198,16 +196,12 @@ def __init__( BatchNorm1d(n_out, momentum=0.01, eps=0.001) if use_batch_norm else None, - LayerNorm(n_out, elementwise_affine=False) - if use_layer_norm - else None, + LayerNorm(n_out, elementwise_affine=False) if use_layer_norm else None, activation_fn() if use_activation else None, Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) - for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:]) - ) + for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) ] ) ) @@ -264,9 +258,7 @@ def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): if layer is not None: if isinstance(layer, BatchNorm1d): if x.dim() == 3: - x = torch.cat( - [(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0 - ) + x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) else: x = layer(x) else: diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index 9b19114b9a..ca17264f4f 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -83,9 +83,7 @@ def __init__( else: pseudoinput_data = None - n_cov_const = ( - adata.obsm["covariates"].shape[1] if "covariates" in adata.obsm else 0 - ) + n_cov_const = adata.obsm["covariates"].shape[1] if "covariates" in adata.obsm else 0 cov_embed_sizes = ( pd.DataFrame(adata.obsm["covariates_embed"]).nunique(axis=0).to_list() if "covariates_embed" in adata.obsm @@ -219,12 +217,8 @@ def _validate_anndata( == adata.uns["layer_information"]["var_names"] ) assert self.adata.uns["system_order"] == adata.uns["system_order"] - for covariate_type, covariate_keys in self.adata.uns[ - "covariate_key_orders" - ].items(): - assert ( - covariate_keys == adata.uns["covariate_key_orders"][covariate_type] - ) + for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): + assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] if "categorical" in covariate_type: for covariate_key in covariate_keys: assert ( @@ -317,14 +311,10 @@ def setup_anndata( ) # Make one-hot embedding with specified order - systems_dict = dict( - zip(system_order, ([float(i) for i in range(0, len(system_order))])) - ) + systems_dict = dict(zip(system_order, ([float(i) for i in range(0, len(system_order))]))) adata.uns["system_order"] = system_order system_cat = pd.Series( - pd.Categorical( - values=adata.obs[batch_key], categories=system_order, ordered=True - ), + pd.Categorical(values=adata.obs[batch_key], categories=system_order, ordered=True), index=adata.obs.index, name="system", ) @@ -357,10 +347,7 @@ def setup_anndata( # Save covariate representation and order information adata.uns["covariate_categ_orders"] = orders_dict adata.uns["covariate_key_orders"] = cov_dict - if ( - continuous_covariate_keys is not None - or categorical_covariate_keys is not None - ): + if continuous_covariate_keys is not None or categorical_covariate_keys is not None: adata.obsm["covariates"] = covariates else: # Remove if present since the presence of this key @@ -380,15 +367,10 @@ def setup_anndata( ObsmField("system", "system"), ] # Covariate fields are optional - if ( - continuous_covariate_keys is not None - or categorical_covariate_keys is not None - ): + if continuous_covariate_keys is not None or categorical_covariate_keys is not None: anndata_fields.append(ObsmField("covariates", "covariates")) if categorical_covariate_embed_keys is not None: anndata_fields.append(ObsmField("covariates_embed", "covariates_embed")) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/scvi/external/sysvi/_module.py b/scvi/external/sysvi/_module.py index 12be6adaae..27a8e4ceda 100644 --- a/scvi/external/sysvi/_module.py +++ b/scvi/external/sysvi/_module.py @@ -87,10 +87,7 @@ def __init__( if self.embed_cov: self.cov_embeddings = torch.nn.ModuleList( - [ - Embedding(size=size, cov_embed_dims=cov_embed_dims) - for size in cov_embed_sizes - ] + [Embedding(size=size, cov_embed_dims=cov_embed_dims) for size in cov_embed_sizes] ) self.encoder = EncoderDecoder( @@ -300,9 +297,7 @@ def forward( get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) # Inference - inference_inputs = self._get_inference_input( - tensors, **get_inference_input_kwargs - ) + inference_inputs = self._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = self.inference(**inference_inputs, **inference_kwargs) # Generative selected_system = self.random_select_systems(tensors["system"]) @@ -322,9 +317,7 @@ def forward( selected_system=selected_system, **get_inference_input_kwargs, ) - inference_cycle_outputs = self.inference( - **inference_cycle_inputs, **inference_kwargs - ) + inference_cycle_outputs = self.inference(**inference_cycle_inputs, **inference_kwargs) # Combine outputs of all forward pass components inference_outputs_merged = dict(**inference_outputs) @@ -384,13 +377,11 @@ def z_dist(z_x_m: torch.Tensor, z_y_m: torch.Tensor): def standardize(x): return (x - means) / stds - return torch.nn.MSELoss(reduction="none")( - standardize(z_x_m), standardize(z_y_m) - ).sum(dim=1) + return torch.nn.MSELoss(reduction="none")(standardize(z_x_m), standardize(z_y_m)).sum( + dim=1 + ) - z_distance_cyc = z_dist( - z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"] - ) + z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) loss = ( reconst_loss * reconstruction_weight From c5f5c379fe5a3c6d360a6009c3de1c51d02543ba Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:20:19 -0700 Subject: [PATCH 07/60] Update scvi/external/sysvi/_base_components.py --- scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_base_components.py b/scvi/external/sysvi/_base_components.py index 4cb545795a..86f074edc0 100644 --- a/scvi/external/sysvi/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -30,7 +30,7 @@ class Embedding(Module): Apply layer normalization """ - def __init__(self, size, cov_embed_dims: int = 10, normalize: bool = True): + def __init__(self, size: int, cov_embed_dims: int = 10, normalize: bool = True): super().__init__() self.normalize = normalize From 5b4838cda7ba20c2b52bd007b8c07bcded1563e0 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:20:27 -0700 Subject: [PATCH 08/60] Update scvi/external/sysvi/_base_components.py --- scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_base_components.py b/scvi/external/sysvi/_base_components.py index 86f074edc0..7eee13962c 100644 --- a/scvi/external/sysvi/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -108,7 +108,7 @@ def __init__( self.mean_encoder = Linear(n_hidden, n_output) self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) - def forward(self, x, cov: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): y = self.decoder_y(x=x, cov=cov) # TODO better handling of inappropriate edge-case values than nan_to_num or at least warn y_m = torch.nan_to_num(self.mean_encoder(y)) From c885e20d7ca34af26e42bd0df9e73a686cbaf735 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:20:33 -0700 Subject: [PATCH 09/60] Update scvi/external/sysvi/_base_components.py --- scvi/external/sysvi/_base_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scvi/external/sysvi/_base_components.py b/scvi/external/sysvi/_base_components.py index 7eee13962c..46c39cf0b9 100644 --- a/scvi/external/sysvi/_base_components.py +++ b/scvi/external/sysvi/_base_components.py @@ -116,7 +116,6 @@ def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): outputs = {"y_m": y_m, "y_v": y_v} - # Sample from latent distribution if self.sample: y = Normal(y_m, y_v.sqrt()).rsample() outputs["y"] = y From 661bbc67915c9f4227a5952e5360cd139b5ec977 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:20:51 -0700 Subject: [PATCH 10/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index ca17264f4f..bed74e0697 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -7,7 +7,7 @@ import pandas as pd import torch from anndata import AnnData -from typing_extensions import Literal +from typing import Literal from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager From 9a49d245be2e47147c93aa95d5a0415cbe0b4085 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:21:18 -0700 Subject: [PATCH 11/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index bed74e0697..de65b37ffa 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -43,7 +43,7 @@ class SysVI(TrainingCustom, BaseModelClass): Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi-tools.SysVI.setup_anndata`. + AnnData object that has been registered via :meth:`~scvi.external.SysVI.setup_anndata`. prior The prior distribution to be used. You can choose between "standard_normal" and "vamp". n_prior_components From 3622eeebfd94d5fa77f23f725d32b4834f036b2e Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:21:30 -0700 Subject: [PATCH 12/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index de65b37ffa..3a62308164 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -52,7 +52,7 @@ class SysVI(TrainingCustom, BaseModelClass): By default VampPrior pseudoinputs are randomly selected from data. Alternatively, one can specify pseudoinput indices using this parameter. **model_kwargs - Keyword args for :class:`~scvi.external.sysvi.module.Module` + Keyword args for :class:`~scvi.external.sysvi.SysVAE` """ def __init__( From e4c1ef92d03246d1d76b05633cc7761f5252a32c Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:21:40 -0700 Subject: [PATCH 13/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index 3a62308164..eb982bb709 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -59,7 +59,7 @@ def __init__( self, adata: AnnData, prior: Literal["standard_normal", "vamp"] = "vamp", - n_prior_components=5, + n_prior_components: int = 5, pseudoinputs_data_indices: np.array | None = None, **model_kwargs, ): From 0f7bd06c29f452e5867fe96a171ea5143020d63a Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:21:55 -0700 Subject: [PATCH 14/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index eb982bb709..a5d0e2df80 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -110,7 +110,7 @@ def __init__( logger.info("The model has been initialized") - @torch.no_grad() + @torch.inference_mode() def get_latent_representation( self, adata: AnnData, From 9e0cba9c04fa739f6de804988464497da5183a1e Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:22:07 -0700 Subject: [PATCH 15/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index a5d0e2df80..25e31c52cf 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -146,7 +146,6 @@ def get_latent_representation( adata = self._validate_anndata(adata) if indices is None: indices = np.arange(adata.n_obs) - # Prediction # Do not shuffle to retain order tensors_fwd = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size, shuffle=False From 54f5734156f6b5486014820ceffc9a1750d0a062 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Fri, 15 Mar 2024 15:23:12 -0700 Subject: [PATCH 16/60] Update scvi/external/sysvi/_model.py --- scvi/external/sysvi/_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index 25e31c52cf..d795407674 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -90,7 +90,6 @@ def __init__( else [] ) - # self.summary_stats provides information about anndata dimensions and other tensor info self.module = SysVAE( n_input=adata.shape[1], n_cov_const=n_cov_const, From f65c403c09c50b3fae89d49ee96562e8975b19ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Sep 2024 11:45:50 +0000 Subject: [PATCH 17/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/external/sysvi/_model.py b/scvi/external/sysvi/_model.py index d795407674..2156d88d2a 100644 --- a/scvi/external/sysvi/_model.py +++ b/scvi/external/sysvi/_model.py @@ -2,12 +2,12 @@ import logging from collections.abc import Sequence +from typing import Literal import numpy as np import pandas as pd import torch from anndata import AnnData -from typing import Literal from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager From 5fcb6f66bb5f5f73ee7e3961e8002711d1bd8037 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 13:46:14 +0200 Subject: [PATCH 18/60] merge --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 2f9c2ac012..3d913fce03 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 2f9c2ac012942f3478405c2d489c4334abc1e22f +Subproject commit 3d913fce03a15ac42f46844840cd831e9b29d8ab From 3c93e7f97bae53b69006210fdadc6ece5f62d077 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 13:48:12 +0200 Subject: [PATCH 19/60] merge --- src/__init__.py | 0 {scvi => src/scvi}/external/sysvi/__init__.py | 0 {scvi => src/scvi}/external/sysvi/_base_components.py | 0 {scvi => src/scvi}/external/sysvi/_model.py | 0 {scvi => src/scvi}/external/sysvi/_module.py | 0 {scvi => src/scvi}/external/sysvi/_priors.py | 0 {scvi => src/scvi}/external/sysvi/_trainingplans.py | 0 {scvi => src/scvi}/external/sysvi/_utils.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/__init__.py rename {scvi => src/scvi}/external/sysvi/__init__.py (100%) rename {scvi => src/scvi}/external/sysvi/_base_components.py (100%) rename {scvi => src/scvi}/external/sysvi/_model.py (100%) rename {scvi => src/scvi}/external/sysvi/_module.py (100%) rename {scvi => src/scvi}/external/sysvi/_priors.py (100%) rename {scvi => src/scvi}/external/sysvi/_trainingplans.py (100%) rename {scvi => src/scvi}/external/sysvi/_utils.py (100%) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scvi/external/sysvi/__init__.py b/src/scvi/external/sysvi/__init__.py similarity index 100% rename from scvi/external/sysvi/__init__.py rename to src/scvi/external/sysvi/__init__.py diff --git a/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py similarity index 100% rename from scvi/external/sysvi/_base_components.py rename to src/scvi/external/sysvi/_base_components.py diff --git a/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py similarity index 100% rename from scvi/external/sysvi/_model.py rename to src/scvi/external/sysvi/_model.py diff --git a/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py similarity index 100% rename from scvi/external/sysvi/_module.py rename to src/scvi/external/sysvi/_module.py diff --git a/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py similarity index 100% rename from scvi/external/sysvi/_priors.py rename to src/scvi/external/sysvi/_priors.py diff --git a/scvi/external/sysvi/_trainingplans.py b/src/scvi/external/sysvi/_trainingplans.py similarity index 100% rename from scvi/external/sysvi/_trainingplans.py rename to src/scvi/external/sysvi/_trainingplans.py diff --git a/scvi/external/sysvi/_utils.py b/src/scvi/external/sysvi/_utils.py similarity index 100% rename from scvi/external/sysvi/_utils.py rename to src/scvi/external/sysvi/_utils.py From 3e8ffd0471391ff6531e1608e92a2109ac0e03c7 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:39:56 +0200 Subject: [PATCH 20/60] extend var documentation and remove unused "linear" mode --- src/scvi/external/sysvi/_base_components.py | 31 +++++---------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 46c39cf0b9..99f5ac007b 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -69,10 +69,10 @@ class EncoderDecoder(Module): The number of fully-connected hidden layers n_layers Number of hidden layers - var_eps - See :class:`~scvi.external.sysvi.VarEncoder` var_mode - See :class:`~scvi.external.sysvi.VarEncoder` + How to compute variance from model outputs, see :class:`~scvi.external.sysvi.VarEncoder` + 'sample_feature' - learn per sample and feature + 'feature' - learn per feature, constant across samples sample Return samples from predicted distribution kwargs @@ -86,7 +86,6 @@ def __init__( n_cov: int, n_hidden: int = 256, n_layers: int = 3, - var_eps: float = 1e-4, var_mode: Literal["sample_feature", "feature", "linear"] = "feature", sample: bool = False, **kwargs, @@ -94,8 +93,6 @@ def __init__( super().__init__() self.sample = sample - self.var_eps = var_eps - self.decoder_y = Layers( n_in=n_input, n_cov=n_cov, @@ -106,7 +103,7 @@ def __init__( ) self.mean_encoder = Linear(n_hidden, n_output) - self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps) + self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode) def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): y = self.decoder_y(x=x, cov=cov) @@ -216,7 +213,7 @@ def set_online_update_hooks(self, hook_first_layer=True): def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] + new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] return new_grad def _hook_fn_zero_out(grad): @@ -285,9 +282,6 @@ class VarEncoder(Module): How to compute var 'sample_feature' - learn per sample and feature 'feature' - learn per feature, constant across samples - 'linear' - linear with respect to input mean, var = a1 * mean + a0; - not suggested to be used due to bad implementation for positive constraining - eps """ def __init__( @@ -295,32 +289,26 @@ def __init__( n_input: int, n_output: int, mode: Literal["sample_feature", "feature", "linear"], - eps: float = 1e-4, ): super().__init__() - self.eps = eps + self.eps = 1e-4 self.mode = mode if self.mode == "sample_feature": self.encoder = Linear(n_input, n_output) elif self.mode == "feature": self.var_param = Parameter(torch.zeros(1, n_output)) - elif self.mode == "linear": - self.var_param_a1 = Parameter(torch.tensor([1.0])) - self.var_param_a0 = Parameter(torch.tensor([self.eps])) else: raise ValueError("Mode not recognised.") self.activation = torch.exp - def forward(self, x: torch.Tensor, x_m: torch.Tensor): + def forward(self, x: torch.Tensor): """Forward pass through model Parameters ---------- x Used to encode var if mode is sample_feature; dim = n_samples x n_input - x_m - Used to predict var instead of x if mode is linear; dim = n_samples x 1 Returns ------- @@ -337,9 +325,4 @@ def forward(self, x: torch.Tensor, x_m: torch.Tensor): v = ( torch.nan_to_num(self.activation(v)) + self.eps ) # Ensure that var is strictly positive - elif self.mode == "linear": - v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0 - # TODO come up with a better way to constrain this to positive while having lin relationship - # Could activation be used for log-lin relationship? - v = torch.clamp(torch.nan_to_num(v), min=self.eps) return v From 4ce461474be26c3e53520dbbf959dfe904b82ce2 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:41:56 +0200 Subject: [PATCH 21/60] update var documentation --- src/scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 99f5ac007b..dde9ea0fd4 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -86,7 +86,7 @@ def __init__( n_cov: int, n_hidden: int = 256, n_layers: int = 3, - var_mode: Literal["sample_feature", "feature", "linear"] = "feature", + var_mode: Literal["sample_feature", "feature"] = "feature", sample: bool = False, **kwargs, ): From f81abdaf2c389e732f3ecc6eb6c8f0e38e0f0354 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Sep 2024 12:42:09 +0000 Subject: [PATCH 22/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index dde9ea0fd4..c81b95b0b1 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -213,7 +213,7 @@ def set_online_update_hooks(self, hook_first_layer=True): def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] + new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] return new_grad def _hook_fn_zero_out(grad): From bdd2c7a75721451a93f05989cd01c9d1113eab96 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:47:01 +0200 Subject: [PATCH 23/60] var activation to softplus --- src/scvi/external/sysvi/_base_components.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index dde9ea0fd4..264b73617c 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -300,7 +300,7 @@ def __init__( self.var_param = Parameter(torch.zeros(1, n_output)) else: raise ValueError("Mode not recognised.") - self.activation = torch.exp + self.activation = torch.nn.Softplus() def forward(self, x: torch.Tensor): """Forward pass through model @@ -317,12 +317,8 @@ def forward(self, x: torch.Tensor): # Force to be non nan - TODO come up with better way to do so if self.mode == "sample_feature": v = self.encoder(x) - v = ( - torch.nan_to_num(self.activation(v)) + self.eps - ) # Ensure that var is strictly positive + v = (self.activation(v) + self.eps) # Ensure that var is strictly positive elif self.mode == "feature": v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size - v = ( - torch.nan_to_num(self.activation(v)) + self.eps - ) # Ensure that var is strictly positive + v = (self.activation(v) + self.eps) # Ensure that var is strictly positive return v From 45076024f18d2a43fb263ce5eba7a70ba3e51e0f Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:09:11 +0200 Subject: [PATCH 24/60] clarify pseudoinputs_data_indices description and assert it matches specified n_prior_components number --- src/scvi/external/sysvi/_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 2156d88d2a..fefd96690f 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -47,10 +47,11 @@ class SysVI(TrainingCustom, BaseModelClass): prior The prior distribution to be used. You can choose between "standard_normal" and "vamp". n_prior_components - Number of prior components in VampPrior. + Number of prior components (i.e. modes) to use in VampPrior. pseudoinputs_data_indices By default VampPrior pseudoinputs are randomly selected from data. Alternatively, one can specify pseudoinput indices using this parameter. + The number of specified indices in the input 1D array should match n_prior_components **model_kwargs Keyword args for :class:`~scvi.external.sysvi.SysVAE` """ @@ -70,6 +71,7 @@ def __init__( pseudoinputs_data_indices = np.random.randint( 0, adata.shape[0], n_prior_components ) + assert pseudoinputs_data_indices.shape[0] == n_prior_components pseudoinput_data = next( iter( self._make_data_loader( From cc19d91bf21d3911bb128efc1fb401e8ceaf013e Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:11:01 +0200 Subject: [PATCH 25/60] also check ndim for pseudoinputs_data_indices --- src/scvi/external/sysvi/_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index fefd96690f..f77d4b9113 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -72,6 +72,7 @@ def __init__( 0, adata.shape[0], n_prior_components ) assert pseudoinputs_data_indices.shape[0] == n_prior_components + assert pseudoinputs_data_indices.ndim == 1 pseudoinput_data = next( iter( self._make_data_loader( From c15b8a24fe34ea1254bc75056683c429f0ea1cf0 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:13:36 +0200 Subject: [PATCH 26/60] remove obtainng cycle latent representations in user interface --- src/scvi/external/sysvi/_model.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index f77d4b9113..a2747db4ff 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -117,7 +117,6 @@ def get_latent_representation( self, adata: AnnData, indices: Sequence[int] | None = None, - cycle: bool = False, give_mean: bool = True, batch_size: int | None = None, as_numpy: bool = True, @@ -130,8 +129,6 @@ def get_latent_representation( Input adata for which latent representation should be obtained. indices Data indices to embed. If None embedd all cells. - cycle - Return latent embedding of the cycle pass. give_mean Return the posterior mean instead of a sample from the posterior. batch_size @@ -157,22 +154,6 @@ def get_latent_representation( # Inference inference_inputs = self.module._get_inference_input(tensors) inference_outputs = self.module.inference(**inference_inputs) - if cycle: - selected_system = self.module.random_select_systems(tensors["system"]) - generative_inputs = self.module._get_generative_input( - tensors, - inference_outputs, - selected_system=selected_system, - ) - generative_outputs = self.module.generative( - **generative_inputs, x_x=False, x_y=True - ) - inference_cycle_inputs = self.module._get_inference_cycle_input( - tensors=tensors, - generative_outputs=generative_outputs, - selected_system=selected_system, - ) - inference_outputs = self.module.inference(**inference_cycle_inputs) if give_mean: predicted += [inference_outputs["z_m"]] else: From aacdafe399567de369c3f6be839888978f673a31 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:25:21 +0200 Subject: [PATCH 27/60] latent always returns np and add option to return_dist --- src/scvi/external/sysvi/_model.py | 34 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index a2747db4ff..f8c3a03336 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import Literal +from typing import Literal, Tuple import numpy as np import pandas as pd @@ -119,8 +119,8 @@ def get_latent_representation( indices: Sequence[int] | None = None, give_mean: bool = True, batch_size: int | None = None, - as_numpy: bool = True, - ) -> np.ndarray | torch.Tensor: + return_dist: bool = False, + ) -> np.ndarray | Tuple[np.ndarray, np.ndarray]: """Return the latent representation for each cell. Parameters @@ -131,11 +131,12 @@ def get_latent_representation( Data indices to embed. If None embedd all cells. give_mean Return the posterior mean instead of a sample from the posterior. + Ignored if `return_dist` is ``True``. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - as_numpy - Return in numpy rather than torch format. - + return_dist + If ``True``, returns the mean and variance of the latent distribution. Otherwise, + returns the mean of the latent distribution. Returns ------- Latent Embedding @@ -149,21 +150,26 @@ def get_latent_representation( tensors_fwd = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size, shuffle=False ) - predicted = [] + predicted_m = [] + predicted_v = [] for tensors in tensors_fwd: # Inference inference_inputs = self.module._get_inference_input(tensors) inference_outputs = self.module.inference(**inference_inputs) - if give_mean: - predicted += [inference_outputs["z_m"]] + if give_mean or return_dist: + predicted_m += [inference_outputs["z_m"]] else: - predicted += [inference_outputs["z"]] + predicted_m += [inference_outputs["z"]] + if return_dist: + predicted_v += [inference_outputs["z_v"]] - predicted = torch.cat(predicted).cpu() + predicted_m = torch.cat(predicted_m).cpu().numpy() + predicted_v = torch.cat(predicted_v).cpu().numpy() - if as_numpy: - predicted = predicted.numpy() - return predicted + if return_dist: + return predicted_m, predicted_v + else: + return predicted_m def _validate_anndata( self, adata: AnnData | None = None, copy_if_view: bool = True From de6db132dbcda990536b6a51d29414462a1e1d2a Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:35:17 +0200 Subject: [PATCH 28/60] bugfix --- src/scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 264b73617c..6b7f747b42 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -109,7 +109,7 @@ def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): y = self.decoder_y(x=x, cov=cov) # TODO better handling of inappropriate edge-case values than nan_to_num or at least warn y_m = torch.nan_to_num(self.mean_encoder(y)) - y_v = self.var_encoder(y, x_m=y_m) + y_v = self.var_encoder(y) outputs = {"y_m": y_m, "y_v": y_v} From 235767c1b47b5e719472b0072f997ade1f94e7e0 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:35:27 +0200 Subject: [PATCH 29/60] rm unused cycle latent retrieval --- tests/external/sysvi/test_model.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 1bc074f1fe..15236515e7 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -172,14 +172,4 @@ def test_model(): ), ) - # Check predicting cycle - with assert_raises(AssertionError): - np.testing.assert_allclose( - embed, - model.get_latent_representation( - adata=adata, - indices=idx, - give_mean=True, - cycle=True, - ), - ) + From 5f04d654d0aa15b47feecc50463610c68264db63 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:36:20 +0200 Subject: [PATCH 30/60] bugfix --- tests/external/sysvi/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 15236515e7..2237adcb59 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -99,7 +99,7 @@ def test_model(): _ = SysVI( adata=adata, prior="vamp", - pseudoinputs_data_indices=list(range(5)), + pseudoinputs_data_indices=np.array(list(range(5))), n_prior_components=5, ) From 69759ec43543f2bd788324a038684c50182ec04e Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:36:51 +0200 Subject: [PATCH 31/60] rm adata validation parts repeated in super --- src/scvi/external/sysvi/_model.py | 44 +++++++++++++------------------ 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index f8c3a03336..5f20f455c2 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -187,34 +187,26 @@ def _validate_anndata( ------- """ "" - if adata is None: - adata = self.adata - - _check_if_view(adata, copy_if_view=copy_if_view) + super()._validate_anndata(adata) - if _SCVI_UUID_KEY not in adata.uns: - raise ValueError("Adata is not set up. Use SysVI.setup_anndata first.") - else: - # Check that all required fields are present and match the Model's adata - assert ( - self.adata.uns["layer_information"]["layer"] - == adata.uns["layer_information"]["layer"] - ) - assert ( - self.adata.uns["layer_information"]["var_names"] - == adata.uns["layer_information"]["var_names"] - ) - assert self.adata.uns["system_order"] == adata.uns["system_order"] - for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): - assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] - if "categorical" in covariate_type: - for covariate_key in covariate_keys: - assert ( - self.adata.uns["covariate_categ_orders"][covariate_key] - == adata.uns["covariate_categ_orders"][covariate_key] + # Check that all required fields are present and match the Model's adata + assert ( + self.adata.uns["layer_information"]["layer"] + == adata.uns["layer_information"]["layer"] + ) + assert ( + self.adata.uns["layer_information"]["var_names"] + == adata.uns["layer_information"]["var_names"] + ) + assert self.adata.uns["system_order"] == adata.uns["system_order"] + for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): + assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] + if "categorical" in covariate_type: + for covariate_key in covariate_keys: + assert ( + self.adata.uns["covariate_categ_orders"][covariate_key] + == adata.uns["covariate_categ_orders"][covariate_key] ) - # Ensures that manager is set up - super()._validate_anndata(adata) return adata From 2e447c833dfe263264dda1b06d46781f95d7840c Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:37:46 +0200 Subject: [PATCH 32/60] bugfix --- src/scvi/external/sysvi/_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 5f20f455c2..c94d968625 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -164,7 +164,8 @@ def get_latent_representation( predicted_v += [inference_outputs["z_v"]] predicted_m = torch.cat(predicted_m).cpu().numpy() - predicted_v = torch.cat(predicted_v).cpu().numpy() + if return_dist: + predicted_v = torch.cat(predicted_v).cpu().numpy() if return_dist: return predicted_m, predicted_v From 9e8cc35cfcec449c2874f04d3383cbd7d2432376 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:46:47 +0200 Subject: [PATCH 33/60] put back original _validate_anndata --- src/scvi/external/sysvi/_model.py | 48 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index c94d968625..38ae8dcbe9 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -175,7 +175,7 @@ def get_latent_representation( def _validate_anndata( self, adata: AnnData | None = None, copy_if_view: bool = True ) -> AnnData: - """Validate anndata has been properly registered" + """Validate anndata has been properly registered Parameters ---------- @@ -187,27 +187,35 @@ def _validate_anndata( Returns ------- - """ "" - super()._validate_anndata(adata) + """ + if adata is None: + adata = self.adata - # Check that all required fields are present and match the Model's adata - assert ( - self.adata.uns["layer_information"]["layer"] - == adata.uns["layer_information"]["layer"] - ) - assert ( - self.adata.uns["layer_information"]["var_names"] - == adata.uns["layer_information"]["var_names"] - ) - assert self.adata.uns["system_order"] == adata.uns["system_order"] - for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): - assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] - if "categorical" in covariate_type: - for covariate_key in covariate_keys: - assert ( - self.adata.uns["covariate_categ_orders"][covariate_key] - == adata.uns["covariate_categ_orders"][covariate_key] + _check_if_view(adata, copy_if_view=copy_if_view) + + if _SCVI_UUID_KEY not in adata.uns: + raise ValueError("Adata is not set up. Use SysVI.setup_anndata first.") + else: + # Check that all required fields are present and match the Model's adata + assert ( + self.adata.uns["layer_information"]["layer"] + == adata.uns["layer_information"]["layer"] + ) + assert ( + self.adata.uns["layer_information"]["var_names"] + == adata.uns["layer_information"]["var_names"] + ) + assert self.adata.uns["system_order"] == adata.uns["system_order"] + for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): + assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] + if "categorical" in covariate_type: + for covariate_key in covariate_keys: + assert ( + self.adata.uns["covariate_categ_orders"][covariate_key] + == adata.uns["covariate_categ_orders"][covariate_key] ) + # Ensures that manager is set up + super()._validate_anndata(adata) return adata From b0829c0a30bb3b659053a42ab7fb27e95e03ebe8 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:51:47 +0200 Subject: [PATCH 34/60] remove too many custom checks from adata validation --- src/scvi/external/sysvi/_model.py | 46 ++++++++++++------------------ tests/external/sysvi/test_model.py | 2 +- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 38ae8dcbe9..edf683f338 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -188,35 +188,27 @@ def _validate_anndata( ------- """ - if adata is None: - adata = self.adata - - _check_if_view(adata, copy_if_view=copy_if_view) - - if _SCVI_UUID_KEY not in adata.uns: - raise ValueError("Adata is not set up. Use SysVI.setup_anndata first.") - else: - # Check that all required fields are present and match the Model's adata - assert ( - self.adata.uns["layer_information"]["layer"] - == adata.uns["layer_information"]["layer"] - ) - assert ( - self.adata.uns["layer_information"]["var_names"] - == adata.uns["layer_information"]["var_names"] - ) - assert self.adata.uns["system_order"] == adata.uns["system_order"] - for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): - assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] - if "categorical" in covariate_type: - for covariate_key in covariate_keys: - assert ( - self.adata.uns["covariate_categ_orders"][covariate_key] - == adata.uns["covariate_categ_orders"][covariate_key] - ) - # Ensures that manager is set up super()._validate_anndata(adata) + # Check that all required fields are present and match the Model's adata + assert ( + self.adata.uns["layer_information"]["layer"] + == adata.uns["layer_information"]["layer"] + ) + assert ( + self.adata.uns["layer_information"]["var_names"] + == adata.uns["layer_information"]["var_names"] + ) + assert self.adata.uns["system_order"] == adata.uns["system_order"] + for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): + assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] + if "categorical" in covariate_type: + for covariate_key in covariate_keys: + assert ( + self.adata.uns["covariate_categ_orders"][covariate_key] + == adata.uns["covariate_categ_orders"][covariate_key] + ) + return adata @classmethod diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 2237adcb59..631fdbd6da 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -140,7 +140,7 @@ def test_model(): # Ensure that embedding with another adata properly checks if it was setu up correctly _ = model.get_latent_representation(adata=adata0) - with assert_raises(AssertionError): + with assert_raises(KeyError): # TODO could add more check for each property separately _ = model.get_latent_representation(adata=adata_no_cov) From bf2e8507f7c21de1913bb15d8d80f666ffa47b8f Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:09:46 +0200 Subject: [PATCH 35/60] covariate type explanation --- src/scvi/external/sysvi/_model.py | 44 +++++++++++++++++++------------ 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index edf683f338..2c520d48ab 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -199,7 +199,7 @@ def _validate_anndata( self.adata.uns["layer_information"]["var_names"] == adata.uns["layer_information"]["var_names"] ) - assert self.adata.uns["system_order"] == adata.uns["system_order"] + assert self.adata.uns["batch_order"] == adata.uns["batch_order"] for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] if "categorical" in covariate_type: @@ -223,11 +223,21 @@ def setup_anndata( continuous_covariate_keys: list[str] | None = None, covariate_categ_orders: dict | None = None, covariate_key_orders: dict | None = None, - system_order: list[str] | None = None, + batch_order: list[str] | None = None, **kwargs, ): """Prepare adata for input to Model + Setup distinguishes between two main types of covariates that can be corrected for: + - batch (referred to as "system" in the original publication): Single categorical covariate that + will be corrected via cycle consistency loss. It will be also used as a condition in cVAE. + This covariate is expected to correspond to stronger batch effects, such as between datasets from the + different sequencing technology or model systems (animal species, in-vitro models, etc.). + - covariate (includes both continous and categorical covariates): Additional covariates to be used only + as a condition in cVAE, but not corrected via cycle loss. + These covariates are expected to correspond to weaker batch effects, such as between datasets from the + same sequencing technology and model system (animal, in-vitro, etc.) or between samples within a dataset. + Parameters ---------- adata @@ -254,8 +264,8 @@ def setup_anndata( covariate_key_orders Covariate encoding information. Should be used if a new adata is to be set up according to setup of an existing adata. Access via adata.uns['covariate_key_orders'] of already setup adata. - system_order - Same as covariate_orders, but for system. Access via adata.uns['system_order'] + batch_order + Same as covariate_orders, but for system. Access via adata.uns['batch_order'] """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -270,8 +280,8 @@ def setup_anndata( } # If setup is to be prepared wtr another adata specs make sure all relevant info is present - if covariate_categ_orders or covariate_key_orders or system_order: - assert system_order is not None + if covariate_categ_orders or covariate_key_orders or batch_order: + assert batch_order is not None if ( categorical_covariate_keys is not None or categorical_covariate_embed_keys is not None @@ -283,19 +293,19 @@ def setup_anndata( # Make system embedding with specific category order # Define order of system categories - if system_order is None: - system_order = sorted(adata.obs[batch_key].unique()) - # Validate that the provided system_order matches the categories in adata.obs[system_key] - if set(system_order) != set(adata.obs[batch_key].unique()): + if batch_order is None: + batch_order = sorted(adata.obs[batch_key].unique()) + # Validate that the provided batch_order matches the categories in adata.obs[batch_key] + if set(batch_order) != set(adata.obs[batch_key].unique()): raise ValueError( - "Provided system_order does not match the categories in adata.obs[system_key]" + "Provided batch_order does not match the categories in adata.obs[batch_key]" ) # Make one-hot embedding with specified order - systems_dict = dict(zip(system_order, ([float(i) for i in range(0, len(system_order))]))) - adata.uns["system_order"] = system_order + systems_dict = dict(zip(batch_order, ([float(i) for i in range(0, len(batch_order))]))) + adata.uns["batch_order"] = batch_order system_cat = pd.Series( - pd.Categorical(values=adata.obs[batch_key], categories=system_order, ordered=True), + pd.Categorical(values=adata.obs[batch_key], categories=batch_order, ordered=True), index=adata.obs.index, name="system", ) @@ -307,13 +317,13 @@ def setup_anndata( # System must not be in cov if categorical_covariate_keys is not None: if batch_key in categorical_covariate_keys: - raise ValueError("system_key should not be within covariate keys") + raise ValueError("batch_key should not be within covariate keys") if categorical_covariate_embed_keys is not None: if batch_key in categorical_covariate_embed_keys: - raise ValueError("system_key should not be within covariate keys") + raise ValueError("batch_key should not be within covariate keys") if continuous_covariate_keys is not None: if batch_key in continuous_covariate_keys: - raise ValueError("system_key should not be within covariate keys") + raise ValueError("batch_key should not be within covariate keys") # Prepare covariate training representations/embedding covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata( From bd5a8821ed0309c61401eba214fd76b44fc41547 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:26:09 +0200 Subject: [PATCH 36/60] revert var exp --- src/scvi/external/sysvi/_base_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 6b7f747b42..5241c5672f 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -300,7 +300,7 @@ def __init__( self.var_param = Parameter(torch.zeros(1, n_output)) else: raise ValueError("Mode not recognised.") - self.activation = torch.nn.Softplus() + self.activation = torch.exp def forward(self, x: torch.Tensor): """Forward pass through model From 7f503533da0ce534ddf2aa5c1d1103d2246a4222 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:29:48 +0200 Subject: [PATCH 37/60] bugfix --- tests/external/sysvi/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 631fdbd6da..32d57beab2 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -73,7 +73,7 @@ def test_model(): continuous_covariate_keys=["covariate_cont"], covariate_categ_orders=adata0.uns["covariate_categ_orders"], covariate_key_orders=adata0.uns["covariate_key_orders"], - system_order=adata0.uns["system_order"], + batch_order=adata0.uns["batch_order"], ) # Check that setup of adata without covariates works From 01db60a44c855033d30bf03d9a974dbcb3efe86e Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:43:26 +0200 Subject: [PATCH 38/60] rm reconstr loss f that not reused --- src/scvi/external/sysvi/_module.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 27a8e4ceda..3739d9fbb3 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -349,15 +349,9 @@ def loss( x_true = tensors[REGISTRY_KEYS.X_KEY] # Reconstruction loss + reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( + generative_outputs["x_m"], x_true, generative_outputs["x_v"]).sum(dim=1) - def reconst_loss_part(x_m, x, x_v): - """Compute reconstruction loss""" - return torch.nn.GaussianNLLLoss(reduction="none")(x_m, x, x_v).sum(dim=1) - - # Reconstruction loss - reconst_loss_x = reconst_loss_part( - x_m=generative_outputs["x_m"], x=x_true, x_v=generative_outputs["x_v"] - ) reconst_loss = reconst_loss_x # Kl divergence on latent space From 6cdd3d6e48b4d2986ecfd87abbb37f8fdb648206 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:47:16 +0200 Subject: [PATCH 39/60] explain why cycle loss computed on standardized values --- src/scvi/external/sysvi/_module.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 3739d9fbb3..e497aee113 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -362,7 +362,11 @@ def loss( ) def z_dist(z_x_m: torch.Tensor, z_y_m: torch.Tensor): - """MSE loss between standardised inputs with standardizer fitted on concatenation of both inputs""" + """MSE loss between standardised inputs with standardizer fitted on concatenation of both inputs + + MSE loss should be computed on standardized latent values as else model can learn to cheat the MSE + loss by putting latent parameters to even smaller numbers. + """ # Standardise data (jointly both z-s) before MSE calculation z = torch.concat([z_x_m, z_y_m]) means = z.mean(dim=0, keepdim=True) From 62eb924c0045f885f4a168122f60af18c9301a74 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 14 Sep 2024 16:48:26 +0200 Subject: [PATCH 40/60] add return statement --- src/scvi/external/sysvi/_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index e497aee113..00a82d54ae 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -345,7 +345,7 @@ def loss( kl_weight: float = 1.0, reconstruction_weight: float = 1.0, z_distance_cycle_weight: float = 2.0, - ): + ) -> LossOutput: x_true = tensors[REGISTRY_KEYS.X_KEY] # Reconstruction loss From b95cb03e9e55e5a6666c894918f91d0fa1671a35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Sep 2024 15:16:23 +0000 Subject: [PATCH 41/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_base_components.py | 4 ++-- src/scvi/external/sysvi/_model.py | 6 ++---- src/scvi/external/sysvi/_module.py | 3 ++- tests/external/sysvi/test_model.py | 2 -- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index b556636652..ed4cfe214e 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -317,8 +317,8 @@ def forward(self, x: torch.Tensor): # Force to be non nan - TODO come up with better way to do so if self.mode == "sample_feature": v = self.encoder(x) - v = (self.activation(v) + self.eps) # Ensure that var is strictly positive + v = self.activation(v) + self.eps # Ensure that var is strictly positive elif self.mode == "feature": v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size - v = (self.activation(v) + self.eps) # Ensure that var is strictly positive + v = self.activation(v) + self.eps # Ensure that var is strictly positive return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 2c520d48ab..0da03933fe 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -11,8 +11,6 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data._constants import _SCVI_UUID_KEY -from scvi.data._utils import _check_if_view from scvi.data.fields import ( LayerField, ObsmField, @@ -137,6 +135,7 @@ def get_latent_representation( return_dist If ``True``, returns the mean and variance of the latent distribution. Otherwise, returns the mean of the latent distribution. + Returns ------- Latent Embedding @@ -192,8 +191,7 @@ def _validate_anndata( # Check that all required fields are present and match the Model's adata assert ( - self.adata.uns["layer_information"]["layer"] - == adata.uns["layer_information"]["layer"] + self.adata.uns["layer_information"]["layer"] == adata.uns["layer_information"]["layer"] ) assert ( self.adata.uns["layer_information"]["var_names"] diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 00a82d54ae..bec63c9fb3 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -350,7 +350,8 @@ def loss( # Reconstruction loss reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( - generative_outputs["x_m"], x_true, generative_outputs["x_v"]).sum(dim=1) + generative_outputs["x_m"], x_true, generative_outputs["x_v"] + ).sum(dim=1) reconst_loss = reconst_loss_x diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 32d57beab2..77fc539bb1 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -171,5 +171,3 @@ def test_model(): give_mean=False, ), ) - - From c049ea1a6d7c7a7a33cf8d9246351c9680f26219 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 19 Oct 2024 07:14:30 +0000 Subject: [PATCH 42/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_base_components.py | 4 +++- src/scvi/external/sysvi/_model.py | 8 +++++--- src/scvi/external/sysvi/_module.py | 3 ++- src/scvi/external/sysvi/_utils.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index ed4cfe214e..09c0569f45 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -197,7 +197,9 @@ def __init__( Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) - for i, (n_in, n_out) in enumerate(zip(layers_dim[:-1], layers_dim[1:])) + for i, (n_in, n_out) in enumerate( + zip(layers_dim[:-1], layers_dim[1:], strict=False) + ) ] ) ) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 0da03933fe..7277a76b64 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import Literal, Tuple +from typing import Literal import numpy as np import pandas as pd @@ -118,7 +118,7 @@ def get_latent_representation( give_mean: bool = True, batch_size: int | None = None, return_dist: bool = False, - ) -> np.ndarray | Tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Return the latent representation for each cell. Parameters @@ -300,7 +300,9 @@ def setup_anndata( ) # Make one-hot embedding with specified order - systems_dict = dict(zip(batch_order, ([float(i) for i in range(0, len(batch_order))]))) + systems_dict = dict( + zip(batch_order, ([float(i) for i in range(0, len(batch_order))]), strict=False) + ) adata.uns["batch_order"] = batch_order system_cat = pd.Series( pd.Categorical(values=adata.obs[batch_key], categories=batch_order, ordered=True), diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index bec63c9fb3..92ebfde80e 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -1,7 +1,8 @@ from __future__ import annotations +from typing import Literal + import torch -from typing_extensions import Literal from scvi import REGISTRY_KEYS from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data diff --git a/src/scvi/external/sysvi/_utils.py b/src/scvi/external/sysvi/_utils.py index 5eadcb8819..ad2e8bcab0 100644 --- a/src/scvi/external/sysvi/_utils.py +++ b/src/scvi/external/sysvi/_utils.py @@ -140,7 +140,7 @@ def dummies_categories(values: pd.Series, categories: list): cov_embed_data = [] for cov_cat_embed_key in cov_cat_embed_keys: cat_order = categ_orders[cov_cat_embed_key] - cat_map = dict(zip(cat_order, range(len(cat_order)))) + cat_map = dict(zip(cat_order, range(len(cat_order)), strict=False)) cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map)) cov_embed_data = pd.concat(cov_embed_data, axis=1) else: From 041c6a41b29d513f05b4f624819ea732c01590e4 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 20 Oct 2024 19:39:14 +0200 Subject: [PATCH 43/60] unify SysVI with scvi-tools code --- src/scvi/external/sysvi/__init__.py | 3 +- src/scvi/external/sysvi/_base_components.py | 200 ++++++++------- src/scvi/external/sysvi/_model.py | 224 ++++------------ src/scvi/external/sysvi/_module.py | 265 +++++++++++-------- src/scvi/external/sysvi/_priors.py | 56 ++-- src/scvi/external/sysvi/_trainingplans.py | 270 -------------------- src/scvi/external/sysvi/_utils.py | 149 ----------- tests/external/sysvi/test_model.py | 170 ++++++------ 8 files changed, 447 insertions(+), 890 deletions(-) delete mode 100644 src/scvi/external/sysvi/_trainingplans.py delete mode 100644 src/scvi/external/sysvi/_utils.py diff --git a/src/scvi/external/sysvi/__init__.py b/src/scvi/external/sysvi/__init__.py index 916b81ff2d..5bd0403a95 100644 --- a/src/scvi/external/sysvi/__init__.py +++ b/src/scvi/external/sysvi/__init__.py @@ -1,5 +1,4 @@ -from ._base_components import Layers, VarEncoder from ._model import SysVI from ._module import SysVAE -__all__ = ["SysVI", "VarEncoder", "Layers", "SysVAE"] +__all__ = ["SysVI", "SysVAE"] diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 09c0569f45..239cbfe532 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -1,9 +1,14 @@ from __future__ import annotations -from collections import OrderedDict +import collections +import warnings +from collections.abc import Iterable from typing import Literal +import numpy as np + import torch +from torch import nn from torch.distributions import Normal from torch.nn import ( BatchNorm1d, @@ -17,39 +22,6 @@ ) -class Embedding(Module): - """Module for obtaining embedding of categorical covariates - - Parameters - ---------- - size - N categories - cov_embed_dims - Dimensions of embedding - normalize - Apply layer normalization - """ - - def __init__(self, size: int, cov_embed_dims: int = 10, normalize: bool = True): - super().__init__() - - self.normalize = normalize - - self.embedding = torch.nn.Embedding(size, cov_embed_dims) - - if self.normalize: - # TODO this could probably be implemented more efficiently as embed gives same result for every sample in - # a give class. However, if we have many balanced classes there wont be many repetitions within minibatch - self.layer_norm = torch.nn.LayerNorm(cov_embed_dims, elementwise_affine=False) - - def forward(self, x): - x = self.embedding(x) - if self.normalize: - x = self.layer_norm(x) - - return x - - class EncoderDecoder(Module): """Module that can be used as probabilistic encoder or decoder @@ -83,7 +55,8 @@ def __init__( self, n_input: int, n_output: int, - n_cov: int, + n_cat_list: list[int], + n_cont: int, n_hidden: int = 256, n_layers: int = 3, var_mode: Literal["sample_feature", "feature"] = "feature", @@ -93,9 +66,10 @@ def __init__( super().__init__() self.sample = sample - self.decoder_y = Layers( + self.decoder_y = FCLayers( n_in=n_input, - n_cov=n_cov, + n_cat_list=n_cat_list, + n_cont=n_cont, n_out=n_hidden, n_hidden=n_hidden, n_layers=n_layers, @@ -105,10 +79,18 @@ def __init__( self.mean_encoder = Linear(n_hidden, n_output) self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode) - def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): - y = self.decoder_y(x=x, cov=cov) - # TODO better handling of inappropriate edge-case values than nan_to_num or at least warn - y_m = torch.nan_to_num(self.mean_encoder(y)) + def forward( + self, + x: torch.Tensor, + cont: torch.Tensor | None = None, + cat_list: list | None = None, + ) -> dict[str, torch.Tensor]: + + y = self.decoder_y(x=x, cont=cont, cat_list=cat_list) + y_m = self.mean_encoder(y) + if y_m.isnan().any() or y_m.isinf().any(): + warnings.warn('Predicted mean contains nan or inf values. Setting to numerical.') + y_m = torch.nan_to_num(y_m) y_v = self.var_encoder(y) outputs = {"y_m": y_m, "y_v": y_v} @@ -120,21 +102,29 @@ def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): return outputs -class Layers(Module): - """A helper class to build fully-connected layers for a neural network. +class FCLayers(nn.Module): + """FCLayers class of scvi-tools adapted to also inject continous covariates. + + The only adaptation is addition of `n_cont` parameter in init and `cont` in forward, + with the associated handling of the two. + The forward method signature is changed to account for optional `cont`. - Adapted from scVI FCLayers to use covariates more flexibly + A helper class to build fully-connected layers for a neural network. Parameters ---------- n_in - The dimensionality of the main input + The dimensionality of the input n_out The dimensionality of the output - n_cov - Dimensionality of covariates. - If there are no cov this should be set to None - - in this case cov will not be used. + n_cat_list + The number of categorical covariates and + the number of category levels. + A list containing, for each covariate of interest, + the number of categories. Each covariate will be + included using a one-hot encoding. + n_cont + The number of continuous covariates. n_layers The number of fully-connected hidden layers n_hidden @@ -150,7 +140,7 @@ class Layers(Module): bias Whether to learn bias in linear layers or not inject_covariates - Whether to inject covariates in each layer, or just the first. + Whether to inject covariates in each layer, or just the first (default). activation_fn Which activation function to use """ @@ -159,7 +149,8 @@ def __init__( self, n_in: int, n_out: int, - n_cov: int | None = None, + n_cat_list: Iterable[int] = None, + n_cont: int = 0, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, @@ -168,38 +159,45 @@ def __init__( use_activation: bool = True, bias: bool = True, inject_covariates: bool = True, - activation_fn: Module = ReLU, + activation_fn: nn.Module = nn.ReLU, ): super().__init__() - self.inject_covariates = inject_covariates - self.n_cov = n_cov if n_cov is not None else 0 - layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] - self.fc_layers = Sequential( - OrderedDict( + if n_cat_list is not None: + # n_cat = 1 will be ignored + self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list] + else: + self.n_cat_list = [] + + self.n_cov = sum(self.n_cat_list) + n_cont + self.fc_layers = nn.Sequential( + collections.OrderedDict( [ ( f"Layer {i}", - Sequential( - Linear( + nn.Sequential( + nn.Linear( n_in + self.n_cov * self.inject_into_layer(i), n_out, bias=bias, ), - # non-default params come from defaults in original Tensorflow implementation - BatchNorm1d(n_out, momentum=0.01, eps=0.001) + # non-default params come from defaults in original Tensorflow + # implementation + nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001) if use_batch_norm else None, - LayerNorm(n_out, elementwise_affine=False) if use_layer_norm else None, + nn.LayerNorm(n_out, elementwise_affine=False) + if use_layer_norm + else None, activation_fn() if use_activation else None, - Dropout(p=dropout_rate) if dropout_rate > 0 else None, + nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:], strict=False) - ) + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) ] ) ) @@ -210,12 +208,13 @@ def inject_into_layer(self, layer_num) -> bool: return user_cond def set_online_update_hooks(self, hook_first_layer=True): + """Set online update hooks.""" self.hooks = [] def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] + new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] return new_grad def _hook_fn_zero_out(grad): @@ -225,7 +224,7 @@ def _hook_fn_zero_out(grad): for layer in layers: if i == 0 and not hook_first_layer: continue - if isinstance(layer, Linear): + if isinstance(layer, nn.Linear): if self.inject_into_layer(i): w = layer.weight.register_hook(_hook_fn_weight) else: @@ -234,39 +233,62 @@ def _hook_fn_zero_out(grad): b = layer.bias.register_hook(_hook_fn_zero_out) self.hooks.append(b) - def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None): - """ - Forward computation on ``x``. + def forward( + self, + x: torch.Tensor, + cont: torch.Tensor | None = None, + cat_list: list | None = None + ) -> torch.Tensor: + """Forward computation on ``x``. Parameters ---------- x tensor of values with shape ``(n_in,)`` - cov - tensor of covariate values with shape ``(n_cov,)`` or None + cont + continuous covariates for this sample, + tensor of values with shape ``(n_cont,)`` + cat_list + list of category membership(s) for this sample Returns ------- - py:class:`torch.Tensor` + :class:`torch.Tensor` tensor of shape ``(n_out,)`` - """ + one_hot_cat_list = [] # for generality in this list many indices useless. + cont_list = [cont] if cont is not None else [] + cat_list = cat_list or [] + + if len(self.n_cat_list) > len(cat_list): + raise ValueError("nb. categorical args provided doesn't match init. params.") + for n_cat, cat in zip(self.n_cat_list, cat_list, strict=False): + if n_cat and cat is None: + raise ValueError("cat not provided while n_cat != 0 in init. params.") + if n_cat > 1: # n_cat = 1 will be ignored - no additional information + if cat.size(1) != n_cat: + one_hot_cat = nn.functional.one_hot(cat.squeeze(-1), n_cat) + else: + one_hot_cat = cat # cat has already been one_hot encoded + one_hot_cat_list += [one_hot_cat] for i, layers in enumerate(self.fc_layers): for layer in layers: if layer is not None: - if isinstance(layer, BatchNorm1d): + if isinstance(layer, nn.BatchNorm1d): if x.dim() == 3: x = torch.cat([(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0) else: x = layer(x) else: - # Injection of covariates - if ( - self.n_cov > 0 - and isinstance(layer, Linear) - and self.inject_into_layer(i) - ): - x = torch.cat((x, cov), dim=-1) + if isinstance(layer, nn.Linear) and self.inject_into_layer(i): + if x.dim() == 3: + cov_list_layer = [ + o.unsqueeze(0).expand((x.size(0), o.size(0), o.size(1))) + for o in one_hot_cat_list + ] + else: + cov_list_layer = one_hot_cat_list + x = torch.cat((x, *cov_list_layer, *cont_list), dim=-1) x = layer(x) return x @@ -294,7 +316,7 @@ def __init__( ): super().__init__() - self.eps = 1e-4 + self.clip_exp = np.log(torch.finfo(torch.get_default_dtype()).max) - 1e-4 self.mode = mode if self.mode == "sample_feature": self.encoder = Linear(n_input, n_output) @@ -316,11 +338,17 @@ def forward(self, x: torch.Tensor): ------- Predicted var """ - # Force to be non nan - TODO come up with better way to do so if self.mode == "sample_feature": v = self.encoder(x) - v = self.activation(v) + self.eps # Ensure that var is strictly positive elif self.mode == "feature": v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size - v = self.activation(v) + self.eps # Ensure that var is strictly positive + + # Ensure that var is strictly positive via exp - Bring back to non-log scale + # Clip to range that will not be inf after exp + v = torch.clip(v, min=-self.clip_exp, max=self.clip_exp) + v = self.activation(v) + if v.isnan().any(): + warnings.warn('Predicted variance contains nan values. Setting to 0.') + v = torch.nan_to_num(v) + return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 7277a76b64..5218510130 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import warnings from collections.abc import Sequence from typing import Literal @@ -13,29 +14,17 @@ from scvi.data import AnnDataManager from scvi.data.fields import ( LayerField, - ObsmField, + ObsmField, CategoricalObsField, CategoricalJointObsField, NumericalJointObsField, NumericalObsField, ) from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.utils import setup_anndata_dsp from ._module import SysVAE -from ._trainingplans import TrainingPlanCustom -from ._utils import prepare_metadata logger = logging.getLogger(__name__) -class TrainingCustom(UnsupervisedTrainingMixin): - """Train method with custom TrainingPlan.""" - - # TODO could make custom Trainer (in a custom TrainRunner) to have in init params for early stopping - # "loss" rather than "elbo" components in available param specifications - for now just use - # a loss that is against the param specification - - _training_plan_cls = TrainingPlanCustom - - -class SysVI(TrainingCustom, BaseModelClass): +class SysVI(UnsupervisedTrainingMixin, BaseModelClass): """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. Parameters @@ -84,18 +73,17 @@ def __init__( else: pseudoinput_data = None - n_cov_const = adata.obsm["covariates"].shape[1] if "covariates" in adata.obsm else 0 - cov_embed_sizes = ( - pd.DataFrame(adata.obsm["covariates_embed"]).nunique(axis=0).to_list() - if "covariates_embed" in adata.obsm - else [] + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None ) self.module = SysVAE( - n_input=adata.shape[1], - n_cov_const=n_cov_const, - cov_embed_sizes=cov_embed_sizes, - n_system=adata.obsm["system"].shape[1], + n_input=self.summary_stats.n_vars, + n_batch=self.summary_stats.n_batch, + n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), + n_cats_per_cov=n_cats_per_cov, prior=prior, n_prior_components=n_prior_components, pseudoinput_data=pseudoinput_data, @@ -103,13 +91,35 @@ def __init__( ) self._model_summary_string = ( - "cVAE model with optional VampPrior and latent cycle-consistency loss" + "SysVI - cVAE model with optional VampPrior and latent cycle-consistency loss." ) # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") + def train( + self, + *args, + plan_kwargs: dict | None = None, + **kwargs, + ): + plan_kwargs = plan_kwargs or {} + kl_weight_defaults = { + 'n_epochs_kl_warmup': 0, + 'n_steps_kl_warmup': 0 + } + if any([v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()]): + warnings.warn('The use of KL weight warmup is not recommended in SysVI. ' + + 'The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0.') + # Overwrite plan kwargs with kl weight defaults + plan_kwargs = {**plan_kwargs, **kl_weight_defaults} + + # Pass to parent + kwargs = kwargs or {} + kwargs['plan_kwargs'] = plan_kwargs + super().train(*args, **kwargs) + @torch.inference_mode() def get_latent_representation( self, @@ -171,44 +181,6 @@ def get_latent_representation( else: return predicted_m - def _validate_anndata( - self, adata: AnnData | None = None, copy_if_view: bool = True - ) -> AnnData: - """Validate anndata has been properly registered - - Parameters - ---------- - adata - Adata to validate. If None use SysVI's adata. - copy_if_view - Whether to copy adata before - - Returns - ------- - - """ - super()._validate_anndata(adata) - - # Check that all required fields are present and match the Model's adata - assert ( - self.adata.uns["layer_information"]["layer"] == adata.uns["layer_information"]["layer"] - ) - assert ( - self.adata.uns["layer_information"]["var_names"] - == adata.uns["layer_information"]["var_names"] - ) - assert self.adata.uns["batch_order"] == adata.uns["batch_order"] - for covariate_type, covariate_keys in self.adata.uns["covariate_key_orders"].items(): - assert covariate_keys == adata.uns["covariate_key_orders"][covariate_type] - if "categorical" in covariate_type: - for covariate_key in covariate_keys: - assert ( - self.adata.uns["covariate_categ_orders"][covariate_key] - == adata.uns["covariate_categ_orders"][covariate_key] - ) - - return adata - @classmethod @setup_anndata_dsp.dedent def setup_anndata( @@ -217,24 +189,21 @@ def setup_anndata( batch_key: str, layer: str | None = None, categorical_covariate_keys: list[str] | None = None, - categorical_covariate_embed_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, - covariate_categ_orders: dict | None = None, - covariate_key_orders: dict | None = None, - batch_order: list[str] | None = None, + weight_batches: bool = False, **kwargs, ): """Prepare adata for input to Model Setup distinguishes between two main types of covariates that can be corrected for: - - batch (referred to as "system" in the original publication): Single categorical covariate that + - batch (referred to as "batch" in the original publication): Single categorical covariate that will be corrected via cycle consistency loss. It will be also used as a condition in cVAE. This covariate is expected to correspond to stronger batch effects, such as between datasets from the different sequencing technology or model systems (animal species, in-vitro models, etc.). - covariate (includes both continous and categorical covariates): Additional covariates to be used only as a condition in cVAE, but not corrected via cycle loss. These covariates are expected to correspond to weaker batch effects, such as between datasets from the - same sequencing technology and model system (animal, in-vitro, etc.) or between samples within a dataset. + same sequencing technology and model batch (animal, in-vitro, etc.) or between samples within a dataset. Parameters ---------- @@ -242,126 +211,31 @@ def setup_anndata( Adata object - will be modified in place. batch_key Name of the obs column with the substantial batch effect covariate, - referred to as system in the original publication (Hrovatin, et al., 2023). + referred to as batch in the original publication (Hrovatin, et al., 2023). Should be categorical. layer AnnData layer to use, default is X. Should contain normalized and log+1 transformed expression. categorical_covariate_keys - Name of obs column with additional categorical covariate information. Will be one hot encoded. - categorical_covariate_embed_keys - Name of obs column with additional categorical covariate information. Embedding will be learned. - This can be useful if the number of categories is very large, which would increase memory usage. - If using this type of covariate representation please also cite - `scPoli <[https://doi.org/10.1038/s41592-023-02035-2]>`_ . + Name of obs column with additional categorical covariate information. + Will be one hot encoded or embedded, as later defined in the model. continuous_covariate_keys Name of obs column with additional continuous covariate information. - covariate_categ_orders - Covariate encoding information. Should be used if a new adata is to be set up according - to setup of an existing adata. Access via adata.uns['covariate_categ_orders'] of already setup adata. - covariate_key_orders - Covariate encoding information. Should be used if a new adata is to be set up according - to setup of an existing adata. Access via adata.uns['covariate_key_orders'] of already setup adata. - batch_order - Same as covariate_orders, but for system. Access via adata.uns['batch_order'] """ setup_method_args = cls._get_setup_method_args(**locals()) - if adata.shape[1] != len(set(adata.var_names)): - raise ValueError("Adata var_names are not unique") - - # The used layer argument - # This could be also done via registry, but that is too cumbersome - adata.uns["layer_information"] = { - "layer": layer, - "var_names": list(adata.var_names), - } - - # If setup is to be prepared wtr another adata specs make sure all relevant info is present - if covariate_categ_orders or covariate_key_orders or batch_order: - assert batch_order is not None - if ( - categorical_covariate_keys is not None - or categorical_covariate_embed_keys is not None - or continuous_covariate_keys is not None - ): - assert covariate_categ_orders is not None - assert covariate_key_orders is not None - - # Make system embedding with specific category order - - # Define order of system categories - if batch_order is None: - batch_order = sorted(adata.obs[batch_key].unique()) - # Validate that the provided batch_order matches the categories in adata.obs[batch_key] - if set(batch_order) != set(adata.obs[batch_key].unique()): - raise ValueError( - "Provided batch_order does not match the categories in adata.obs[batch_key]" - ) - - # Make one-hot embedding with specified order - systems_dict = dict( - zip(batch_order, ([float(i) for i in range(0, len(batch_order))]), strict=False) - ) - adata.uns["batch_order"] = batch_order - system_cat = pd.Series( - pd.Categorical(values=adata.obs[batch_key], categories=batch_order, ordered=True), - index=adata.obs.index, - name="system", - ) - adata.obsm["system"] = pd.get_dummies(system_cat, dtype=float) - - # Set up covariates - # TODO this could be handled by specific field type in registry - - # System must not be in cov - if categorical_covariate_keys is not None: - if batch_key in categorical_covariate_keys: - raise ValueError("batch_key should not be within covariate keys") - if categorical_covariate_embed_keys is not None: - if batch_key in categorical_covariate_embed_keys: - raise ValueError("batch_key should not be within covariate keys") - if continuous_covariate_keys is not None: - if batch_key in continuous_covariate_keys: - raise ValueError("batch_key should not be within covariate keys") - - # Prepare covariate training representations/embedding - covariates, covariates_embed, orders_dict, cov_dict = prepare_metadata( - meta_data=adata.obs, - cov_cat_keys=categorical_covariate_keys, - cov_cat_embed_keys=categorical_covariate_embed_keys, - cov_cont_keys=continuous_covariate_keys, - categ_orders=covariate_categ_orders, - key_orders=covariate_key_orders, - ) - - # Save covariate representation and order information - adata.uns["covariate_categ_orders"] = orders_dict - adata.uns["covariate_key_orders"] = cov_dict - if continuous_covariate_keys is not None or categorical_covariate_keys is not None: - adata.obsm["covariates"] = covariates - else: - # Remove if present since the presence of this key - # is in model used to determine if cov should be used or not - if "covariates" in adata.obsm: - del adata.obsm["covariates"] - if categorical_covariate_embed_keys is not None: - adata.obsm["covariates_embed"] = covariates_embed - else: - if "covariates_embed" in adata.obsm: - del adata.obsm["covariates_embed"] - - # Anndata setup - anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=False), - ObsmField("system", "system"), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] - # Covariate fields are optional - if continuous_covariate_keys is not None or categorical_covariate_keys is not None: - anndata_fields.append(ObsmField("covariates", "covariates")) - if categorical_covariate_embed_keys is not None: - anndata_fields.append(ObsmField("covariates_embed", "covariates_embed")) + if weight_batches: + warnings.warn('The use of inverse batch proportion weights is experimental.') + batch_weights_key = 'batch_weights' + adata.obs[batch_weights_key] = adata.obs[batch_key].map({ + cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()}) + anndata_fields.append(NumericalObsField(batch_weights_key, batch_weights_key)) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 92ebfde80e..eb6335e1f1 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -7,7 +7,7 @@ from scvi import REGISTRY_KEYS from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data -from ._base_components import Embedding, EncoderDecoder +from ._base_components import EncoderDecoder from ._priors import StandardPrior, VampPrior torch.backends.cudnn.benchmark = True @@ -27,8 +27,8 @@ class SysVAE(BaseModuleClass): cov_embed_sizes Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...]. Passed directly from Model. - n_system - Number of systems. + n_batch + Number of batches. Passed directly from Model. cov_embed_dims Dimension for covariate embedding. @@ -62,10 +62,10 @@ class SysVAE(BaseModuleClass): def __init__( self, n_input: int, - n_cov_const: int, - cov_embed_sizes: list, - n_system: int, - cov_embed_dims: int = 10, + n_batch: int, + n_continuous_cov: int = 0, + n_cats_per_cov: list[int] | None = None, + embed_cat: bool = False, prior: Literal["standard_normal", "vamp"] = "vamp", n_prior_components: int = 5, trainable_priors: bool = True, @@ -74,27 +74,35 @@ def __init__( n_hidden: int = 256, n_layers: int = 2, dropout_rate: float = 0.1, - out_var_mode: str = "feature", - **enc_dec_kwargs, + out_var_mode: Literal["sample_feature", "feature"] = "feature", + enc_dec_kwargs: dict | None = None, + embedding_kwargs: dict | None = None, + ): super().__init__() - self.embed_cov = len(cov_embed_sizes) > 0 # Will any covs be embedded - self.n_cov_const = n_cov_const # Dimension of covariates that are not embedded - n_cov = ( - n_cov_const + len(cov_embed_sizes) * cov_embed_dims - ) # Total size of covs (embedded & not embedded) - n_cov_encoder = n_cov + n_system # N covariates passed to Module (cov & system) + self.embed_cat = embed_cat - if self.embed_cov: - self.cov_embeddings = torch.nn.ModuleList( - [Embedding(size=size, cov_embed_dims=cov_embed_dims) for size in cov_embed_sizes] - ) + enc_dec_kwargs = enc_dec_kwargs or {} + embedding_kwargs = embedding_kwargs or {} + + self.n_batch = n_batch + n_cat_list = [n_batch] + n_cont = n_continuous_cov + if n_cats_per_cov is not None: + if self.embed_cat: + for cov, n in enumerate(n_cats_per_cov): + cov = self._cov_idx_name(cov=cov) + self.init_embedding(cov, n, **embedding_kwargs) + n_cont += self.get_embedding(cov).embedding_dim + else: + n_cat_list.extend(n_cats_per_cov) self.encoder = EncoderDecoder( n_input=n_input, n_output=n_latent, - n_cov=n_cov_encoder, + n_cat_list=n_cat_list, + n_cont=n_cont, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, @@ -106,7 +114,8 @@ def __init__( self.decoder = EncoderDecoder( n_input=n_latent, n_output=n_input, - n_cov=n_cov_encoder, + n_cat_list=n_cat_list, + n_cont=n_cont, n_hidden=n_hidden, n_layers=n_layers, dropout_rate=dropout_rate, @@ -118,84 +127,114 @@ def __init__( if prior == "standard_normal": self.prior = StandardPrior() elif prior == "vamp": - if pseudoinput_data is not None: - pseudoinput_data = self._get_inference_input(pseudoinput_data) + assert pseudoinput_data is not None, 'Pseudoinput data must be specified if using VampPrior' + pseudoinput_data = self._get_inference_input(pseudoinput_data) self.prior = VampPrior( n_components=n_prior_components, - n_input=n_input, - n_cov=n_cov_encoder, encoder=self.encoder, - data=( - pseudoinput_data["expr"], - self._merge_cov( - cov=pseudoinput_data["cov"], system=pseudoinput_data["system"] - ), + data_x=pseudoinput_data["expr"], + n_cat_list=n_cat_list, + data_cat=self._merge_batch_cov( + cat=pseudoinput_data["cat"], batch=pseudoinput_data["batch"] ), + data_cont=pseudoinput_data["cont"], trainable_priors=trainable_priors, ) else: raise ValueError("Prior not recognised") + @staticmethod + def _cov_idx_name(cov: int | float): + return "cov" + str(cov) + def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]: """Parse the input tensors to get inference inputs""" - expr = tensors[REGISTRY_KEYS.X_KEY] cov = self._get_cov(tensors=tensors) - system = tensors["system"] - input_dict = {"expr": expr, "cov": cov, "system": system} + input_dict = { + "expr": tensors[REGISTRY_KEYS.X_KEY], + "batch": tensors[REGISTRY_KEYS.BATCH_KEY], + "cat": cov['cat'], + "cont": cov['cont'] + } return input_dict def _get_inference_cycle_input( - self, tensors, generative_outputs, selected_system: torch.Tensor, **kwargs + self, tensors, generative_outputs, selected_batch: torch.Tensor, **kwargs ) -> dict[str, torch.Tensor]: - """Parse the input tensors and cycle system info to get cycle inference inputs""" - expr = generative_outputs["y_m"] + """Parse the input tensors and cycle batch info to get cycle inference inputs.""" cov = self._mock_cov(self._get_cov(tensors=tensors)) - system = selected_system - input_dict = {"expr": expr, "cov": cov, "system": system} + input_dict = { + "expr": generative_outputs["y_m"], + "batch": selected_batch, + "cat": cov['cat'], + "cont": cov['cont'] + } return input_dict def _get_generative_input( - self, tensors, inference_outputs, selected_system: torch.Tensor, **kwargs - ) -> dict[str, torch.Tensor]: - """Parse the input tensors, inference inputs, and cycle system to get generative inputs""" + self, + tensors: dict[str, torch.Tensor], + inference_outputs: dict[str, torch.Tensor], + selected_batch: torch.Tensor, + **kwargs + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]: + """Parse the input tensors, inference inputs, and cycle batch to get generative inputs""" z = inference_outputs["z"] cov = self._get_cov(tensors=tensors) - cov = {"x": cov, "y": self._mock_cov(cov)} + cov_mock = self._mock_cov(cov) + cat = {'x': cov['cat'], 'y': cov_mock['cat']} + cont = {'x': cov['cont'], 'y': cov_mock['cont']} - system = {"x": tensors["system"], "y": selected_system} + batch = {"x": tensors["batch"], "y": selected_batch} - input_dict = {"z": z, "cov": cov, "system": system} + input_dict = {"z": z, "batch": batch, 'cat': cat, 'cont': cont} return input_dict - @auto_move_data - def _get_cov(self, tensors: dict[str, torch.Tensor]) -> torch.Tensor | None: - """Merge all covariates into single tensor, including embedding of covariates""" - cov = [] - if self.n_cov_const > 0: - cov.append(tensors["covariates"]) - if self.embed_cov: - cov.extend( - [ - embedding(tensors["covariates_embed"][:, idx].int()) - for idx, embedding in enumerate(self.cov_embeddings) - ] - ) - cov = torch.concat(cov, dim=1) if len(cov) > 0 else None + @auto_move_data # TODO remove? + def _get_cov(self, tensors: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """Process all covariates into continuous and categorical components for cVAE""" + + cat_parts = [] + cont_parts = [] + if REGISTRY_KEYS.CONT_COVS_KEY in tensors: + cont_parts.append(tensors[REGISTRY_KEYS.CONT_COVS_KEY]) + if REGISTRY_KEYS.CAT_COVS_KEY in tensors: + cat = torch.split(tensors[REGISTRY_KEYS.CAT_COVS_KEY], 1, dim=1) + if self.embed_cat: + for idx, tensor in enumerate(cat): + cont_parts.append( + self.compute_embedding(tensor, self._cov_idx_name(idx)) + ) + else: + cat_parts.extend(cat) + cov = { + 'cat': cat_parts, + 'cont': torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None + } return cov @staticmethod - def _merge_cov(cov: torch.Tensor | None, system: torch.Tensor) -> torch.Tensor: - """Merge full covariate data and system data to get cov for model input""" - return torch.cat([cov, system], dim=1) if cov is not None else system + def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch.Tensor]: + return [batch] + cat @staticmethod - def _mock_cov(cov: torch.Tensor | None) -> torch.Tensor | None: + def _mock_cov(cov: dict[str, list[torch.Tensor], torch.Tensor, None]) -> torch.Tensor | None: """Make mock (all 0) covariates for cycle""" - return torch.zeros_like(cov) if cov is not None else None + mock = { + 'cat': [torch.zeros_like(cat) for cat in cov['cat']], + 'cont': torch.zeros_like(cov['cont']) if cov['cont'] is not None else None, + } + return mock @auto_move_data - def inference(self, expr, cov, system) -> dict: + def inference(self, + expr: torch.Tensor, + batch: torch.Tensor, + cat: list[torch.Tensor], + cont: torch.Tensor | None, + ) -> dict: """ expression & cov -> latent representation @@ -205,18 +244,18 @@ def inference(self, expr, cov, system) -> dict: Expression data cov Full covariate data (categorical, categorical embedded, and continuous - system + batch System representation Returns ------- Posterior parameters and sample """ - z = self.encoder(x=expr, cov=self._merge_cov(cov=cov, system=system)) + z = self.encoder(x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) return {"z": z["y"], "z_m": z["y_m"], "z_v": z["y_v"]} @auto_move_data - def generative(self, z, cov, system, x_x: bool = True, x_y: bool = True) -> dict: + def generative(self, z, batch, cat, cont, x_x: bool = True, x_y: bool = True) -> dict: """ latent representation & cov -> expression @@ -226,28 +265,36 @@ def generative(self, z, cov, system, x_x: bool = True, x_y: bool = True) -> dict Latent embedding cov Full covariate data (categorical, categorical embedded, and continuous - system + batch System representation x_x - Decode to original system + Decode to original batch x_y - Decode to replacement system + Decode to replacement batch Returns ------- Decoded distribution parameters and sample """ - def outputs(compute, name, res, x, cov, system): + def outputs( + compute: bool, + name: str, + res: dict, + x: torch.Tensor, + batch: torch.Tensor, + cat: list[torch.Tensor], + cont: torch.Tensor | None, + ): if compute: - res_sub = self.decoder(x=x, cov=self._merge_cov(cov=cov, system=system)) + res_sub = self.decoder(x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) res[name] = res_sub["y"] res[name + "_m"] = res_sub["y_m"] res[name + "_v"] = res_sub["y_v"] res = {} - outputs(compute=x_x, name="x", res=res, x=z, cov=cov["x"], system=system["x"]) - outputs(compute=x_y, name="y", res=res, x=z, cov=cov["y"], system=system["y"]) + outputs(compute=x_x, name="x", res=res, x=z, batch=batch['x'], cat=cat['x'], cont=cont['x']) + outputs(compute=x_y, name="y", res=res, x=z, batch=batch['y'], cat=cat['y'], cont=cont['y']) return res @auto_move_data @@ -291,21 +338,21 @@ def forward( # don't compute if not needed # Parse kwargs - inference_kwargs = _get_dict_if_none(inference_kwargs) - generative_kwargs = _get_dict_if_none(generative_kwargs) - loss_kwargs = _get_dict_if_none(loss_kwargs) - get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) - get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) + inference_kwargs = inference_kwargs or {} + generative_kwargs = generative_kwargs or {} + loss_kwargs = loss_kwargs or {} + get_inference_input_kwargs = get_inference_input_kwargs or {} + get_generative_input_kwargs = get_generative_input_kwargs or {} # Inference inference_inputs = self._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = self.inference(**inference_inputs, **inference_kwargs) # Generative - selected_system = self.random_select_systems(tensors["system"]) + selected_batch = self.random_select_batch(tensors[REGISTRY_KEYS.BATCH_KEY]) generative_inputs = self._get_generative_input( tensors, inference_outputs, - selected_system=selected_system, + selected_batch=selected_batch, **get_generative_input_kwargs, ) generative_outputs = self.generative( @@ -315,12 +362,12 @@ def forward( inference_cycle_inputs = self._get_inference_cycle_input( tensors=tensors, generative_outputs=generative_outputs, - selected_system=selected_system, + selected_batch=selected_batch, **get_inference_input_kwargs, ) inference_cycle_outputs = self.inference(**inference_cycle_inputs, **inference_kwargs) - # Combine outputs of all forward pass components + # Combine outputs of all forward pass components - first and cycle pass inference_outputs_merged = dict(**inference_outputs) inference_outputs_merged.update( **{k.replace("z", "z_cyc"): v for k, v in inference_cycle_outputs.items()} @@ -347,13 +394,12 @@ def loss( reconstruction_weight: float = 1.0, z_distance_cycle_weight: float = 2.0, ) -> LossOutput: - x_true = tensors[REGISTRY_KEYS.X_KEY] # Reconstruction loss + x_true = tensors[REGISTRY_KEYS.X_KEY] reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( generative_outputs["x_m"], x_true, generative_outputs["x_v"] ).sum(dim=1) - reconst_loss = reconst_loss_x # Kl divergence on latent space @@ -382,6 +428,8 @@ def standardize(x): ) z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) + if 'batch_weights' in tensors.keys(): + z_distance_cyc *= tensors['batch_weights'].flatten() loss = ( reconst_loss * reconstruction_weight @@ -390,56 +438,49 @@ def standardize(x): ) return LossOutput( - n_obs_minibatch=loss.shape[0], loss=loss.mean(), - extra_metrics={ - "reconstruction_loss": reconst_loss.mean(), - "kl_local": kl_divergence_z.mean(), - "z_distance_cycle": z_distance_cyc.mean(), - }, + reconstruction_loss=reconst_loss, + kl_local=kl_divergence_z, + extra_metrics={'cycle_loss':z_distance_cyc.mean()}, ) - @staticmethod - def random_select_systems(system: torch.Tensor) -> torch.Tensor: - """For every cell randomly selects a new system that is different from the original system + def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: + """For every cell randomly selects a new batch that is different from the original batch Parameters ---------- - system - One hot encoded system information for each cell + batch + One hot encoded batch information for each cell Returns ------- - One hot encoding of newly selected system for each cell + One hot encoding of newly selected batch for each cell """ - # get available systems - those that are zero will become nonzero and vice versa - available_systems = 1 - system + # Get available batches - + # those that are zero will become nonzero and vice versa + batch = torch.nn.functional.one_hot(batch.squeeze(-1), self.n_batch) + available_batches = 1 - batch # Get nonzero indices for each cell - row_indices, col_indices = torch.nonzero(available_systems, as_tuple=True) - col_pairs = col_indices.view(-1, system.shape[1] - 1) - # Select system for every cell from available systems + row_indices, col_indices = torch.nonzero(available_batches, as_tuple=True) + col_pairs = col_indices.view(-1, batch.shape[1] - 1) + # Select batch for every cell from available batches randomly_selected_indices = col_pairs.gather( 1, torch.randint( 0, - system.shape[1] - 1, + batch.shape[1] - 1, size=(col_pairs.size(0), 1), device=col_pairs.device, dtype=col_pairs.dtype, ), ) - new_tensor = torch.zeros_like(available_systems) - # generate system covariate tensor + new_tensor = torch.zeros_like(available_batches) + # generate batch covariate tensor new_tensor.scatter_(1, randomly_selected_indices, 1) return new_tensor + @torch.inference_mode() def sample(self, *args, **kwargs): - raise NotImplementedError("") - - -def _get_dict_if_none(param: dict | None) -> dict: - """If not a dict return empty dict""" - param = {} if not isinstance(param, dict) else param - return param + raise NotImplementedError("The use of decoded expression is not recommended for SysVI.") diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py index d2c28e1375..b2d6361540 100644 --- a/src/scvi/external/sysvi/_priors.py +++ b/src/scvi/external/sysvi/_priors.py @@ -28,10 +28,6 @@ class VampPrior(Prior): ---------- n_components Prior components - n_input - Model input dimensions - n_cov - Model input covariate dimensions encoder The encoder data @@ -44,32 +40,44 @@ class VampPrior(Prior): def __init__( self, - n_components, - n_input, - n_cov, - encoder, - data: tuple[torch.tensor, torch.tensor] | None = None, - trainable_priors=True, + n_components: int, + encoder: torch.nn.Module, + data_x: torch.tensor, + n_cat_list: list[int], + data_cat: list[torch.tensor], + data_cont: torch.tensor|None=None, + trainable_priors: bool = True, ): super().__init__() self.encoder = encoder - # Get pseudoinputs - if data is None: - u = torch.rand(n_components, n_input) # K * I - u_cov = torch.zeros(n_components, n_cov) # K * C + # Make pseudoinputs into parameters + # X + assert n_components == data_x.shape[0] + self.u = torch.nn.Parameter(data_x, requires_grad=trainable_priors) # K x I + # Cat + assert all([cat.shape[0] == n_components for cat in data_cat]) + # For categorical covariates, since scvi-tools one-hot encodes + # them in the layers, we need to create a multinomial distn + # from which we can sample categories for layers input + # Initialise the multinomial distn weights based on + # one-hot encoding of pseudoinput categories + self.u_cat = torch.nn.ParameterList([ + torch.nn.Parameter( + torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot + requires_grad=trainable_priors) + for cat, n in zip(data_cat, n_cat_list) # K x C_cat + ]) + # Cont + if data_cont is None: + self.u_cont = None else: - u = data[0] - u_cov = data[1] - assert n_components == data[0].shape[0] == data[1].shape[0] - assert n_input == data[0].shape[1] - assert n_cov == data[1].shape[1] - self.u = torch.nn.Parameter(u, requires_grad=trainable_priors) - self.u_cov = torch.nn.Parameter(u_cov, requires_grad=trainable_priors) + assert n_components == data_cont.shape[0] + self.u_cont = torch.nn.Parameter(data_cont, requires_grad=trainable_priors) # K x C_cont # mixing weights - self.w = torch.nn.Parameter(torch.zeros(self.u.shape[0], 1, 1)) # K x 1 x 1 + self.w = torch.nn.Parameter(torch.zeros(n_components, 1, 1)) # K x 1 x 1 def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -82,7 +90,9 @@ def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: # u, u_cov -> encoder -> mean, var original_mode = self.encoder.training self.encoder.train(False) - z = self.encoder(x=self.u, cov=self.u_cov) + # Convert category weights to categories + cat_list = [torch.multinomial(cat, num_samples=1) for cat in self.u_cat] + z = self.encoder(x=self.u, cat_list=cat_list, cont=self.u_cont) self.encoder.train(original_mode) return z["y_m"], z["y_v"] # (K x L), (K x L) diff --git a/src/scvi/external/sysvi/_trainingplans.py b/src/scvi/external/sysvi/_trainingplans.py deleted file mode 100644 index ce4b521c1c..0000000000 --- a/src/scvi/external/sysvi/_trainingplans.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations - -from inspect import getfullargspec -from typing import Literal - -import torch - -from scvi.module.base import BaseModuleClass, LossOutput -from scvi.train import TrainingPlan - -# TODO could make new metric class to not be called elbo metric as used for other metrics as well -from scvi.train._metrics import ElboMetric - - -class WeightScaling: - """Linearly scale loss weights between start and end weight accordingly to the current training stage - - Parameters - ---------- - weight_start - Starting weight value - weight_end - End weight vlue - point_start - Training point to start scaling - before weight is weight_start - Since the epochs are counted after they are run, - the start point must be set to 0 to represent 1st epoch - point_end - Training point to end scaling - after weight is weight_end - Since the epochs are counted after they are run, - the start point must be set to n-1 to represent the last epoch - update_on - Define training progression based on epochs or steps - - """ - - def __init__( - self, - weight_start: float, - weight_end: float, - point_start: int, - point_end: int, - update_on: Literal["epoch", "step"] = "step", - ): - self.weight_start = weight_start - self.weight_end = weight_end - self.point_start = point_start - self.point_end = point_end - if update_on not in ["step", "epoch"]: - raise ValueError("update_on not recognized") - self.update_on = update_on - - weight_diff = self.weight_end - self.weight_start - n_points = self.point_end - self.point_start - self.slope = weight_diff / n_points - - if ( - self.weight(epoch=self.point_start, step=self.point_start) < 0 - or self.weight(epoch=self.point_end, step=self.point_end) < 0 - ): - raise ValueError("Specified weight scaling would lead to negative weights") - - def weight( - self, - epoch: int, - step: int, - ) -> float: - """ - Computes the weight for the current step/epoch depending on which update type was set in init - - Parameters - ---------- - epoch - Current epoch. - step - Current step. - """ - if self.update_on == "epoch": - point = epoch - elif self.update_on == "step": - point = step - else: - # This is ensured not to happen by above init check - raise ValueError("self.update_on not recognised") - - if point < self.point_start: - return self.weight_start - elif point > self.point_end: - return self.weight_end - else: - return self.slope * (point - self.point_start) + self.weight_start - - -class TrainingPlanCustom(TrainingPlan): - """Extends scvi TrainingPlan for custom support for other losses. - - Parameters - ---------- - args - Passed to parent - log_on_epoch - See on_epoch of lightning Module log method - log_on_step - See on_step of lightning Module log method - loss_weights - Specifies how losses should be weighted and how it may change during training - Dict with keys being loss names and values being loss weights. - Loss weights can be floats for constant weight or dict of params passed to WeightScaling object - Note that other loss weight params from the parent class are ignored - (e.g. n_steps/epochs_kl_warmup and min/max_kl_weight) - kwargs - Passed to parent. - As described in param loss_weights the loss weighting params of parent are ignored - """ - - def __init__( - self, - module: BaseModuleClass, - loss_weights: None | dict[str, float | WeightScaling] = None, - log_on_epoch: bool = True, - log_on_step: bool = False, - **kwargs, - ): - super().__init__(module, **kwargs) - - self.log_on_epoch = log_on_epoch - self.log_on_step = log_on_step - - # automatic handling of loss component weights - if loss_weights is None: - loss_weights = {} - # Make weighting object - for loss, weight in loss_weights.items(): - if isinstance(weight, dict): - loss_weights[loss] = WeightScaling(**weight) - self.loss_weights = loss_weights - - # Ensure that all passed loss weight specifications are in available loss params - # Also update loss kwargs based on specified weights - self._loss_args = getfullargspec(self.module.loss)[0] - # Make sure no loss weights are already in loss kwargs (e.g. from parent init) - for loss in self._loss_args: - if loss in self.loss_kwargs: - del self.loss_kwargs[loss] - for loss, weight in loss_weights.items(): - if loss not in self._loss_args: - raise ValueError( - f"Loss {loss} for which a weight was specified is not in loss parameters" - ) - # This will also overwrite the kl_weight from parent - self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)}) - - def compute_loss_weight(self, weight): - if isinstance(weight, float): - return weight - elif isinstance(weight, int): - return float(weight) - elif isinstance(weight, WeightScaling): - return weight.weight(epoch=self.current_epoch, step=self.global_step) - - @staticmethod - def _create_elbo_metric_components( - mode: Literal["train", "validation"], **kwargs - ) -> ElboMetric: - """ - Initialize the combined loss collection. - - Parameters - ---------- - mode - train/validation - - Returns - ------- - tuple - Objects for storing the combined loss - - """ - loss = ElboMetric("loss", mode, "batch") - return loss - - def initialize_train_metrics(self): - """Initialize train combined loss. - - TODO could add other losses - """ - self.loss_train = self._create_elbo_metric_components( - mode="train", n_total=self.n_obs_training - ) - self.loss_train.reset() - - def initialize_val_metrics(self): - """Initialize val combined loss. - - TODO could add other losses - """ - self.loss_val = self._create_elbo_metric_components( - mode="validation", n_total=self.n_obs_validation - ) - self.loss_val.reset() - - @torch.no_grad() - def compute_and_log_metrics( - self, - loss_output: LossOutput, - loss_metric: ElboMetric, - mode: str, - ): - """ - Computes and logs metrics. - - Parameters - ---------- - loss_output - LossRecorder object from scvi-tools module - loss_metric - The loss Metric to update - mode - Postfix string to add to the metric name of - extra metrics. If train also logs the loss in progress bar - """ - n_obs_minibatch = loss_output.n_obs_minibatch - loss = loss_output.loss - - loss_metric.update( - loss=loss, - n_obs_minibatch=n_obs_minibatch, - ) - - self.log( - f"loss_{mode}", - loss, - on_step=self.log_on_step, - on_epoch=self.log_on_epoch, - prog_bar=True if mode == "train" else False, - sync_dist=self.use_sync_dist, - ) - - # accumulate extra metrics passed to loss recorder - for extra_metric in loss_output.extra_metrics_keys: - met = loss_output.extra_metrics[extra_metric] - if isinstance(met, torch.Tensor): - if met.shape != torch.Size([]): - raise ValueError("Extra tracked metrics should be 0-d tensors.") - met = met.detach() - self.log( - f"{extra_metric}_{mode}", - met, - on_step=self.log_on_step, - on_epoch=self.log_on_epoch, - sync_dist=self.use_sync_dist, - ) - - def training_step(self, batch, batch_idx): - for loss, weight in self.loss_weights.items(): - self.loss_kwargs.update({loss: self.compute_loss_weight(weight=weight)}) - _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) - # combined loss is logged via compute_and_log_metrics - self.compute_and_log_metrics(scvi_loss, self.loss_train, "train") - return scvi_loss.loss - - def validation_step(self, batch, batch_idx): - _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) - # Combined loss is logged via compute_and_log_metrics - self.compute_and_log_metrics(scvi_loss, self.loss_val, "validation") - - @property - def kl_weight(self): - # Can not raise not implemented error as used in parent init - pass diff --git a/src/scvi/external/sysvi/_utils.py b/src/scvi/external/sysvi/_utils.py deleted file mode 100644 index ad2e8bcab0..0000000000 --- a/src/scvi/external/sysvi/_utils.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -import pandas as pd - - -def prepare_metadata( - meta_data: pd.DataFrame, - cov_cat_keys: list | None = None, - cov_cat_embed_keys: list | None = None, - cov_cont_keys: list | None = None, - categ_orders: list | None = None, - key_orders: list | None = None, -): - """ - Prepare content of dataframe columns for model training (one hot encoding, encoding for embedding, ...) - - Parameters - ---------- - meta_data - Dataframe containing metadata columns. - cov_cat_keys - List of categorical covariates column names to be one-hot encoded. - cov_cat_embed_keys - List of categorical covariates column names to be embedded. - cov_cont_keys - List of continuous covariates column names. - categ_orders - Defined orders for categorical covariates. Dict with keys being - categorical covariates keys and values being lists of categories. May contain more - categories than data. - key_orders - Defines order of covariate columns. Dict with keys being 'categorical', 'categorical_embed', 'continuous' - and values being lists of keys. - - Returns - ------- - Tuple of: covariate data that does not require embedding, - covariate data that requires embedding, - dict with order of categories per covariate (as orders), - dict with keys (categorical, categorical_embed, and continuous) specifying as values - order of covariates used to construct the two covariate datas - - """ - - def get_categories_order(values: pd.Series, categories: list | None = None): - """ - Helper to get order of categories based on values and optional list of categories - - Parameters - ---------- - values - Categorical values - categories - Optional order of categories - - Returns - ------- - Categories order - """ - if categories is None: - categories = pd.Categorical(values).categories.values - else: - missing = set(values.unique()) - set(categories) - if len(missing) > 0: - raise ValueError( - f"Some values of {values.name} are not in the specified categories order: {missing}" - ) - return list(categories) - - if cov_cat_keys is None: - cov_cat_keys = [] - if cov_cat_embed_keys is None: - cov_cat_embed_keys = [] - if cov_cont_keys is None: - cov_cont_keys = [] - - # Check & set order of covariates and categories - if key_orders is not None: - assert set(key_orders["categorical"]) == set(cov_cat_keys) - cov_cat_keys = key_orders["categorical"] - assert set(key_orders["categorical_embed"]) == set(cov_cat_embed_keys) - cov_cat_embed_keys = key_orders["categorical_embed"] - assert set(key_orders["continuous"]) == set(cov_cont_keys) - cov_cont_keys = key_orders["continuous"] - cov_dict = { - "categorical": cov_cat_keys, - "categorical_embed": cov_cat_embed_keys, - "continuous": cov_cont_keys, - } - - if categ_orders is None: - categ_orders = {} - for cov_key in cov_cat_keys + cov_cat_embed_keys: - categ_orders[cov_key] = get_categories_order( - values=meta_data[cov_key], categories=None - ) - - def dummies_categories(values: pd.Series, categories: list): - """ - Make dummies of categorical covariates. Use specified order of categories. - - Parameters - ---------- - values - Categorical vales for each observation. - categories - Order of categories to use - - Returns - ------- - dummies - one-hot encoding of categories in same order as categories. - """ - # Get dummies - # Ensure ordering - values = pd.Series( - pd.Categorical(values=values, categories=categories, ordered=True), - index=values.index, - name=values.name, - ) - # This is problematic if many covariates - dummies = pd.get_dummies(values, prefix=values.name) - - return dummies - - # Covs that are not embedded: continuous and one-hot encoded categorical covariates - if len(cov_cat_keys) > 0 or len(cov_cont_keys) > 0: - cov_cat_data = [] - for cov_cat_key in cov_cat_keys: - cat_dummies = dummies_categories( - values=meta_data[cov_cat_key], categories=categ_orders[cov_cat_key] - ) - cov_cat_data.append(cat_dummies) - # Prepare single cov array for all covariates - cov_data_parsed = pd.concat(cov_cat_data + [meta_data[cov_cont_keys]], axis=1) - else: - cov_data_parsed = None - - # Data of covariates to be embedded - if len(cov_cat_embed_keys) > 0: - cov_embed_data = [] - for cov_cat_embed_key in cov_cat_embed_keys: - cat_order = categ_orders[cov_cat_embed_key] - cat_map = dict(zip(cat_order, range(len(cat_order)), strict=False)) - cov_embed_data.append(meta_data[cov_cat_embed_key].map(cat_map)) - cov_embed_data = pd.concat(cov_embed_data, axis=1) - else: - cov_embed_data = None - - return cov_data_parsed, cov_embed_data, categ_orders, cov_dict diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 77fc539bb1..c3b9e4c5dc 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -1,5 +1,5 @@ import math - +import pytest import numpy as np import pandas as pd from anndata import AnnData @@ -44,46 +44,48 @@ def mock_adata(): ) adata.obs["covariate_cont"] = list(range(200)) adata.obs["covariate_cat"] = ["a"] * 50 + ["b"] * 50 + ["c"] * 50 + ["d"] * 50 - adata.obs["covariate_cat_emb"] = ["a"] * 50 + ["b"] * 50 + ["c"] * 50 + ["d"] * 50 - adata.obs["system"] = ["a"] * 100 + ["b"] * 50 + ["c"] * 50 + adata.obs["batch"] = ["a"] * 100 + ["b"] * 50 + ["c"] * 50 return adata -def test_model(): - adata0 = mock_adata() - - # Run adata setup with all covariates - SysVI.setup_anndata( - adata0, - batch_key="system", - categorical_covariate_keys=["covariate_cat"], - categorical_covariate_embed_keys=["covariate_cat_emb"], - continuous_covariate_keys=["covariate_cont"], - ) - - # Run adata setup transfer - # TODO ensure this is actually done correctly, not just that it runs through +@pytest.mark.parametrize( + ( + "categorical_covariate_keys", + "continuous_covariate_keys", + "pseudoinputs_data_indices", + "embed_cat", + "weight_batches" + ), + [ + # Check different covariate combinations + (["covariate_cat"], ["covariate_cont"], None, False, False), + (["covariate_cat"], ["covariate_cont"], None, True, False), + (["covariate_cat"], None, None, False, False), + (["covariate_cat"], None, None, True, False), + (None, ["covariate_cont"], None, False, False), + # Check pre-specifying pseudoinputs + (None, None, np.array(list(range(5))), False, False), + + ], +) +def test_model( + categorical_covariate_keys, + continuous_covariate_keys, + pseudoinputs_data_indices, + embed_cat, + weight_batches +): adata = mock_adata() - SysVI.setup_anndata( - adata, - batch_key="system", - categorical_covariate_keys=["covariate_cat"], - categorical_covariate_embed_keys=["covariate_cat_emb"], - continuous_covariate_keys=["covariate_cont"], - covariate_categ_orders=adata0.uns["covariate_categ_orders"], - covariate_key_orders=adata0.uns["covariate_key_orders"], - batch_order=adata0.uns["batch_order"], - ) - # Check that setup of adata without covariates works - adata_no_cov = mock_adata() + # Run adata setup SysVI.setup_anndata( - adata_no_cov, - batch_key="system", + adata, + batch_key="batch", + categorical_covariate_keys=categorical_covariate_keys, + continuous_covariate_keys=continuous_covariate_keys, + weight_batches=weight_batches, ) - assert "covariates" not in adata_no_cov.obsm - assert "covariates_embed" not in adata_no_cov.obsm # Model @@ -91,42 +93,14 @@ def test_model(): model = SysVI(adata=adata, prior="standard_normal") model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) - # Check that mode runs through without covariates - model = SysVI(adata=adata_no_cov) - model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) - - # Check pre-specifying pseudoinput indices for vamp prior - _ = SysVI( + # Check that model runs through with vamp prior + model = SysVI( adata=adata, prior="vamp", - pseudoinputs_data_indices=np.array(list(range(5))), - n_prior_components=5, - ) - - # Check that model runs through with vamp prior without specifying pseudoinput indices, - # all covariates, and weight scaling - model = SysVI(adata=adata, prior="vamp") - model.train( - max_epochs=2, - batch_size=math.ceil(adata.n_obs / 2.0), - log_every_n_steps=1, - check_val_every_n_epoch=1, - val_check_interval=1, - plan_kwargs={ - "log_on_epoch": False, - "log_on_step": True, - "loss_weights": { - "kl_weight": 2, - "z_distance_cycle_weight": { - "weight_start": 1, - "weight_end": 3, - "point_start": 1, - "point_end": 3, - "update_on": "step", - }, - }, - }, + pseudoinputs_data_indices=pseudoinputs_data_indices, + n_prior_components=5 ) + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) # Embedding @@ -138,13 +112,21 @@ def test_model(): == adata.shape[0] ) - # Ensure that embedding with another adata properly checks if it was setu up correctly - _ = model.get_latent_representation(adata=adata0) - with assert_raises(KeyError): - # TODO could add more check for each property separately - _ = model.get_latent_representation(adata=adata_no_cov) - # Check that indices in embedding works +def test_latent_representation(): + # Train model + adata = mock_adata() + SysVI.setup_anndata( + adata, + batch_key="batch", + categorical_covariate_keys=None, + continuous_covariate_keys=None, + weight_batches=False, + ) + model = SysVI(adata=adata, prior="standard_normal") + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) + + # Check that specifying indices in embedding works idx = [1, 2, 3] embed = model.get_latent_representation( adata=adata, @@ -153,7 +135,7 @@ def test_model(): ) assert embed.shape[0] == 3 - # Check predicting mean/sample + # Check predicting mean vs sample np.testing.assert_allclose( embed, model.get_latent_representation( @@ -171,3 +153,45 @@ def test_model(): give_mean=False, ), ) + + # Test returning distn + mean, var = model.get_latent_representation( + adata=adata, + indices=idx, + return_dist=True, + ) + np.testing.assert_allclose(embed, mean) + + +def test_warnings(): + # Train model + adata = mock_adata() + SysVI.setup_anndata( + adata, + batch_key="batch", + categorical_covariate_keys=None, + continuous_covariate_keys=None, + weight_batches=False, + ) + model = SysVI(adata=adata, prior="standard_normal") + + # Assert that warning is printed if kl warmup is used + # Step warmup + with pytest.warns(Warning) as record: + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={'n_steps_kl_warmup': 1}) + assert any([ + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record]) + # Epoch warmup + with pytest.warns(Warning) as record: + model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={'n_epochs_kl_warmup': 1}) + assert any([ + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record]) + + # Asert that sampling is disabled + with pytest.raises(NotImplementedError): + model.module.sample() + From 7743c1923913b43d77be23199dbb434ebb64ae59 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:39:27 +0000 Subject: [PATCH 44/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_base_components.py | 22 ++----- src/scvi/external/sysvi/_model.py | 28 ++++---- src/scvi/external/sysvi/_module.py | 71 +++++++++++---------- src/scvi/external/sysvi/_priors.py | 21 +++--- tests/external/sysvi/test_model.py | 43 ++++++++----- 5 files changed, 99 insertions(+), 86 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 239cbfe532..ba223a7911 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -6,19 +6,13 @@ from typing import Literal import numpy as np - import torch from torch import nn from torch.distributions import Normal from torch.nn import ( - BatchNorm1d, - Dropout, - LayerNorm, Linear, Module, Parameter, - ReLU, - Sequential, ) @@ -85,11 +79,10 @@ def forward( cont: torch.Tensor | None = None, cat_list: list | None = None, ) -> dict[str, torch.Tensor]: - y = self.decoder_y(x=x, cont=cont, cat_list=cat_list) y_m = self.mean_encoder(y) if y_m.isnan().any() or y_m.isinf().any(): - warnings.warn('Predicted mean contains nan or inf values. Setting to numerical.') + warnings.warn("Predicted mean contains nan or inf values. Setting to numerical.") y_m = torch.nan_to_num(y_m) y_v = self.var_encoder(y) @@ -196,8 +189,8 @@ def __init__( ), ) for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:], strict=True) - ) + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) ] ) ) @@ -214,7 +207,7 @@ def set_online_update_hooks(self, hook_first_layer=True): def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] + new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] return new_grad def _hook_fn_zero_out(grad): @@ -234,10 +227,7 @@ def _hook_fn_zero_out(grad): self.hooks.append(b) def forward( - self, - x: torch.Tensor, - cont: torch.Tensor | None = None, - cat_list: list | None = None + self, x: torch.Tensor, cont: torch.Tensor | None = None, cat_list: list | None = None ) -> torch.Tensor: """Forward computation on ``x``. @@ -348,7 +338,7 @@ def forward(self, x: torch.Tensor): v = torch.clip(v, min=-self.clip_exp, max=self.clip_exp) v = self.activation(v) if v.isnan().any(): - warnings.warn('Predicted variance contains nan values. Setting to 0.') + warnings.warn("Predicted variance contains nan values. Setting to 0.") v = torch.nan_to_num(v) return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 5218510130..70749136e7 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -6,15 +6,17 @@ from typing import Literal import numpy as np -import pandas as pd import torch from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager from scvi.data.fields import ( + CategoricalJointObsField, + CategoricalObsField, LayerField, - ObsmField, CategoricalObsField, CategoricalJointObsField, NumericalJointObsField, NumericalObsField, + NumericalJointObsField, + NumericalObsField, ) from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.utils import setup_anndata_dsp @@ -105,19 +107,18 @@ def train( **kwargs, ): plan_kwargs = plan_kwargs or {} - kl_weight_defaults = { - 'n_epochs_kl_warmup': 0, - 'n_steps_kl_warmup': 0 - } + kl_weight_defaults = {"n_epochs_kl_warmup": 0, "n_steps_kl_warmup": 0} if any([v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()]): - warnings.warn('The use of KL weight warmup is not recommended in SysVI. ' + - 'The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0.') + warnings.warn( + "The use of KL weight warmup is not recommended in SysVI. " + + "The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0." + ) # Overwrite plan kwargs with kl weight defaults plan_kwargs = {**plan_kwargs, **kl_weight_defaults} # Pass to parent kwargs = kwargs or {} - kwargs['plan_kwargs'] = plan_kwargs + kwargs["plan_kwargs"] = plan_kwargs super().train(*args, **kwargs) @torch.inference_mode() @@ -231,10 +232,11 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] if weight_batches: - warnings.warn('The use of inverse batch proportion weights is experimental.') - batch_weights_key = 'batch_weights' - adata.obs[batch_weights_key] = adata.obs[batch_key].map({ - cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()}) + warnings.warn("The use of inverse batch proportion weights is experimental.") + batch_weights_key = "batch_weights" + adata.obs[batch_weights_key] = adata.obs[batch_key].map( + {cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()} + ) anndata_fields.append(NumericalObsField(batch_weights_key, batch_weights_key)) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index eb6335e1f1..2f74a2e6da 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -77,7 +77,6 @@ def __init__( out_var_mode: Literal["sample_feature", "feature"] = "feature", enc_dec_kwargs: dict | None = None, embedding_kwargs: dict | None = None, - ): super().__init__() @@ -127,7 +126,9 @@ def __init__( if prior == "standard_normal": self.prior = StandardPrior() elif prior == "vamp": - assert pseudoinput_data is not None, 'Pseudoinput data must be specified if using VampPrior' + assert ( + pseudoinput_data is not None + ), "Pseudoinput data must be specified if using VampPrior" pseudoinput_data = self._get_inference_input(pseudoinput_data) self.prior = VampPrior( n_components=n_prior_components, @@ -153,8 +154,8 @@ def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]: input_dict = { "expr": tensors[REGISTRY_KEYS.X_KEY], "batch": tensors[REGISTRY_KEYS.BATCH_KEY], - "cat": cov['cat'], - "cont": cov['cont'] + "cat": cov["cat"], + "cont": cov["cont"], } return input_dict @@ -166,8 +167,8 @@ def _get_inference_cycle_input( input_dict = { "expr": generative_outputs["y_m"], "batch": selected_batch, - "cat": cov['cat'], - "cont": cov['cont'] + "cat": cov["cat"], + "cont": cov["cont"], } return input_dict @@ -176,26 +177,26 @@ def _get_generative_input( tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor], selected_batch: torch.Tensor, - **kwargs + **kwargs, ) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]: """Parse the input tensors, inference inputs, and cycle batch to get generative inputs""" z = inference_outputs["z"] cov = self._get_cov(tensors=tensors) cov_mock = self._mock_cov(cov) - cat = {'x': cov['cat'], 'y': cov_mock['cat']} - cont = {'x': cov['cont'], 'y': cov_mock['cont']} + cat = {"x": cov["cat"], "y": cov_mock["cat"]} + cont = {"x": cov["cont"], "y": cov_mock["cont"]} batch = {"x": tensors["batch"], "y": selected_batch} - input_dict = {"z": z, "batch": batch, 'cat': cat, 'cont': cont} + input_dict = {"z": z, "batch": batch, "cat": cat, "cont": cont} return input_dict @auto_move_data # TODO remove? - def _get_cov(self, tensors: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + def _get_cov( + self, tensors: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: """Process all covariates into continuous and categorical components for cVAE""" - cat_parts = [] cont_parts = [] if REGISTRY_KEYS.CONT_COVS_KEY in tensors: @@ -204,14 +205,12 @@ def _get_cov(self, tensors: dict[str, torch.Tensor] cat = torch.split(tensors[REGISTRY_KEYS.CAT_COVS_KEY], 1, dim=1) if self.embed_cat: for idx, tensor in enumerate(cat): - cont_parts.append( - self.compute_embedding(tensor, self._cov_idx_name(idx)) - ) + cont_parts.append(self.compute_embedding(tensor, self._cov_idx_name(idx))) else: cat_parts.extend(cat) cov = { - 'cat': cat_parts, - 'cont': torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None + "cat": cat_parts, + "cont": torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None, } return cov @@ -223,18 +222,19 @@ def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch def _mock_cov(cov: dict[str, list[torch.Tensor], torch.Tensor, None]) -> torch.Tensor | None: """Make mock (all 0) covariates for cycle""" mock = { - 'cat': [torch.zeros_like(cat) for cat in cov['cat']], - 'cont': torch.zeros_like(cov['cont']) if cov['cont'] is not None else None, + "cat": [torch.zeros_like(cat) for cat in cov["cat"]], + "cont": torch.zeros_like(cov["cont"]) if cov["cont"] is not None else None, } return mock @auto_move_data - def inference(self, - expr: torch.Tensor, - batch: torch.Tensor, - cat: list[torch.Tensor], - cont: torch.Tensor | None, - ) -> dict: + def inference( + self, + expr: torch.Tensor, + batch: torch.Tensor, + cat: list[torch.Tensor], + cont: torch.Tensor | None, + ) -> dict: """ expression & cov -> latent representation @@ -287,14 +287,20 @@ def outputs( cont: torch.Tensor | None, ): if compute: - res_sub = self.decoder(x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) + res_sub = self.decoder( + x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont + ) res[name] = res_sub["y"] res[name + "_m"] = res_sub["y_m"] res[name + "_v"] = res_sub["y_v"] res = {} - outputs(compute=x_x, name="x", res=res, x=z, batch=batch['x'], cat=cat['x'], cont=cont['x']) - outputs(compute=x_y, name="y", res=res, x=z, batch=batch['y'], cat=cat['y'], cont=cont['y']) + outputs( + compute=x_x, name="x", res=res, x=z, batch=batch["x"], cat=cat["x"], cont=cont["x"] + ) + outputs( + compute=x_y, name="y", res=res, x=z, batch=batch["y"], cat=cat["y"], cont=cont["y"] + ) return res @auto_move_data @@ -394,7 +400,6 @@ def loss( reconstruction_weight: float = 1.0, z_distance_cycle_weight: float = 2.0, ) -> LossOutput: - # Reconstruction loss x_true = tensors[REGISTRY_KEYS.X_KEY] reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( @@ -428,8 +433,8 @@ def standardize(x): ) z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) - if 'batch_weights' in tensors.keys(): - z_distance_cyc *= tensors['batch_weights'].flatten() + if "batch_weights" in tensors.keys(): + z_distance_cyc *= tensors["batch_weights"].flatten() loss = ( reconst_loss * reconstruction_weight @@ -441,7 +446,7 @@ def standardize(x): loss=loss.mean(), reconstruction_loss=reconst_loss, kl_local=kl_divergence_z, - extra_metrics={'cycle_loss':z_distance_cyc.mean()}, + extra_metrics={"cycle_loss": z_distance_cyc.mean()}, ) def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py index b2d6361540..0118fee3f7 100644 --- a/src/scvi/external/sysvi/_priors.py +++ b/src/scvi/external/sysvi/_priors.py @@ -45,7 +45,7 @@ def __init__( data_x: torch.tensor, n_cat_list: list[int], data_cat: list[torch.tensor], - data_cont: torch.tensor|None=None, + data_cont: torch.tensor | None = None, trainable_priors: bool = True, ): super().__init__() @@ -63,18 +63,23 @@ def __init__( # from which we can sample categories for layers input # Initialise the multinomial distn weights based on # one-hot encoding of pseudoinput categories - self.u_cat = torch.nn.ParameterList([ - torch.nn.Parameter( - torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot - requires_grad=trainable_priors) - for cat, n in zip(data_cat, n_cat_list) # K x C_cat - ]) + self.u_cat = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot + requires_grad=trainable_priors, + ) + for cat, n in zip(data_cat, n_cat_list, strict=False) # K x C_cat + ] + ) # Cont if data_cont is None: self.u_cont = None else: assert n_components == data_cont.shape[0] - self.u_cont = torch.nn.Parameter(data_cont, requires_grad=trainable_priors) # K x C_cont + self.u_cont = torch.nn.Parameter( + data_cont, requires_grad=trainable_priors + ) # K x C_cont # mixing weights self.w = torch.nn.Parameter(torch.zeros(n_components, 1, 1)) # K x 1 x 1 diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index c3b9e4c5dc..a878c1ffa9 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -1,7 +1,8 @@ import math -import pytest + import numpy as np import pandas as pd +import pytest from anndata import AnnData from numpy.testing import assert_raises from scipy import sparse @@ -55,7 +56,7 @@ def mock_adata(): "continuous_covariate_keys", "pseudoinputs_data_indices", "embed_cat", - "weight_batches" + "weight_batches", ), [ # Check different covariate combinations @@ -66,7 +67,6 @@ def mock_adata(): (None, ["covariate_cont"], None, False, False), # Check pre-specifying pseudoinputs (None, None, np.array(list(range(5))), False, False), - ], ) def test_model( @@ -74,7 +74,7 @@ def test_model( continuous_covariate_keys, pseudoinputs_data_indices, embed_cat, - weight_batches + weight_batches, ): adata = mock_adata() @@ -98,7 +98,7 @@ def test_model( adata=adata, prior="vamp", pseudoinputs_data_indices=pseudoinputs_data_indices, - n_prior_components=5 + n_prior_components=5, ) model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) @@ -178,20 +178,31 @@ def test_warnings(): # Assert that warning is printed if kl warmup is used # Step warmup with pytest.warns(Warning) as record: - model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0), - plan_kwargs={'n_steps_kl_warmup': 1}) - assert any([ - "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) - for rec in record]) + model.train( + max_epochs=2, + batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={"n_steps_kl_warmup": 1}, + ) + assert any( + [ + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record + ] + ) # Epoch warmup with pytest.warns(Warning) as record: - model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0), - plan_kwargs={'n_epochs_kl_warmup': 1}) - assert any([ - "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) - for rec in record]) + model.train( + max_epochs=2, + batch_size=math.ceil(adata.n_obs / 2.0), + plan_kwargs={"n_epochs_kl_warmup": 1}, + ) + assert any( + [ + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record + ] + ) # Asert that sampling is disabled with pytest.raises(NotImplementedError): model.module.sample() - From 1641654bfefb05128719bfba644d998bcb679df8 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 26 Oct 2024 07:07:16 +0200 Subject: [PATCH 45/60] use adatamanager to access filed statistics --- src/scvi/external/sysvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 5218510130..966797eba0 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -56,7 +56,7 @@ def __init__( if prior == "vamp": if pseudoinputs_data_indices is None: pseudoinputs_data_indices = np.random.randint( - 0, adata.shape[0], n_prior_components + 0, self.summary_stats.n_vars, n_prior_components ) assert pseudoinputs_data_indices.shape[0] == n_prior_components assert pseudoinputs_data_indices.ndim == 1 From 7c3bddc80f3dba5fa2631916a32444c7eb9e220b Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sat, 26 Oct 2024 07:07:35 +0200 Subject: [PATCH 46/60] documentation --- src/scvi/external/sysvi/_base_components.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 239cbfe532..31bfb388fb 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -301,11 +301,14 @@ class VarEncoder(Module): n_input Number of input dimensions, used if mode is sample_feature n_output - Number of variances to predict + Number of variances to predict, matching the desired number of features + (e.g. latent dimensions for variational encoding or output features + for variational decoding) mode - How to compute var - 'sample_feature' - learn per sample and feature - 'feature' - learn per feature, constant across samples + How to compute variance. + One of the following: + * ```'sample_feature'``` - learn variance per sample and feature + * ```'feature'``` - learn variance per feature, constant across samples """ def __init__( @@ -332,7 +335,7 @@ def forward(self, x: torch.Tensor): Parameters ---------- x - Used to encode var if mode is sample_feature; dim = n_samples x n_input + Used to encode variance if mode is sample_feature; dim = n_samples x n_input Returns ------- From 5015e8219b4099efe18d7110e49a95c4e34dd295 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 14:26:22 +0100 Subject: [PATCH 47/60] optionally change var activation function --- src/scvi/external/sysvi/_base_components.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index a5d1ceeb9a..f91db32264 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -3,7 +3,7 @@ import collections import warnings from collections.abc import Iterable -from typing import Literal +from typing import Literal, Callable import numpy as np import torch @@ -54,6 +54,7 @@ def __init__( n_hidden: int = 256, n_layers: int = 3, var_mode: Literal["sample_feature", "feature"] = "feature", + var_activation: Callable | None = None, sample: bool = False, **kwargs, ): @@ -71,7 +72,7 @@ def __init__( ) self.mean_encoder = Linear(n_hidden, n_output) - self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode) + self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, activation=var_activation) def forward( self, @@ -189,8 +190,8 @@ def __init__( ), ) for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:], strict=True) - ) + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) ] ) ) @@ -207,7 +208,7 @@ def set_online_update_hooks(self, hook_first_layer=True): def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] + new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] return new_grad def _hook_fn_zero_out(grad): @@ -306,6 +307,7 @@ def __init__( n_input: int, n_output: int, mode: Literal["sample_feature", "feature", "linear"], + activation: Callable | None = None, ): super().__init__() @@ -317,7 +319,7 @@ def __init__( self.var_param = Parameter(torch.zeros(1, n_output)) else: raise ValueError("Mode not recognised.") - self.activation = torch.exp + self.activation = torch.exp if activation is None else activation def forward(self, x: torch.Tensor): """Forward pass through model From b64b2bafa31379a1f07f492ee460d98a42b564b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Oct 2024 13:26:44 +0000 Subject: [PATCH 48/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_base_components.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index f91db32264..5cba834714 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -2,8 +2,8 @@ import collections import warnings -from collections.abc import Iterable -from typing import Literal, Callable +from collections.abc import Callable, Iterable +from typing import Literal import numpy as np import torch @@ -190,8 +190,8 @@ def __init__( ), ) for i, (n_in, n_out) in enumerate( - zip(layers_dim[:-1], layers_dim[1:], strict=True) - ) + zip(layers_dim[:-1], layers_dim[1:], strict=True) + ) ] ) ) @@ -208,7 +208,7 @@ def set_online_update_hooks(self, hook_first_layer=True): def _hook_fn_weight(grad): new_grad = torch.zeros_like(grad) if self.n_cov > 0: - new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:] + new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :] return new_grad def _hook_fn_zero_out(grad): From 9b78ba39802827326272da74bb923c5f7fc30d19 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:10:52 +0100 Subject: [PATCH 49/60] documentation improvements --- src/scvi/external/sysvi/_base_components.py | 86 ++++++--- src/scvi/external/sysvi/_model.py | 3 - src/scvi/external/sysvi/_module.py | 186 ++++++++++++++++---- 3 files changed, 208 insertions(+), 67 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index f91db32264..9e3a69ee80 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -17,32 +17,36 @@ class EncoderDecoder(Module): - """Module that can be used as probabilistic encoder or decoder + """Module that can be used as probabilistic encoder or decoder. - Based on inputs and optional covariates predicts output mean and var + Based on inputs and optional covariates predicts output mean and variance. Parameters ---------- n_input - The dimensionality of the main input + The dimensionality of the main input. n_output - The dimensionality of the output - n_cov - Dimensionality of covariates. - If there are no cov this should be set to None - - in this case cov will not be used. + The dimensionality of the output. + n_cat_list + A list containing the number of categories for each covariate. + n_cont + The dimensionality of the continuous covariates. n_hidden - The number of fully-connected hidden layers + The number of nodes per hidden layer. n_layers - Number of hidden layers + The number of hidden layers. var_mode - How to compute variance from model outputs, see :class:`~scvi.external.sysvi.VarEncoder` - 'sample_feature' - learn per sample and feature - 'feature' - learn per feature, constant across samples + How to compute variance from model outputs, see :class:`~scvi.external.sysvi.VarEncoder`. + One of the following: + * ```'sample_feature'``` - learn variance per sample and feature. + * ```'feature'``` - learn variance per feature, constant across samples. + var_activation + Function used to ensure positivity of the variance. + Defaults to :meth:`torch.exp`. sample - Return samples from predicted distribution + Return samples from predicted distribution. kwargs - Passed to :class:`~scvi.external.sysvi.Layers` + Passed to :class:`~scvi.external.sysvi.Layers`. """ def __init__( @@ -78,8 +82,29 @@ def forward( self, x: torch.Tensor, cont: torch.Tensor | None = None, - cat_list: list | None = None, + cat_list: list[torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: + """Forward pass. + + Parameters + ---------- + x + Main input (i.e. expression for encoder or latent embedding for decoder.). + dim = n_samples * n_input + cont + Optional continuous covariates. + dim = n_samples * n_cont + cat_list + List of optional categorical covariates. + Will be one hot encoded in `~scvi.nn.FCLayers`. + Each list entry is of dim = n_samples * 1 + + Returns + ------- + Predicted mean (``'y_m'``) and variance (``'y_v'``) and + optionally samples (``'y'``) from normal distribution + parametrized with the predicted parameters. + """ y = self.decoder_y(x=x, cont=cont, cat_list=cat_list) y_m = self.mean_encoder(y) if y_m.isnan().any() or y_m.isinf().any(): @@ -97,14 +122,14 @@ def forward( class FCLayers(nn.Module): - """FCLayers class of scvi-tools adapted to also inject continous covariates. + """A helper class to build fully-connected layers for a neural network. + + FCLayers class of scvi-tools adapted to also inject continous covariates. The only adaptation is addition of `n_cont` parameter in init and `cont` in forward, with the associated handling of the two. The forward method signature is changed to account for optional `cont`. - A helper class to build fully-connected layers for a neural network. - Parameters ---------- n_in @@ -290,16 +315,18 @@ class VarEncoder(Module): Parameters ---------- n_input - Number of input dimensions, used if mode is sample_feature + Number of input dimensions. + Used if mode is ``'sample_feature'`` to construct a network predicting + variance from input features. n_output Number of variances to predict, matching the desired number of features (e.g. latent dimensions for variational encoding or output features - for variational decoding) + for variational decoding). mode How to compute variance. One of the following: - * ```'sample_feature'``` - learn variance per sample and feature - * ```'feature'``` - learn variance per feature, constant across samples + * ``'sample_feature'`` - learn variance per sample and feature. + * ``'feature'`` - learn variance per feature, constant across samples. """ def __init__( @@ -321,17 +348,22 @@ def __init__( raise ValueError("Mode not recognised.") self.activation = torch.exp if activation is None else activation - def forward(self, x: torch.Tensor): - """Forward pass through model + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through model. Parameters ---------- x - Used to encode variance if mode is sample_feature; dim = n_samples x n_input + Used to encode variance if mode is ``'sample_feature'``. + dim = n_samples x n_input Returns ------- - Predicted var + Predicted variance + dim = n_samples * 1 """ if self.mode == "sample_feature": v = self.encoder(x) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index dd10fcdf30..f9bda90346 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -95,7 +95,6 @@ def __init__( self._model_summary_string = ( "SysVI - cVAE model with optional VampPrior and latent cycle-consistency loss." ) - # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") @@ -151,7 +150,6 @@ def get_latent_representation( ------- Latent Embedding """ - # Check model and adata self._check_if_trained(warn=False) adata = self._validate_anndata(adata) if indices is None: @@ -163,7 +161,6 @@ def get_latent_representation( predicted_m = [] predicted_v = [] for tensors in tensors_fwd: - # Inference inference_inputs = self.module._get_inference_input(tensors) inference_outputs = self.module.inference(**inference_inputs) if give_mean or return_dist: diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 2f74a2e6da..95f6e7ceb9 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -16,48 +16,57 @@ class SysVAE(BaseModuleClass): """CVAE with optional VampPrior and latent cycle consistency loss. + Described in `Hrovatin et al. (2023) `_. + Parameters ---------- n_input Number of input features. - Passed directly from Model. - n_cov_const - Dimensionality of covariate data that will not be further embedded. - Passed directly from Model. - cov_embed_sizes - Number of categories per every cov to be embedded, e.g. [cov1_n_categ, cov2_n_categ, ...]. - Passed directly from Model. n_batch Number of batches. - Passed directly from Model. - cov_embed_dims - Dimension for covariate embedding. + n_continuous_cov + Number of continuous covariates. + n_cats_per_cov + A list of integers containing the number of categories for each categorical covariate. + embed_cat + If ``True`` embeds categorical covariates and batches + into continuously-valued vectors instead of using one-hot encoding. prior Which prior distribution to use. - Passed directly from Model. + * ``'standard_normal'``: Standard normal distribution. + * ``'vamp'``: VampPrior. n_prior_components - If VampPrior - how many prior components to use. - Passed directly from Model. + Number of prior components for VampPrior. trainable_priors - If VampPrior - should prior components be trainable. + Should prior components of VampPrior be trainable. pseudoinput_data - Initialisation data for VampPrior. Should match input tensors structure. - Passed directly from Model. + Initialisation data for VampPrior. + Should match input tensors structure. n_latent Numer of latent space dimensions. n_hidden - Number of nodes in hidden layers. + Numer of hidden nodes per layer for encoder and decoder. n_layers - Number of hidden layers. + Number of hidden layers for encoder and decoder. dropout_rate - Dropout rate. + Dropout rate for encoder and decoder. out_var_mode - See :class:`~scvi.external.sysvi.nn.VarEncoder` + How variance is predicted in decoder, see :class:`~scvi.external.sysvi.nn.VarEncoder`. + One of the following: + * ``'sample_feature'`` - learn variance per sample and feature. + * ``'feature'`` - learn variance per feature, constant across samples. enc_dec_kwargs Additional kwargs passed to encoder and decoder. + For both encoder and decoder :class:`~scvi.external.sysvi.nn.EncoderDecoder` is used. + embedding_kwargs + Keyword arguments passed into :class:`~scvi.nn.Embedding` + if ``embed_cat`` is set to ``True``. """ - # TODO could disable computation of cycle if predefined that cycle wil not be used + # TODO could disable computation of cycle if predefined that cycle loss will not be used. + # Cycle loss is not expected to be disabled in practice for typical use cases. + # As the use of cycle is currently only based on loss kwargs, + # which are specified only later, it can not be inferred here. def __init__( self, @@ -145,11 +154,48 @@ def __init__( raise ValueError("Prior not recognised") @staticmethod - def _cov_idx_name(cov: int | float): + def _cov_idx_name(cov: int) -> str: + """Convert covariate index into a name used for embedding. + + Parameters + ---------- + cov + Covariate index. + + Returns + ------- + Covariate name. + + """ return "cov" + str(cov) - def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]: - """Parse the input tensors to get inference inputs""" + def _get_inference_input( + self, + tensors: dict[str, torch.Tensor], + **kwargs + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """Parse the input tensors to get inference inputs. + + Parameters + ---------- + tensors + Input tensors. + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for inference. + Keys: + * ``'expr'``: Expression. + * ``'batch'``: Batch covariate. + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + If absent returns None. + """ cov = self._get_cov(tensors=tensors) input_dict = { "expr": tensors[REGISTRY_KEYS.X_KEY], @@ -160,9 +206,39 @@ def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]: return input_dict def _get_inference_cycle_input( - self, tensors, generative_outputs, selected_batch: torch.Tensor, **kwargs - ) -> dict[str, torch.Tensor]: - """Parse the input tensors and cycle batch info to get cycle inference inputs.""" + self, + tensors: dict[str, torch.Tensor], + generative_outputs: dict[str, torch.Tensor], + selected_batch: torch.Tensor, **kwargs + ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: + """Parse the input tensors, generative outputs, and cycle batch info to get cycle inference inputs. + + Parameters + ---------- + tensors + Input tensors. + generative_outputs + Outputs of the generative pass. + selected_batch + Batch covariate to be used for the cycle inference. + dim = n_samples * 1 + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for cycle inference. + Keys: + * ``'expr'``: Expression. + * ``'batch'``: Batch covariate. + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + If absent returns None. + + """ cov = self._mock_cov(self._get_cov(tensors=tensors)) input_dict = { "expr": generative_outputs["y_m"], @@ -179,7 +255,36 @@ def _get_generative_input( selected_batch: torch.Tensor, **kwargs, ) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]: - """Parse the input tensors, inference inputs, and cycle batch to get generative inputs""" + """Parse the input tensors, inference outputs, and cycle batch info to get generative inputs. + + Parameters + ---------- + tensors + Input tensors. + inference_outputs + Outputs of the inference pass. + selected_batch + Batch covariate to be used for the cycle expression generation. + dim = n_samples * 1 + kwargs + Not used. Added for inheritance compatibility. + + Returns + ------- + Tensors that can be used for normal and cycle generation. + Keys: + * ``'z'``: Latent representation. + * ``'batch'``: Batch covariates. + Dict with keys ``'x'`` for normal and ``'y'`` for cycle pass. + * ``'cat'``: All covariates that require one-hot encoding. + Dict with keys ``'x'`` for normal and ``'y'`` for cycle pass. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + If absent returns None. + + """ z = inference_outputs["z"] cov = self._get_cov(tensors=tensors) @@ -220,7 +325,14 @@ def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch @staticmethod def _mock_cov(cov: dict[str, list[torch.Tensor], torch.Tensor, None]) -> torch.Tensor | None: - """Make mock (all 0) covariates for cycle""" + """Make mock covariates for cycle. + + In the cycle pass mock covariates are used due to the following assumption: The encoder + and decoder could have trouble learning how covariates from the input system would behave + in the output/predicted (cycle) system if these covariats are not also present in the real + data of the output system (and usually they are not). + However, I did not test passing real input covariates instead of mock covariates in the cycle. + """ mock = { "cat": [torch.zeros_like(cat) for cat in cov["cat"]], "cont": torch.zeros_like(cov["cont"]) if cov["cont"] is not None else None, @@ -373,7 +485,10 @@ def forward( ) inference_cycle_outputs = self.inference(**inference_cycle_inputs, **inference_kwargs) - # Combine outputs of all forward pass components - first and cycle pass + # Combine outputs of all forward pass components (first and cycle pass) into a single dict, + # separately for inference and generative outputs + # Rename keys in outputs of cycle pass to be distinguishable from the first pass + # for the merging into a single dict inference_outputs_merged = dict(**inference_outputs) inference_outputs_merged.update( **{k.replace("z", "z_cyc"): v for k, v in inference_cycle_outputs.items()} @@ -450,16 +565,16 @@ def standardize(x): ) def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: - """For every cell randomly selects a new batch that is different from the original batch + """For every cell randomly selects a new batch that is different from the real batch. Parameters ---------- batch - One hot encoded batch information for each cell + Real batch information for each cell Returns ------- - One hot encoding of newly selected batch for each cell + Newly selected batch for each cell """ # Get available batches - @@ -480,11 +595,8 @@ def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: dtype=col_pairs.dtype, ), ) - new_tensor = torch.zeros_like(available_batches) - # generate batch covariate tensor - new_tensor.scatter_(1, randomly_selected_indices, 1) - return new_tensor + return randomly_selected_indices @torch.inference_mode() def sample(self, *args, **kwargs): From 2aab4b4c0633680caf43b1234bea6833226be05a Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:19:54 +0100 Subject: [PATCH 50/60] Use real instead of mock covariates in cycle For code simplicity the use of mock covariates was replaced with real covariates in cycle. Needs testing on real data. --- src/scvi/external/sysvi/_module.py | 36 +++++++++--------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 95f6e7ceb9..c84b220b5f 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -238,8 +238,11 @@ def _get_inference_cycle_input( Includes continous and embedded categorical covariates. If absent returns None. + Note: cycle covariates differ from the original publication. + Instead of mock covariates the real input covaiates are used in cycle. + """ - cov = self._mock_cov(self._get_cov(tensors=tensors)) + cov = self._get_cov(tensors=tensors) input_dict = { "expr": generative_outputs["y_m"], "batch": selected_batch, @@ -277,24 +280,21 @@ def _get_generative_input( * ``'batch'``: Batch covariates. Dict with keys ``'x'`` for normal and ``'y'`` for cycle pass. * ``'cat'``: All covariates that require one-hot encoding. - Dict with keys ``'x'`` for normal and ``'y'`` for cycle pass. List of tensors with dim = n_samples * 1. If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. If absent returns None. + Note: cycle covariates differ from the original publication. + Instead of mock covariates the real input covaiates are used in cycle. + """ z = inference_outputs["z"] - cov = self._get_cov(tensors=tensors) - cov_mock = self._mock_cov(cov) - cat = {"x": cov["cat"], "y": cov_mock["cat"]} - cont = {"x": cov["cont"], "y": cov_mock["cont"]} - batch = {"x": tensors["batch"], "y": selected_batch} - input_dict = {"z": z, "batch": batch, "cat": cat, "cont": cont} + input_dict = {"z": z, "batch": batch, "cat": cov["cat"], "cont": cov["cont"]} return input_dict @auto_move_data # TODO remove? @@ -323,22 +323,6 @@ def _get_cov( def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch.Tensor]: return [batch] + cat - @staticmethod - def _mock_cov(cov: dict[str, list[torch.Tensor], torch.Tensor, None]) -> torch.Tensor | None: - """Make mock covariates for cycle. - - In the cycle pass mock covariates are used due to the following assumption: The encoder - and decoder could have trouble learning how covariates from the input system would behave - in the output/predicted (cycle) system if these covariats are not also present in the real - data of the output system (and usually they are not). - However, I did not test passing real input covariates instead of mock covariates in the cycle. - """ - mock = { - "cat": [torch.zeros_like(cat) for cat in cov["cat"]], - "cont": torch.zeros_like(cov["cont"]) if cov["cont"] is not None else None, - } - return mock - @auto_move_data def inference( self, @@ -408,10 +392,10 @@ def outputs( res = {} outputs( - compute=x_x, name="x", res=res, x=z, batch=batch["x"], cat=cat["x"], cont=cont["x"] + compute=x_x, name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont ) outputs( - compute=x_y, name="y", res=res, x=z, batch=batch["y"], cat=cat["y"], cont=cont["y"] + compute=x_y, name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont ) return res From 1b50be1fcd93b947b235d15066e5f9b62de38651 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 17:10:42 +0100 Subject: [PATCH 51/60] improve documentation --- src/scvi/external/sysvi/_module.py | 266 ++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 62 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index c84b220b5f..2a421cd08f 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -194,6 +194,7 @@ def _get_inference_input( If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. + Single tensor of dim = n_samples * n_concatenated_cov_features. If absent returns None. """ cov = self._get_cov(tensors=tensors) @@ -236,6 +237,7 @@ def _get_inference_cycle_input( If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. + Single tensor of dim = n_samples * n_concatenated_cov_features. If absent returns None. Note: cycle covariates differ from the original publication. @@ -284,6 +286,7 @@ def _get_generative_input( If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. + Single tensor of dim = n_samples * n_concatenated_cov_features. If absent returns None. Note: cycle covariates differ from the original publication. @@ -299,9 +302,29 @@ def _get_generative_input( @auto_move_data # TODO remove? def _get_cov( - self, tensors: dict[str, torch.Tensor] + self, + tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: - """Process all covariates into continuous and categorical components for cVAE""" + """Process all covariates into continuous and categorical components for cVAE. + + Parameters + ---------- + tensors + Input tensors. + + Returns + ------- + Covariates that can be used for decoder and encoder. + Keys: + * ``'cat'``: All covariates that require one-hot encoding. + List of tensors with dim = n_samples * 1. + If absent returns empty list. + * ``'cont'``: All covariates that are already continous. + Includes continous and embedded categorical covariates. + Single tensor of dim = n_samples * n_concatenated_cov_features. + If absent returns None. + + """ cat_parts = [] cont_parts = [] if REGISTRY_KEYS.CONT_COVS_KEY in tensors: @@ -320,7 +343,26 @@ def _get_cov( return cov @staticmethod - def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch.Tensor]: + def _merge_batch_cov( + cat: list[torch.Tensor], + batch: torch.Tensor, + ) -> list[torch.Tensor]: + """Merge batch and continuous covariates for input into encoder and decoder. + + Parameters + ---------- + cat + Categorical covariates. + List of tensors with dim = n_samples * 1. + batch + Batch covariate. + dim = n_samples * 1 + + Returns + ------- + Single list with batch and categorical covariates. + + """ return [batch] + cat @auto_move_data @@ -330,96 +372,133 @@ def inference( batch: torch.Tensor, cat: list[torch.Tensor], cont: torch.Tensor | None, - ) -> dict: - """ - expression & cov -> latent representation + ) -> dict[str, torch.Tensor]: + """Inference: expression & cov -> latent representation. Parameters ---------- expr - Expression data - cov - Full covariate data (categorical, categorical embedded, and continuous + Expression data. batch - System representation + Batch covariate. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. Returns ------- - Posterior parameters and sample + Predicted mean (``'z_m'``) and variance (``'z_v'``) of the latent distribution + as wll as a sample (``'z'``) from it. + """ z = self.encoder(x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) return {"z": z["y"], "z_m": z["y_m"], "z_v": z["y_v"]} @auto_move_data - def generative(self, z, batch, cat, cont, x_x: bool = True, x_y: bool = True) -> dict: - """ - latent representation & cov -> expression + def generative( + self, + z: torch.Tensor, + batch: dict[str, torch.Tensor], + cat: list[torch.Tensor], + cont: torch.Tensor | None, + x_x: bool = True, + x_y: bool = True + ) -> dict[str, torch.Tensor]: + """Generation: latent representation & cov -> expression. Parameters ---------- z - Latent embedding - cov - Full covariate data (categorical, categorical embedded, and continuous + Latent representation. batch - System representation + Batch covariate for normal (``'x'``) and cycle (``'y'``) generation. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. x_x - Decode to original batch + Decode to original batch. x_y - Decode to replacement batch + Decode to cycle batch. Returns ------- - Decoded distribution parameters and sample + Predicted mean (``'x_m'``) and variance (``'x_v'``) of the expression distribution + as wll as a sample (``'expr'``) from it. Same outputs are returned for the cycle generation + with ``'expr'`` in keys being replaced by ``'y'``. """ def outputs( - compute: bool, name: str, res: dict, - x: torch.Tensor, + expr: torch.Tensor, batch: torch.Tensor, cat: list[torch.Tensor], cont: torch.Tensor | None, ): - if compute: - res_sub = self.decoder( - x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont - ) - res[name] = res_sub["y"] - res[name + "_m"] = res_sub["y_m"] - res[name + "_v"] = res_sub["y_v"] + """Helper to compute generative outputs for normal and cycle pass. + + Adds generative outputs directly to the ``res`` dict. + + Parameters + ---------- + name + Name prepended to the keys added to the ``res`` dict. + res + Dict to store generative outputs in. + Mean is stored in ``'name_m'``, variance to ``'name_v'`` + and sample to ``'name'``. + expr + Expression data. + batch + Batch covariate. + cat + All covariates that require one-hot encoding. + cont + All covariates that are already continous. + Includes continous and embedded categorical covariates. + """ + res_sub = self.decoder( + x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont + ) + res[name] = res_sub["y"] + res[name + "_m"] = res_sub["y_m"] + res[name + "_v"] = res_sub["y_v"] res = {} - outputs( - compute=x_x, name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont - ) - outputs( - compute=x_y, name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont - ) + if x_x: + outputs( + name="expr", res=res, expr=z, batch=batch["expr"], cat=cat, cont=cont + ) + if x_y: + outputs( + name="y", res=res, expr=z, batch=batch["y"], cat=cat, cont=cont + ) return res @auto_move_data def forward( self, - tensors, + tensors: dict[str, torch.Tensor], get_inference_input_kwargs: dict | None = None, get_generative_input_kwargs: dict | None = None, inference_kwargs: dict | None = None, generative_kwargs: dict | None = None, loss_kwargs: dict | None = None, - compute_loss=True, + compute_loss: bool = True, ) -> ( tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], LossOutput] ): - """ - Forward pass through the network. + """Forward pass through the network. Parameters ---------- tensors - tensors to pass through + Input tensors. get_inference_input_kwargs Keyword args for ``_get_inference_input()`` get_generative_input_kwargs @@ -431,13 +510,17 @@ def forward( loss_kwargs Keyword args for ``loss()`` compute_loss - Whether to compute loss on forward pass. This adds - another return value. + Whether to compute loss on forward pass. + + Returns + ------- + Inference outputs, generative outputs of the normal pass, and optionally loss components. + Inference normal and cycle outputs are combined into a single dict. + Thus, the keys of cycle inference outputs are modified by replacing ``'z'`` with ``'z_cyc'``. """ - """Core of the forward call shared by PyTorch- and Jax-based modules.""" - # TODO currently some forward paths are computed despite potentially having loss weight=0 - - # don't compute if not needed + # TODO could disable computation of cycle if cycle loss will not be used (weight = 0). + # Cycle loss is not expected to be disabled in practice for typical use cases. # Parse kwargs inference_kwargs = inference_kwargs or {} @@ -492,13 +575,35 @@ def forward( def loss( self, - tensors, - inference_outputs, - generative_outputs, + tensors: dict[str, torch.Tensor], + inference_outputs: dict[str, torch.Tensor], + generative_outputs: dict[str, torch.Tensor], kl_weight: float = 1.0, reconstruction_weight: float = 1.0, z_distance_cycle_weight: float = 2.0, ) -> LossOutput: + """Compute loss of forward pass. + + Parameters + ---------- + tensors + Input tensors. + inference_outputs + Outputs of normal and cycle inference pass. + generative_outputs + Outputs of the normal generative pass. + kl_weight + Weight for KL loss. + reconstruction_weight + Weight for reconstruction loss. + z_distance_cycle_weight + Weight for cycle loss. + + Returns + ------- + Loss components: + Cycle loss is added to extra metrics as ``'cycle_loss'``. + """ # Reconstruction loss x_true = tensors[REGISTRY_KEYS.X_KEY] reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( @@ -513,23 +618,52 @@ def loss( z=inference_outputs["z"], ) - def z_dist(z_x_m: torch.Tensor, z_y_m: torch.Tensor): - """MSE loss between standardised inputs with standardizer fitted on concatenation of both inputs - - MSE loss should be computed on standardized latent values as else model can learn to cheat the MSE - loss by putting latent parameters to even smaller numbers. + def z_dist( + z_x_m: torch.Tensor, + z_y_m: torch.Tensor, + ) -> torch.Tensor: + """MSE loss between standardised inputs. + + MSE loss should be computed on standardized latent representations + as else model can learn to cheat the MSE loss + by setting the latent representations to smaller numbers. + Standardizer is fitted on concatenation of both inputs. + + Parameters + ---------- + z_x_m + First input. + z_y_m + Second input. + + Returns + ------- + The loss. + dim = n_samples * 1 """ # Standardise data (jointly both z-s) before MSE calculation z = torch.concat([z_x_m, z_y_m]) means = z.mean(dim=0, keepdim=True) stds = z.std(dim=0, keepdim=True) - def standardize(x): + def standardize(x: torch.Tensor) -> torch.Tensor: + """Helper function to standardize a tensor. + + Mean and variance from the outer scope are used for standardization. + + Parameters + ---------- + x + Input tensor. + + Returns + ------- + Standardized tensor. + """ return (x - means) / stds - return torch.nn.MSELoss(reduction="none")(standardize(z_x_m), standardize(z_y_m)).sum( - dim=1 - ) + return torch.nn.MSELoss(reduction="none")( + standardize(z_x_m), standardize(z_y_m)).sum(dim=1) z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) if "batch_weights" in tensors.keys(): @@ -554,18 +688,18 @@ def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: Parameters ---------- batch - Real batch information for each cell + Real batch information for each cell. Returns ------- - Newly selected batch for each cell - + Newly selected batch for each cell. """ # Get available batches - # those that are zero will become nonzero and vice versa batch = torch.nn.functional.one_hot(batch.squeeze(-1), self.n_batch) available_batches = 1 - batch - # Get nonzero indices for each cell + # Get nonzero indices for each cell - batches that differ from the real batch + # and are thus available row_indices, col_indices = torch.nonzero(available_batches, as_tuple=True) col_pairs = col_indices.view(-1, batch.shape[1] - 1) # Select batch for every cell from available batches @@ -584,4 +718,12 @@ def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: @torch.inference_mode() def sample(self, *args, **kwargs): + """Generate expression samples from the posterior generative distribution. + + Not implemented as the use of decoded expression is not recommended for SysVI. + + Raises + ------ + NotImplementedError + """ raise NotImplementedError("The use of decoded expression is not recommended for SysVI.") From 6d32d89c471a97bf0d78e5ec35887f1f36b37413 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:11:05 +0000 Subject: [PATCH 52/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_module.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 2a421cd08f..fde571e529 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -170,9 +170,7 @@ def _cov_idx_name(cov: int) -> str: return "cov" + str(cov) def _get_inference_input( - self, - tensors: dict[str, torch.Tensor], - **kwargs + self, tensors: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: """Parse the input tensors to get inference inputs. @@ -210,7 +208,8 @@ def _get_inference_cycle_input( self, tensors: dict[str, torch.Tensor], generative_outputs: dict[str, torch.Tensor], - selected_batch: torch.Tensor, **kwargs + selected_batch: torch.Tensor, + **kwargs, ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: """Parse the input tensors, generative outputs, and cycle batch info to get cycle inference inputs. @@ -404,7 +403,7 @@ def generative( cat: list[torch.Tensor], cont: torch.Tensor | None, x_x: bool = True, - x_y: bool = True + x_y: bool = True, ) -> dict[str, torch.Tensor]: """Generation: latent representation & cov -> expression. @@ -470,13 +469,9 @@ def outputs( res = {} if x_x: - outputs( - name="expr", res=res, expr=z, batch=batch["expr"], cat=cat, cont=cont - ) + outputs(name="expr", res=res, expr=z, batch=batch["expr"], cat=cat, cont=cont) if x_y: - outputs( - name="y", res=res, expr=z, batch=batch["y"], cat=cat, cont=cont - ) + outputs(name="y", res=res, expr=z, batch=batch["y"], cat=cat, cont=cont) return res @auto_move_data @@ -518,7 +513,6 @@ def forward( Inference normal and cycle outputs are combined into a single dict. Thus, the keys of cycle inference outputs are modified by replacing ``'z'`` with ``'z_cyc'``. """ - # TODO could disable computation of cycle if cycle loss will not be used (weight = 0). # Cycle loss is not expected to be disabled in practice for typical use cases. @@ -662,8 +656,9 @@ def standardize(x: torch.Tensor) -> torch.Tensor: """ return (x - means) / stds - return torch.nn.MSELoss(reduction="none")( - standardize(z_x_m), standardize(z_y_m)).sum(dim=1) + return torch.nn.MSELoss(reduction="none")(standardize(z_x_m), standardize(z_y_m)).sum( + dim=1 + ) z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"]) if "batch_weights" in tensors.keys(): From 2e59be7ad75ed09ce02df5dd1860bac0adcdffdb Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 19:45:04 +0100 Subject: [PATCH 53/60] improve documentation --- src/scvi/external/sysvi/_model.py | 56 ++++++++++---- src/scvi/external/sysvi/_priors.py | 120 ++++++++++++++++++++++++----- tests/external/sysvi/test_model.py | 5 +- 3 files changed, 142 insertions(+), 39 deletions(-) diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index f9bda90346..53a310b287 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -29,18 +29,21 @@ class SysVI(UnsupervisedTrainingMixin, BaseModelClass): """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. + Described in `Hrovatin et al. (2023) `_. + Parameters ---------- adata AnnData object that has been registered via :meth:`~scvi.external.SysVI.setup_anndata`. prior - The prior distribution to be used. You can choose between "standard_normal" and "vamp". + The prior distribution to be used. + You can choose between ``"standard_normal"`` and ``"vamp"``. n_prior_components Number of prior components (i.e. modes) to use in VampPrior. pseudoinputs_data_indices - By default VampPrior pseudoinputs are randomly selected from data. + By default, VampPrior pseudoinputs are randomly selected from data. Alternatively, one can specify pseudoinput indices using this parameter. - The number of specified indices in the input 1D array should match n_prior_components + The number of specified indices in the input 1D array should match ``n_prior_components``. **model_kwargs Keyword args for :class:`~scvi.external.sysvi.SysVAE` """ @@ -105,6 +108,22 @@ def train( plan_kwargs: dict | None = None, **kwargs, ): + """Train the models. + + Overwrites the ``train`` method of class:`~scvi.model.base.UnsupervisedTrainingMixin` + to prevent the use of KL loss warmup (specified in ``plan_kwargs``). + This is disabled as our experiments showed poor integration in the cycle model + when using KL loss warmup. + + Parameters + ---------- + args + Training args. + plan_kwargs + Training plan kwargs. + kwargs + Training kwargs. + """ plan_kwargs = plan_kwargs or {} kl_weight_defaults = {"n_epochs_kl_warmup": 0, "n_steps_kl_warmup": 0} if any([v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()]): @@ -136,19 +155,21 @@ def get_latent_representation( adata Input adata for which latent representation should be obtained. indices - Data indices to embed. If None embedd all cells. + Data indices to embed. If None embedd all samples. give_mean - Return the posterior mean instead of a sample from the posterior. + Return the posterior latent distribution mean instead of a sample from it. Ignored if `return_dist` is ``True``. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. return_dist - If ``True``, returns the mean and variance of the latent distribution. Otherwise, - returns the mean of the latent distribution. + If ``True``, returns the mean and variance of the posterior latent distribution. + Otherwise, returns its mean or a sample from it. Returns ------- - Latent Embedding + Latent representation of a cell. + If ``return_dist`` is ``True``, returns the mean and variance of the posterior latent distribution. + Else, returns the mean or a sample, depending on ``give_mean``. """ self._check_if_trained(warn=False) adata = self._validate_anndata(adata) @@ -191,17 +212,18 @@ def setup_anndata( weight_batches: bool = False, **kwargs, ): - """Prepare adata for input to Model + """Prepare adata for input to SysVI model. Setup distinguishes between two main types of covariates that can be corrected for: - - batch (referred to as "batch" in the original publication): Single categorical covariate that - will be corrected via cycle consistency loss. It will be also used as a condition in cVAE. - This covariate is expected to correspond to stronger batch effects, such as between datasets from the - different sequencing technology or model systems (animal species, in-vitro models, etc.). + - batch (referred to as "system" in the original publication Hrovatin, et al., 2023): + Single categorical covariate that will be corrected via cycle consistency loss. + It will be also used as a condition in cVAE. + This covariate is expected to correspond to stronger batch effects, such as between datasets from + different sequencing technology or model systems (animal species, in-vitro models and tissue, etc.). - covariate (includes both continous and categorical covariates): Additional covariates to be used only as a condition in cVAE, but not corrected via cycle loss. These covariates are expected to correspond to weaker batch effects, such as between datasets from the - same sequencing technology and model batch (animal, in-vitro, etc.) or between samples within a dataset. + same sequencing technology and system (animal, in-vitro, etc.) or between samples within a dataset. Parameters ---------- @@ -215,10 +237,10 @@ def setup_anndata( AnnData layer to use, default is X. Should contain normalized and log+1 transformed expression. categorical_covariate_keys - Name of obs column with additional categorical covariate information. - Will be one hot encoded or embedded, as later defined in the model. + Name of obs columns with additional categorical covariate information. + Will be one hot encoded or embedded, as later defined in the ``SysVI`` model. continuous_covariate_keys - Name of obs column with additional continuous covariate information. + Name of obs columns with additional continuous covariate information. """ setup_method_args = cls._get_setup_method_args(**locals()) diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py index 0118fee3f7..7e8b388a90 100644 --- a/src/scvi/external/sysvi/_priors.py +++ b/src/scvi/external/sysvi/_priors.py @@ -8,13 +8,57 @@ class Prior(torch.nn.Module, abc.ABC): + """Abstract class for prior distributions.""" + @abstractmethod - def kl(self, m_q, v_q, z): + def kl( + self, + m_q: torch.Tensor, + v_q: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + """Compute KL divergence between prior and posterior distribution. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Sample from the posterior distribution. + + Returns + ------- + KL divergence. + """ pass class StandardPrior(Prior): - def kl(self, m_q, v_q, z=None): + """Standard prior distribution.""" + + def kl( + self, + m_q: torch.Tensor, + v_q: torch.Tensor, + z: None = None + ) -> torch.Tensor: + """Compute KL divergence between standard normal prior and the posterior distribution. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Ignored. + + Returns + ------- + KL divergence. + """ # 1 x N return kl_divergence( Normal(m_q, v_q.sqrt()), Normal(torch.zeros_like(m_q), torch.ones_like(v_q)) @@ -22,18 +66,33 @@ def kl(self, m_q, v_q, z=None): class VampPrior(Prior): - """VampPrior adapted from https://github.com/jmtomczak/intro_dgm/blob/main/vaes/vae_priors_example.ipynb + """VampPrior. + + Adapted from a + `blog post `_ + of the original VampPrior author. Parameters ---------- n_components - Prior components + Number of prior components. encoder - The encoder - data - Data for pseudoinputs initialisation tuple(input,covs) + The encoder. + data_x + Expression data for pseudoinputs initialisation. + n_cat_list + The number of categorical covariates and + the number of category levels. + A list containing, for each covariate of interest, + the number of categories. + data_cat + List of categorical covariates for pseudoinputs initialisation. + Includes all covariates that will be one-hot encoded by the ``encoder``, + including the batch. + data_cont + Continuous covariates for pseudoinputs initialisation. trainable_priors - Are pseudoinput parameters trainable or fixed + Are pseudoinput parameters trainable or fixed. """ # K - components, I - inputs, L - latent, N - samples @@ -42,10 +101,10 @@ def __init__( self, n_components: int, encoder: torch.nn.Module, - data_x: torch.tensor, + data_x: torch.Tensor, n_cat_list: list[int], - data_cat: list[torch.tensor], - data_cont: torch.tensor | None = None, + data_cat: list[torch.Tensor], + data_cont: torch.Tensor | None = None, trainable_priors: bool = True, ): super().__init__() @@ -85,12 +144,11 @@ def __init__( self.w = torch.nn.Parameter(torch.zeros(n_components, 1, 1)) # K x 1 x 1 def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get posterior of pseudoinputs + """Get posterior of pseudoinputs. Returns ------- - Posterior mean, var + Posterior representation mean and variance for each pseudoinput. """ # u, u_cov -> encoder -> mean, var original_mode = self.encoder.training @@ -101,23 +159,23 @@ def get_params(self) -> tuple[torch.Tensor, torch.Tensor]: self.encoder.train(original_mode) return z["y_m"], z["y_v"] # (K x L), (K x L) - def log_prob(self, z) -> torch.Tensor: - """ - Log probability of z under the prior + def log_prob(self, z: torch.Tensor) -> torch.Tensor: + """Log probability of posterior sample under the prior. Parameters ---------- z - Latent embedding of samples + Latent embedding of samples. Returns ------- - Log probability of every sample: samples * latent dimensions + Log probability of every sample. + dim = n_samples * n_latent_dimensions """ # Mixture of gaussian computed on K x N x L z = z.unsqueeze(0) # 1 x N x L - # Get pseudoinputs posteriors - prior params + # Get pseudoinputs posteriors which are prior params m_p, v_p = self.get_params() # (K x L), (K x L) m_p = m_p.unsqueeze(1) # K x 1 x L v_p = v_p.unsqueeze(1) # K x 1 x L @@ -131,5 +189,25 @@ def log_prob(self, z) -> torch.Tensor: return log_prob # N x L - def kl(self, m_q, v_q, z): + def kl( + self, + m_q: torch.Tensor, + v_q: torch.Tensor, + z: torch.Tensor + ) -> torch.Tensor: + """Compute KL divergence between VampPrior and the posterior distribution. + + Parameters + ---------- + m_q + Posterior distribution mean. + v_q + Posterior distribution variance. + z + Sample from the posterior distribution. + + Returns + ------- + KL divergence. + """ return (Normal(m_q, v_q.sqrt()).log_prob(z) - self.log_prob(z)).sum(1) diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index a878c1ffa9..537da914b2 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -11,7 +11,7 @@ def mock_adata(): - # Make random data + """Mock data for testing.""" adata = AnnData( sparse.csr_matrix( np.exp( @@ -76,6 +76,7 @@ def test_model( embed_cat, weight_batches, ): + """Test model with different input and parameters settings.""" adata = mock_adata() # Run adata setup @@ -114,6 +115,7 @@ def test_model( def test_latent_representation(): + """Test different parameters for computing later representation.""" # Train model adata = mock_adata() SysVI.setup_anndata( @@ -164,6 +166,7 @@ def test_latent_representation(): def test_warnings(): + """Test that the most important warnings and exceptions are raised.""" # Train model adata = mock_adata() SysVI.setup_anndata( From 4e0988e03f412d937b3ecf63d602fa671c9c68f5 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 19:48:53 +0100 Subject: [PATCH 54/60] fix bug introduced when renaming parameters in generative function --- src/scvi/external/sysvi/_module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 2a421cd08f..08d4a261f5 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -434,7 +434,7 @@ def generative( def outputs( name: str, res: dict, - expr: torch.Tensor, + x: torch.Tensor, batch: torch.Tensor, cat: list[torch.Tensor], cont: torch.Tensor | None, @@ -451,8 +451,8 @@ def outputs( Dict to store generative outputs in. Mean is stored in ``'name_m'``, variance to ``'name_v'`` and sample to ``'name'``. - expr - Expression data. + x + Latent representation. batch Batch covariate. cat @@ -462,7 +462,7 @@ def outputs( Includes continous and embedded categorical covariates. """ res_sub = self.decoder( - x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont + x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont ) res[name] = res_sub["y"] res[name + "_m"] = res_sub["y_m"] @@ -471,11 +471,11 @@ def outputs( res = {} if x_x: outputs( - name="expr", res=res, expr=z, batch=batch["expr"], cat=cat, cont=cont + name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont ) if x_y: outputs( - name="y", res=res, expr=z, batch=batch["y"], cat=cat, cont=cont + name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont ) return res From 3f8be47be5acc47272145fc47066394193f4a011 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Oct 2024 18:49:40 +0000 Subject: [PATCH 55/60] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/sysvi/_module.py | 8 ++------ src/scvi/external/sysvi/_priors.py | 14 ++------------ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 866f8c4067..0c7176fcbd 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -469,13 +469,9 @@ def outputs( res = {} if x_x: - outputs( - name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont - ) + outputs(name="x", res=res, x=z, batch=batch["x"], cat=cat, cont=cont) if x_y: - outputs( - name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont - ) + outputs(name="y", res=res, x=z, batch=batch["y"], cat=cat, cont=cont) return res @auto_move_data diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py index 7e8b388a90..91a9ba22cf 100644 --- a/src/scvi/external/sysvi/_priors.py +++ b/src/scvi/external/sysvi/_priors.py @@ -38,12 +38,7 @@ def kl( class StandardPrior(Prior): """Standard prior distribution.""" - def kl( - self, - m_q: torch.Tensor, - v_q: torch.Tensor, - z: None = None - ) -> torch.Tensor: + def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: None = None) -> torch.Tensor: """Compute KL divergence between standard normal prior and the posterior distribution. Parameters @@ -189,12 +184,7 @@ def log_prob(self, z: torch.Tensor) -> torch.Tensor: return log_prob # N x L - def kl( - self, - m_q: torch.Tensor, - v_q: torch.Tensor, - z: torch.Tensor - ) -> torch.Tensor: + def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """Compute KL divergence between VampPrior and the posterior distribution. Parameters From 776a8e2903549e1d95360b815628ccb5f9a12ed7 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 20:06:43 +0100 Subject: [PATCH 56/60] rename tests to prevent automatic test failure --- tests/external/sysvi/{test_model.py => test_sysvi.py} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename tests/external/sysvi/{test_model.py => test_sysvi.py} (98%) diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_sysvi.py similarity index 98% rename from tests/external/sysvi/test_model.py rename to tests/external/sysvi/test_sysvi.py index 537da914b2..ef19d0f7b2 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_sysvi.py @@ -69,7 +69,7 @@ def mock_adata(): (None, None, np.array(list(range(5))), False, False), ], ) -def test_model( +def test_sysvi_model( categorical_covariate_keys, continuous_covariate_keys, pseudoinputs_data_indices, @@ -114,7 +114,7 @@ def test_model( ) -def test_latent_representation(): +def test_sysvi_latent_representation(): """Test different parameters for computing later representation.""" # Train model adata = mock_adata() @@ -165,7 +165,7 @@ def test_latent_representation(): np.testing.assert_allclose(embed, mean) -def test_warnings(): +def test_sysvi_warnings(): """Test that the most important warnings and exceptions are raised.""" # Train model adata = mock_adata() From 35e3385e558bd43b62ab65f99b842d5e87fc57de Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 20:17:15 +0100 Subject: [PATCH 57/60] fix typo in docstring and formatting --- tests/external/sysvi/test_sysvi.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/external/sysvi/test_sysvi.py b/tests/external/sysvi/test_sysvi.py index ef19d0f7b2..c9b4e8f48c 100644 --- a/tests/external/sysvi/test_sysvi.py +++ b/tests/external/sysvi/test_sysvi.py @@ -11,7 +11,7 @@ def mock_adata(): - """Mock data for testing.""" + """Mock adata for testing.""" adata = AnnData( sparse.csr_matrix( np.exp( @@ -187,10 +187,8 @@ def test_sysvi_warnings(): plan_kwargs={"n_steps_kl_warmup": 1}, ) assert any( - [ - "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) - for rec in record - ] + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record ) # Epoch warmup with pytest.warns(Warning) as record: @@ -200,10 +198,8 @@ def test_sysvi_warnings(): plan_kwargs={"n_epochs_kl_warmup": 1}, ) assert any( - [ - "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) - for rec in record - ] + "The use of KL weight warmup is not recommended in SysVI." in str(rec.message) + for rec in record ) # Asert that sampling is disabled From 4601808b7f7c7e5e77e731532dc542ebb3c2a809 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 27 Oct 2024 21:01:53 +0100 Subject: [PATCH 58/60] ruff fixes --- src/scvi/external/sysvi/_base_components.py | 39 ++++++--- src/scvi/external/sysvi/_model.py | 94 +++++++++++++-------- src/scvi/external/sysvi/_module.py | 89 +++++++++++-------- src/scvi/external/sysvi/_priors.py | 15 ++-- 4 files changed, 153 insertions(+), 84 deletions(-) diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index 9708432c1f..b280ae0bb0 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -2,8 +2,11 @@ import collections import warnings -from collections.abc import Callable, Iterable -from typing import Literal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from typing import Literal import numpy as np import torch @@ -15,6 +18,8 @@ Parameter, ) +from scvi import settings + class EncoderDecoder(Module): """Module that can be used as probabilistic encoder or decoder. @@ -36,7 +41,8 @@ class EncoderDecoder(Module): n_layers The number of hidden layers. var_mode - How to compute variance from model outputs, see :class:`~scvi.external.sysvi.VarEncoder`. + How to compute variance from model outputs, + see :class:`~scvi.external.sysvi.VarEncoder`. One of the following: * ```'sample_feature'``` - learn variance per sample and feature. * ```'feature'``` - learn variance per feature, constant across samples. @@ -89,7 +95,8 @@ def forward( Parameters ---------- x - Main input (i.e. expression for encoder or latent embedding for decoder.). + Main input (i.e. expression for encoder or + latent embedding for decoder.). dim = n_samples * n_input cont Optional continuous covariates. @@ -108,7 +115,10 @@ def forward( y = self.decoder_y(x=x, cont=cont, cat_list=cat_list) y_m = self.mean_encoder(y) if y_m.isnan().any() or y_m.isinf().any(): - warnings.warn("Predicted mean contains nan or inf values. Setting to numerical.") + warnings.warn( + "Predicted mean contains nan or inf values. " + "Setting to numerical.", + stacklevel=settings.warnings_stacklevel, + ) y_m = torch.nan_to_num(y_m) y_v = self.var_encoder(y) @@ -126,8 +136,8 @@ class FCLayers(nn.Module): FCLayers class of scvi-tools adapted to also inject continous covariates. - The only adaptation is addition of `n_cont` parameter in init and `cont` in forward, - with the associated handling of the two. + The only adaptation is addition of `n_cont` parameter in init + and `cont` in forward, with the associated handling of the two. The forward method signature is changed to account for optional `cont`. Parameters @@ -202,7 +212,8 @@ def __init__( n_out, bias=bias, ), - # non-default params come from defaults in original Tensorflow + # non-default params come from defaults + # in original Tensorflow # implementation nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001) if use_batch_norm @@ -272,7 +283,7 @@ def forward( :class:`torch.Tensor` tensor of shape ``(n_out,)`` """ - one_hot_cat_list = [] # for generality in this list many indices useless. + one_hot_cat_list = [] # for generality in this list many idxs useless. cont_list = [cont] if cont is not None else [] cat_list = cat_list or [] @@ -281,7 +292,7 @@ def forward( for n_cat, cat in zip(self.n_cat_list, cat_list, strict=False): if n_cat and cat is None: raise ValueError("cat not provided while n_cat != 0 in init. params.") - if n_cat > 1: # n_cat = 1 will be ignored - no additional information + if n_cat > 1: # n_cat = 1 will be ignored - no additional info if cat.size(1) != n_cat: one_hot_cat = nn.functional.one_hot(cat.squeeze(-1), n_cat) else: @@ -370,12 +381,16 @@ def forward( elif self.mode == "feature": v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size - # Ensure that var is strictly positive via exp - Bring back to non-log scale + # Ensure that var is strictly positive via exp - + # Bring back to non-log scale # Clip to range that will not be inf after exp v = torch.clip(v, min=-self.clip_exp, max=self.clip_exp) v = self.activation(v) if v.isnan().any(): - warnings.warn("Predicted variance contains nan values. Setting to 0.") + warnings.warn( + "Predicted variance contains nan values. Setting to 0.", + stacklevel=settings.warnings_stacklevel, + ) v = torch.nan_to_num(v) return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 53a310b287..d75109b4d6 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -2,14 +2,18 @@ import logging import warnings -from collections.abc import Sequence -from typing import Literal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData import numpy as np import torch -from anndata import AnnData -from scvi import REGISTRY_KEYS +from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data.fields import ( CategoricalJointObsField, @@ -27,14 +31,16 @@ class SysVI(UnsupervisedTrainingMixin, BaseModelClass): - """Integration model based on cVAE with optional VampPrior and latent cycle-consistency loss. + """Integration with cVAE & optional VampPrior and latent cycle-consistency. - Described in `Hrovatin et al. (2023) `_. + Described in + `Hrovatin et al. (2023) `_. Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.external.SysVI.setup_anndata`. + AnnData object that has been registered via + :meth:`~scvi.external.SysVI.setup_anndata`. prior The prior distribution to be used. You can choose between ``"standard_normal"`` and ``"vamp"``. @@ -43,7 +49,8 @@ class SysVI(UnsupervisedTrainingMixin, BaseModelClass): pseudoinputs_data_indices By default, VampPrior pseudoinputs are randomly selected from data. Alternatively, one can specify pseudoinput indices using this parameter. - The number of specified indices in the input 1D array should match ``n_prior_components``. + The number of specified indices in the input 1D array should match + ``n_prior_components``. **model_kwargs Keyword args for :class:`~scvi.external.sysvi.SysVAE` """ @@ -96,7 +103,7 @@ def __init__( ) self._model_summary_string = ( - "SysVI - cVAE model with optional VampPrior and latent cycle-consistency loss." + "SysVI - cVAE model with optional VampPrior " + "and latent cycle-consistency loss." ) self.init_params_ = self._get_init_params(locals()) @@ -110,10 +117,11 @@ def train( ): """Train the models. - Overwrites the ``train`` method of class:`~scvi.model.base.UnsupervisedTrainingMixin` + Overwrites the ``train`` method of + class:`~scvi.model.base.UnsupervisedTrainingMixin` to prevent the use of KL loss warmup (specified in ``plan_kwargs``). - This is disabled as our experiments showed poor integration in the cycle model - when using KL loss warmup. + This is disabled as our experiments showed poor integration in the + cycle model when using KL loss warmup. Parameters ---------- @@ -126,10 +134,12 @@ def train( """ plan_kwargs = plan_kwargs or {} kl_weight_defaults = {"n_epochs_kl_warmup": 0, "n_steps_kl_warmup": 0} - if any([v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()]): + if any(v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()): warnings.warn( "The use of KL weight warmup is not recommended in SysVI. " - + "The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0." + + "The n_epochs_kl_warmup and n_steps_kl_warmup " + + "will be reset to 0.", + stacklevel=settings.warnings_stacklevel, ) # Overwrite plan kwargs with kl weight defaults plan_kwargs = {**plan_kwargs, **kl_weight_defaults} @@ -157,18 +167,22 @@ def get_latent_representation( indices Data indices to embed. If None embedd all samples. give_mean - Return the posterior latent distribution mean instead of a sample from it. + Return the posterior latent distribution mean + instead of a sample from it. Ignored if `return_dist` is ``True``. batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + Minibatch size for data loading into model. + Defaults to `scvi.settings.batch_size`. return_dist - If ``True``, returns the mean and variance of the posterior latent distribution. + If ``True``, returns the mean and variance of the posterior + latent distribution. Otherwise, returns its mean or a sample from it. Returns ------- Latent representation of a cell. - If ``return_dist`` is ``True``, returns the mean and variance of the posterior latent distribution. + If ``return_dist`` is ``True``, returns the mean and variance + of the posterior latent distribution. Else, returns the mean or a sample, depending on ``give_mean``. """ self._check_if_trained(warn=False) @@ -214,16 +228,23 @@ def setup_anndata( ): """Prepare adata for input to SysVI model. - Setup distinguishes between two main types of covariates that can be corrected for: - - batch (referred to as "system" in the original publication Hrovatin, et al., 2023): - Single categorical covariate that will be corrected via cycle consistency loss. - It will be also used as a condition in cVAE. - This covariate is expected to correspond to stronger batch effects, such as between datasets from - different sequencing technology or model systems (animal species, in-vitro models and tissue, etc.). - - covariate (includes both continous and categorical covariates): Additional covariates to be used only - as a condition in cVAE, but not corrected via cycle loss. - These covariates are expected to correspond to weaker batch effects, such as between datasets from the - same sequencing technology and system (animal, in-vitro, etc.) or between samples within a dataset. + Setup distinguishes between two main types of covariates that can be + corrected for: + + - batch - referred to as "system" in the original publication + Hrovatin, et al., 2023): + Single categorical covariate that will be corrected via cycle + consistency loss. + It will be also used as a condition in cVAE. + This covariate is expected to correspond to stronger batch effects, + such as between datasets from different sequencing technology or + model systems (animal species, in-vitro models and tissue, etc.). + - covariate (includes both continous and categorical covariates): + Additional covariates to be used only + as a condition in cVAE, but not corrected via cycle loss. + These covariates are expected to correspond to weaker batch effects, + such as between datasets from the same sequencing technology and + system (animal, in-vitro, etc.) or between samples within a dataset. Parameters ---------- @@ -231,16 +252,20 @@ def setup_anndata( Adata object - will be modified in place. batch_key Name of the obs column with the substantial batch effect covariate, - referred to as batch in the original publication (Hrovatin, et al., 2023). + referred to as batch in the original publication + (Hrovatin, et al., 2023). Should be categorical. layer AnnData layer to use, default is X. Should contain normalized and log+1 transformed expression. categorical_covariate_keys - Name of obs columns with additional categorical covariate information. - Will be one hot encoded or embedded, as later defined in the ``SysVI`` model. + Name of obs columns with additional categorical + covariate information. + Will be one hot encoded or embedded, as later defined in the + ``SysVI`` model. continuous_covariate_keys - Name of obs columns with additional continuous covariate information. + Name of obs columns with additional continuous + covariate information. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -251,7 +276,10 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] if weight_batches: - warnings.warn("The use of inverse batch proportion weights is experimental.") + warnings.warn( + "The use of inverse batch proportion weights " + "is experimental.", + stacklevel=settings.warnings_stacklevel, + ) batch_weights_key = "batch_weights" adata.obs[batch_weights_key] = adata.obs[batch_key].map( {cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()} diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 0c7176fcbd..a1f83d6a22 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Literal import torch @@ -16,7 +19,8 @@ class SysVAE(BaseModuleClass): """CVAE with optional VampPrior and latent cycle consistency loss. - Described in `Hrovatin et al. (2023) `_. + Described in + `Hrovatin et al. (2023) `_. Parameters ---------- @@ -27,7 +31,8 @@ class SysVAE(BaseModuleClass): n_continuous_cov Number of continuous covariates. n_cats_per_cov - A list of integers containing the number of categories for each categorical covariate. + A list of integers containing the number of categories + for each categorical covariate. embed_cat If ``True`` embeds categorical covariates and batches into continuously-valued vectors instead of using one-hot encoding. @@ -51,20 +56,24 @@ class SysVAE(BaseModuleClass): dropout_rate Dropout rate for encoder and decoder. out_var_mode - How variance is predicted in decoder, see :class:`~scvi.external.sysvi.nn.VarEncoder`. + How variance is predicted in decoder, + see :class:`~scvi.external.sysvi.nn.VarEncoder`. One of the following: * ``'sample_feature'`` - learn variance per sample and feature. * ``'feature'`` - learn variance per feature, constant across samples. enc_dec_kwargs Additional kwargs passed to encoder and decoder. - For both encoder and decoder :class:`~scvi.external.sysvi.nn.EncoderDecoder` is used. + For both encoder and decoder + :class:`~scvi.external.sysvi.nn.EncoderDecoder` is used. embedding_kwargs Keyword arguments passed into :class:`~scvi.nn.Embedding` if ``embed_cat`` is set to ``True``. """ - # TODO could disable computation of cycle if predefined that cycle loss will not be used. - # Cycle loss is not expected to be disabled in practice for typical use cases. + # TODO could disable computation of cycle if predefined + # that cycle loss will not be used. + # Cycle loss is not expected to be disabled in practice + # for typical use cases. # As the use of cycle is currently only based on loss kwargs, # which are specified only later, it can not be inferred here. @@ -191,8 +200,10 @@ def _get_inference_input( List of tensors with dim = n_samples * 1. If absent returns empty list. * ``'cont'``: All covariates that are already continous. - Includes continous and embedded categorical covariates. - Single tensor of dim = n_samples * n_concatenated_cov_features. + Includes continous and embedded + categorical covariates. + Single tensor of + dim = n_samples * n_concatenated_cov_features. If absent returns None. """ cov = self._get_cov(tensors=tensors) @@ -211,7 +222,7 @@ def _get_inference_cycle_input( selected_batch: torch.Tensor, **kwargs, ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: - """Parse the input tensors, generative outputs, and cycle batch info to get cycle inference inputs. + """In. tensors, gen. outputs, and cycle batch -> cycle inference inputs. Parameters ---------- @@ -236,7 +247,8 @@ def _get_inference_cycle_input( If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. - Single tensor of dim = n_samples * n_concatenated_cov_features. + Single tensor of + dim = n_samples * n_concatenated_cov_features. If absent returns None. Note: cycle covariates differ from the original publication. @@ -259,7 +271,7 @@ def _get_generative_input( selected_batch: torch.Tensor, **kwargs, ) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]: - """Parse the input tensors, inference outputs, and cycle batch info to get generative inputs. + """In. tensors, inf. outputs, and cycle batch info -> generative inputs. Parameters ---------- @@ -279,13 +291,15 @@ def _get_generative_input( Keys: * ``'z'``: Latent representation. * ``'batch'``: Batch covariates. - Dict with keys ``'x'`` for normal and ``'y'`` for cycle pass. + Dict with keys ``'x'`` for normal and + ``'y'`` for cycle pass. * ``'cat'``: All covariates that require one-hot encoding. List of tensors with dim = n_samples * 1. If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. - Single tensor of dim = n_samples * n_concatenated_cov_features. + Single tensor of + dim = n_samples * n_concatenated_cov_features. If absent returns None. Note: cycle covariates differ from the original publication. @@ -304,7 +318,7 @@ def _get_cov( self, tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor | list[torch.Tensor] | None]: - """Process all covariates into continuous and categorical components for cVAE. + """Process all covs into continuous and categorical components for cVAE. Parameters ---------- @@ -320,7 +334,8 @@ def _get_cov( If absent returns empty list. * ``'cont'``: All covariates that are already continous. Includes continous and embedded categorical covariates. - Single tensor of dim = n_samples * n_concatenated_cov_features. + Single tensor of + dim = n_samples * n_concatenated_cov_features. If absent returns None. """ @@ -346,7 +361,7 @@ def _merge_batch_cov( cat: list[torch.Tensor], batch: torch.Tensor, ) -> list[torch.Tensor]: - """Merge batch and continuous covariates for input into encoder and decoder. + """Merge batch and continuous covs for input into encoder and decoder. Parameters ---------- @@ -388,8 +403,8 @@ def inference( Returns ------- - Predicted mean (``'z_m'``) and variance (``'z_v'``) of the latent distribution - as wll as a sample (``'z'``) from it. + Predicted mean (``'z_m'``) and variance (``'z_v'``) + of the latent distribution as wll as a sample (``'z'``) from it. """ z = self.encoder(x=expr, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont) @@ -425,9 +440,10 @@ def generative( Returns ------- - Predicted mean (``'x_m'``) and variance (``'x_v'``) of the expression distribution - as wll as a sample (``'expr'``) from it. Same outputs are returned for the cycle generation - with ``'expr'`` in keys being replaced by ``'y'``. + Predicted mean (``'x_m'``) and variance (``'x_v'``) + of the expression distribution as wll as a sample (``'x'``) from it. + Same outputs are returned for the cycle generation with ``'x'`` + in keys being replaced by ``'y'``. """ def outputs( @@ -509,12 +525,16 @@ def forward( Returns ------- - Inference outputs, generative outputs of the normal pass, and optionally loss components. + Inference outputs, generative outputs of the normal pass, + and optionally loss components. Inference normal and cycle outputs are combined into a single dict. - Thus, the keys of cycle inference outputs are modified by replacing ``'z'`` with ``'z_cyc'``. + Thus, the keys of cycle inference outputs are modified by replacing + ``'z'`` with ``'z_cyc'``. """ - # TODO could disable computation of cycle if cycle loss will not be used (weight = 0). - # Cycle loss is not expected to be disabled in practice for typical use cases. + # TODO could disable computation of cycle if cycle loss + # will not be used (weight = 0). + # Cycle loss is not expected to be disabled in practice + # for typical use cases. # Parse kwargs inference_kwargs = inference_kwargs or {} @@ -546,9 +566,11 @@ def forward( ) inference_cycle_outputs = self.inference(**inference_cycle_inputs, **inference_kwargs) - # Combine outputs of all forward pass components (first and cycle pass) into a single dict, + # Combine outputs of all forward pass components + # (first and cycle pass) into a single dict, # separately for inference and generative outputs - # Rename keys in outputs of cycle pass to be distinguishable from the first pass + # Rename keys in outputs of cycle pass + # to be distinguishable from the first pass # for the merging into a single dict inference_outputs_merged = dict(**inference_outputs) inference_outputs_merged.update( @@ -678,7 +700,7 @@ def standardize(x: torch.Tensor) -> torch.Tensor: ) def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: - """For every cell randomly selects a new batch that is different from the real batch. + """For each cell randomly selects new batch different from the real one. Parameters ---------- @@ -693,8 +715,8 @@ def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: # those that are zero will become nonzero and vice versa batch = torch.nn.functional.one_hot(batch.squeeze(-1), self.n_batch) available_batches = 1 - batch - # Get nonzero indices for each cell - batches that differ from the real batch - # and are thus available + # Get nonzero indices for each cell - + # batches that differ from the real batch and are thus available row_indices, col_indices = torch.nonzero(available_batches, as_tuple=True) col_pairs = col_indices.view(-1, batch.shape[1] - 1) # Select batch for every cell from available batches @@ -713,9 +735,10 @@ def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor: @torch.inference_mode() def sample(self, *args, **kwargs): - """Generate expression samples from the posterior generative distribution. + """Generate expression samples from posterior generative distribution. - Not implemented as the use of decoded expression is not recommended for SysVI. + Not implemented as the use of decoded expression + is not recommended for SysVI. Raises ------ diff --git a/src/scvi/external/sysvi/_priors.py b/src/scvi/external/sysvi/_priors.py index 91a9ba22cf..caece4a64d 100644 --- a/src/scvi/external/sysvi/_priors.py +++ b/src/scvi/external/sysvi/_priors.py @@ -39,7 +39,7 @@ class StandardPrior(Prior): """Standard prior distribution.""" def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: None = None) -> torch.Tensor: - """Compute KL divergence between standard normal prior and the posterior distribution. + """Compute KL div between std. normal prior and the posterior distn. Parameters ---------- @@ -64,7 +64,8 @@ class VampPrior(Prior): """VampPrior. Adapted from a - `blog post `_ + `blog post + `_ of the original VampPrior author. Parameters @@ -111,7 +112,7 @@ def __init__( assert n_components == data_x.shape[0] self.u = torch.nn.Parameter(data_x, requires_grad=trainable_priors) # K x I # Cat - assert all([cat.shape[0] == n_components for cat in data_cat]) + assert all(cat.shape[0] == n_components for cat in data_cat) # For categorical covariates, since scvi-tools one-hot encodes # them in the layers, we need to create a multinomial distn # from which we can sample categories for layers input @@ -120,10 +121,12 @@ def __init__( self.u_cat = torch.nn.ParameterList( [ torch.nn.Parameter( - torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot + torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), + # K x C_cat_onehot requires_grad=trainable_priors, ) - for cat, n in zip(data_cat, n_cat_list, strict=False) # K x C_cat + for cat, n in zip(data_cat, n_cat_list, strict=False) + # K x C_cat ] ) # Cont @@ -185,7 +188,7 @@ def log_prob(self, z: torch.Tensor) -> torch.Tensor: return log_prob # N x L def kl(self, m_q: torch.Tensor, v_q: torch.Tensor, z: torch.Tensor) -> torch.Tensor: - """Compute KL divergence between VampPrior and the posterior distribution. + """Compute KL div. between VampPrior and the posterior distribution. Parameters ---------- From 87a367e9aa56b0b3d844da4f800c63b2644a2f9f Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:57:42 +0100 Subject: [PATCH 59/60] bugfix in embedding of covariates --- src/scvi/external/sysvi/_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index a1f83d6a22..fd773a7b01 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -8,7 +8,7 @@ import torch from scvi import REGISTRY_KEYS -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseModuleClass, EmbeddingModuleMixin, LossOutput, auto_move_data from ._base_components import EncoderDecoder from ._priors import StandardPrior, VampPrior @@ -16,7 +16,7 @@ torch.backends.cudnn.benchmark = True -class SysVAE(BaseModuleClass): +class SysVAE(BaseModuleClass, EmbeddingModuleMixin): """CVAE with optional VampPrior and latent cycle consistency loss. Described in @@ -347,7 +347,7 @@ def _get_cov( cat = torch.split(tensors[REGISTRY_KEYS.CAT_COVS_KEY], 1, dim=1) if self.embed_cat: for idx, tensor in enumerate(cat): - cont_parts.append(self.compute_embedding(tensor, self._cov_idx_name(idx))) + cont_parts.append(self.compute_embedding(self._cov_idx_name(idx), tensor)) else: cat_parts.extend(cat) cov = { From df4b3356e4a92aca536d5f2d97184bbc2d34b65c Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:57:57 +0100 Subject: [PATCH 60/60] bugfix in test for checking cov embeding --- tests/external/sysvi/test_sysvi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/external/sysvi/test_sysvi.py b/tests/external/sysvi/test_sysvi.py index c9b4e8f48c..c47b22a9d1 100644 --- a/tests/external/sysvi/test_sysvi.py +++ b/tests/external/sysvi/test_sysvi.py @@ -91,7 +91,7 @@ def test_sysvi_model( # Model # Check that model runs through with standard normal prior - model = SysVI(adata=adata, prior="standard_normal") + model = SysVI(adata=adata, prior="standard_normal", embed_cat=embed_cat) model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0)) # Check that model runs through with vamp prior @@ -100,6 +100,7 @@ def test_sysvi_model( prior="vamp", pseudoinputs_data_indices=pseudoinputs_data_indices, n_prior_components=5, + embed_cat=embed_cat, ) model.train(max_epochs=2, batch_size=math.ceil(adata.n_obs / 2.0))