Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Hrovatin committed Oct 26, 2024
2 parents 7c3bddc + 7743c19 commit eb872bc
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 86 deletions.
22 changes: 6 additions & 16 deletions src/scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@
from typing import Literal

import numpy as np

import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import (
BatchNorm1d,
Dropout,
LayerNorm,
Linear,
Module,
Parameter,
ReLU,
Sequential,
)


Expand Down Expand Up @@ -85,11 +79,10 @@ def forward(
cont: torch.Tensor | None = None,
cat_list: list | None = None,
) -> dict[str, torch.Tensor]:

y = self.decoder_y(x=x, cont=cont, cat_list=cat_list)
y_m = self.mean_encoder(y)
if y_m.isnan().any() or y_m.isinf().any():
warnings.warn('Predicted mean contains nan or inf values. Setting to numerical.')
warnings.warn("Predicted mean contains nan or inf values. Setting to numerical.")
y_m = torch.nan_to_num(y_m)
y_v = self.var_encoder(y)

Expand Down Expand Up @@ -196,8 +189,8 @@ def __init__(
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:], strict=True)
)
zip(layers_dim[:-1], layers_dim[1:], strict=True)
)
]
)
)
Expand All @@ -214,7 +207,7 @@ def set_online_update_hooks(self, hook_first_layer=True):
def _hook_fn_weight(grad):
new_grad = torch.zeros_like(grad)
if self.n_cov > 0:
new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:]
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :]
return new_grad

def _hook_fn_zero_out(grad):
Expand All @@ -234,10 +227,7 @@ def _hook_fn_zero_out(grad):
self.hooks.append(b)

def forward(
self,
x: torch.Tensor,
cont: torch.Tensor | None = None,
cat_list: list | None = None
self, x: torch.Tensor, cont: torch.Tensor | None = None, cat_list: list | None = None
) -> torch.Tensor:
"""Forward computation on ``x``.
Expand Down Expand Up @@ -351,7 +341,7 @@ def forward(self, x: torch.Tensor):
v = torch.clip(v, min=-self.clip_exp, max=self.clip_exp)
v = self.activation(v)
if v.isnan().any():
warnings.warn('Predicted variance contains nan values. Setting to 0.')
warnings.warn("Predicted variance contains nan values. Setting to 0.")
v = torch.nan_to_num(v)

return v
28 changes: 15 additions & 13 deletions src/scvi/external/sysvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from typing import Literal

import numpy as np
import pandas as pd
import torch
from anndata import AnnData

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
ObsmField, CategoricalObsField, CategoricalJointObsField, NumericalJointObsField, NumericalObsField,
NumericalJointObsField,
NumericalObsField,
)
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.utils import setup_anndata_dsp
Expand Down Expand Up @@ -105,19 +107,18 @@ def train(
**kwargs,
):
plan_kwargs = plan_kwargs or {}
kl_weight_defaults = {
'n_epochs_kl_warmup': 0,
'n_steps_kl_warmup': 0
}
kl_weight_defaults = {"n_epochs_kl_warmup": 0, "n_steps_kl_warmup": 0}
if any([v != plan_kwargs.get(k, v) for k, v in kl_weight_defaults.items()]):
warnings.warn('The use of KL weight warmup is not recommended in SysVI. ' +
'The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0.')
warnings.warn(
"The use of KL weight warmup is not recommended in SysVI. "
+ "The n_epochs_kl_warmup and n_steps_kl_warmup will be reset to 0."
)
# Overwrite plan kwargs with kl weight defaults
plan_kwargs = {**plan_kwargs, **kl_weight_defaults}

# Pass to parent
kwargs = kwargs or {}
kwargs['plan_kwargs'] = plan_kwargs
kwargs["plan_kwargs"] = plan_kwargs
super().train(*args, **kwargs)

@torch.inference_mode()
Expand Down Expand Up @@ -231,10 +232,11 @@ def setup_anndata(
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
]
if weight_batches:
warnings.warn('The use of inverse batch proportion weights is experimental.')
batch_weights_key = 'batch_weights'
adata.obs[batch_weights_key] = adata.obs[batch_key].map({
cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()})
warnings.warn("The use of inverse batch proportion weights is experimental.")
batch_weights_key = "batch_weights"
adata.obs[batch_weights_key] = adata.obs[batch_key].map(
{cat: 1 / n for cat, n in adata.obs[batch_key].value_counts().items()}
)
anndata_fields.append(NumericalObsField(batch_weights_key, batch_weights_key))
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
Expand Down
71 changes: 38 additions & 33 deletions src/scvi/external/sysvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(
out_var_mode: Literal["sample_feature", "feature"] = "feature",
enc_dec_kwargs: dict | None = None,
embedding_kwargs: dict | None = None,

):
super().__init__()

Expand Down Expand Up @@ -127,7 +126,9 @@ def __init__(
if prior == "standard_normal":
self.prior = StandardPrior()
elif prior == "vamp":
assert pseudoinput_data is not None, 'Pseudoinput data must be specified if using VampPrior'
assert (
pseudoinput_data is not None
), "Pseudoinput data must be specified if using VampPrior"
pseudoinput_data = self._get_inference_input(pseudoinput_data)
self.prior = VampPrior(
n_components=n_prior_components,
Expand All @@ -153,8 +154,8 @@ def _get_inference_input(self, tensors, **kwargs) -> dict[str, torch.Tensor]:
input_dict = {
"expr": tensors[REGISTRY_KEYS.X_KEY],
"batch": tensors[REGISTRY_KEYS.BATCH_KEY],
"cat": cov['cat'],
"cont": cov['cont']
"cat": cov["cat"],
"cont": cov["cont"],
}
return input_dict

Expand All @@ -166,8 +167,8 @@ def _get_inference_cycle_input(
input_dict = {
"expr": generative_outputs["y_m"],
"batch": selected_batch,
"cat": cov['cat'],
"cont": cov['cont']
"cat": cov["cat"],
"cont": cov["cont"],
}
return input_dict

Expand All @@ -176,26 +177,26 @@ def _get_generative_input(
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor],
selected_batch: torch.Tensor,
**kwargs
**kwargs,
) -> dict[str, torch.Tensor | dict[str, torch.Tensor | list[torch.Tensor] | None]]:
"""Parse the input tensors, inference inputs, and cycle batch to get generative inputs"""
z = inference_outputs["z"]

cov = self._get_cov(tensors=tensors)
cov_mock = self._mock_cov(cov)
cat = {'x': cov['cat'], 'y': cov_mock['cat']}
cont = {'x': cov['cont'], 'y': cov_mock['cont']}
cat = {"x": cov["cat"], "y": cov_mock["cat"]}
cont = {"x": cov["cont"], "y": cov_mock["cont"]}

batch = {"x": tensors["batch"], "y": selected_batch}

input_dict = {"z": z, "batch": batch, 'cat': cat, 'cont': cont}
input_dict = {"z": z, "batch": batch, "cat": cat, "cont": cont}
return input_dict

@auto_move_data # TODO remove?
def _get_cov(self, tensors: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor | list[torch.Tensor] | None]:
def _get_cov(
self, tensors: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor | list[torch.Tensor] | None]:
"""Process all covariates into continuous and categorical components for cVAE"""

cat_parts = []
cont_parts = []
if REGISTRY_KEYS.CONT_COVS_KEY in tensors:
Expand All @@ -204,14 +205,12 @@ def _get_cov(self, tensors: dict[str, torch.Tensor]
cat = torch.split(tensors[REGISTRY_KEYS.CAT_COVS_KEY], 1, dim=1)
if self.embed_cat:
for idx, tensor in enumerate(cat):
cont_parts.append(
self.compute_embedding(tensor, self._cov_idx_name(idx))
)
cont_parts.append(self.compute_embedding(tensor, self._cov_idx_name(idx)))
else:
cat_parts.extend(cat)
cov = {
'cat': cat_parts,
'cont': torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None
"cat": cat_parts,
"cont": torch.concat(cont_parts, dim=1) if len(cont_parts) > 0 else None,
}
return cov

Expand All @@ -223,18 +222,19 @@ def _merge_batch_cov(cat: list[torch.Tensor], batch: torch.Tensor) -> list[torch
def _mock_cov(cov: dict[str, list[torch.Tensor], torch.Tensor, None]) -> torch.Tensor | None:
"""Make mock (all 0) covariates for cycle"""
mock = {
'cat': [torch.zeros_like(cat) for cat in cov['cat']],
'cont': torch.zeros_like(cov['cont']) if cov['cont'] is not None else None,
"cat": [torch.zeros_like(cat) for cat in cov["cat"]],
"cont": torch.zeros_like(cov["cont"]) if cov["cont"] is not None else None,
}
return mock

@auto_move_data
def inference(self,
expr: torch.Tensor,
batch: torch.Tensor,
cat: list[torch.Tensor],
cont: torch.Tensor | None,
) -> dict:
def inference(
self,
expr: torch.Tensor,
batch: torch.Tensor,
cat: list[torch.Tensor],
cont: torch.Tensor | None,
) -> dict:
"""
expression & cov -> latent representation
Expand Down Expand Up @@ -287,14 +287,20 @@ def outputs(
cont: torch.Tensor | None,
):
if compute:
res_sub = self.decoder(x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont)
res_sub = self.decoder(
x=x, cat_list=self._merge_batch_cov(cat=cat, batch=batch), cont=cont
)
res[name] = res_sub["y"]
res[name + "_m"] = res_sub["y_m"]
res[name + "_v"] = res_sub["y_v"]

res = {}
outputs(compute=x_x, name="x", res=res, x=z, batch=batch['x'], cat=cat['x'], cont=cont['x'])
outputs(compute=x_y, name="y", res=res, x=z, batch=batch['y'], cat=cat['y'], cont=cont['y'])
outputs(
compute=x_x, name="x", res=res, x=z, batch=batch["x"], cat=cat["x"], cont=cont["x"]
)
outputs(
compute=x_y, name="y", res=res, x=z, batch=batch["y"], cat=cat["y"], cont=cont["y"]
)
return res

@auto_move_data
Expand Down Expand Up @@ -394,7 +400,6 @@ def loss(
reconstruction_weight: float = 1.0,
z_distance_cycle_weight: float = 2.0,
) -> LossOutput:

# Reconstruction loss
x_true = tensors[REGISTRY_KEYS.X_KEY]
reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")(
Expand Down Expand Up @@ -428,8 +433,8 @@ def standardize(x):
)

z_distance_cyc = z_dist(z_x_m=inference_outputs["z_m"], z_y_m=inference_outputs["z_cyc_m"])
if 'batch_weights' in tensors.keys():
z_distance_cyc *= tensors['batch_weights'].flatten()
if "batch_weights" in tensors.keys():
z_distance_cyc *= tensors["batch_weights"].flatten()

loss = (
reconst_loss * reconstruction_weight
Expand All @@ -441,7 +446,7 @@ def standardize(x):
loss=loss.mean(),
reconstruction_loss=reconst_loss,
kl_local=kl_divergence_z,
extra_metrics={'cycle_loss':z_distance_cyc.mean()},
extra_metrics={"cycle_loss": z_distance_cyc.mean()},
)

def random_select_batch(self, batch: torch.Tensor) -> torch.Tensor:
Expand Down
21 changes: 13 additions & 8 deletions src/scvi/external/sysvi/_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
data_x: torch.tensor,
n_cat_list: list[int],
data_cat: list[torch.tensor],
data_cont: torch.tensor|None=None,
data_cont: torch.tensor | None = None,
trainable_priors: bool = True,
):
super().__init__()
Expand All @@ -63,18 +63,23 @@ def __init__(
# from which we can sample categories for layers input
# Initialise the multinomial distn weights based on
# one-hot encoding of pseudoinput categories
self.u_cat = torch.nn.ParameterList([
torch.nn.Parameter(
torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot
requires_grad=trainable_priors)
for cat, n in zip(data_cat, n_cat_list) # K x C_cat
])
self.u_cat = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.nn.functional.one_hot(cat.squeeze(-1), n).float(), # K x C_cat_onehot
requires_grad=trainable_priors,
)
for cat, n in zip(data_cat, n_cat_list, strict=False) # K x C_cat
]
)
# Cont
if data_cont is None:
self.u_cont = None
else:
assert n_components == data_cont.shape[0]
self.u_cont = torch.nn.Parameter(data_cont, requires_grad=trainable_priors) # K x C_cont
self.u_cont = torch.nn.Parameter(
data_cont, requires_grad=trainable_priors
) # K x C_cont

# mixing weights
self.w = torch.nn.Parameter(torch.zeros(n_components, 1, 1)) # K x 1 x 1
Expand Down
Loading

0 comments on commit eb872bc

Please sign in to comment.