Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 7, 2024
1 parent 815555c commit e03b006
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
17 changes: 10 additions & 7 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import warnings
from collections.abc import Iterable as IterableClass
from functools import partial
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -46,7 +45,7 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number, AnnOrMuData
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -590,8 +589,9 @@ def get_accessibility_estimates(
return pd.DataFrame(
imputed,
index=adata.obs_names[indices],
columns=adata["rna"].var_names[self.n_genes :][region_mask] if
type(adata).__name__ == "MuData" else adata.var_names[self.n_genes :][region_mask],
columns=adata["rna"].var_names[self.n_genes :][region_mask]
if type(adata).__name__ == "MuData"
else adata.var_names[self.n_genes :][region_mask],
)

@torch.inference_mode()
Expand Down Expand Up @@ -1089,7 +1089,8 @@ def setup_mudata(
mod_key=modalities.rna_layer,
is_count_data=True,
mod_required=True,
))
)
)
if modalities.atac_layer is not None:
mudata_fields.append(
fields.MuDataLayerField(
Expand All @@ -1098,7 +1099,8 @@ def setup_mudata(
mod_key=modalities.atac_layer,
is_count_data=True,
mod_required=True,
))
)
)
if modalities.protein_layer is not None:
mudata_fields.append(
fields.MuDataProteinLayerField(
Expand All @@ -1109,7 +1111,8 @@ def setup_mudata(
batch_field=batch_field,
is_count_data=True,
mod_required=True,
))
)
)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
2 changes: 1 addition & 1 deletion src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number, AnnOrMuData
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand Down
33 changes: 20 additions & 13 deletions tests/model/test_multivi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os

import anndata as ad
import muon
import numpy as np
import pytest
from mudata import MuData
import scanpy as sc
import anndata as ad
from mudata import MuData

import scvi
import os
from scvi import REGISTRY_KEYS
from scvi.data import synthetic_iid
from scvi.model import MULTIVI
from scvi import REGISTRY_KEYS
from scvi.utils import attrdict


Expand Down Expand Up @@ -145,13 +147,15 @@ def test_multivi_mudata_rna_atac_external():
n_top_genes=4000,
flavor="seurat_v3",
)
mdata.mod["atac_subset"] = (mdata.mod["atac"][:, mdata.mod["atac"].var["highly_variable"]].
copy())
mdata.mod["atac_subset"] = mdata.mod["atac"][
:, mdata.mod["atac"].var["highly_variable"]
].copy()
mdata.update()
# mdata
# mdata.mod
MULTIVI.setup_mudata(mdata, modalities={"rna_layer": "rna_subset",
"atac_layer": "atac_subset"})
MULTIVI.setup_mudata(
mdata, modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"}
)
model = MULTIVI(mdata, n_genes=50, n_regions=50)
model.train(1, train_size=0.9)

Expand All @@ -178,8 +182,11 @@ def test_multivi_mudata():
MULTIVI.setup_mudata(
mdata,
batch_key="batch",
modalities={"rna_layer": "rna", "protein_layer": "protein_expression",
"atac_layer": "accessibility"},
modalities={
"rna_layer": "rna",
"protein_layer": "protein_expression",
"atac_layer": "accessibility",
},
)
n_obs = mdata.n_obs
# n_genes = np.min([mdata.n_vars, mdata["protein_expression"].n_vars])
Expand Down Expand Up @@ -354,7 +361,7 @@ def test_multivi_size_factor_mudata():
model.train(1, train_size=0.5)


def test_multivi_saving_and_loading_mudata(save_path: str="."):
def test_multivi_saving_and_loading_mudata(save_path: str = "."):
adata = synthetic_iid()
protein_adata = synthetic_iid(n_genes=50)
mdata = MuData({"rna": adata, "protein": protein_adata})
Expand Down Expand Up @@ -415,7 +422,7 @@ def test_multivi_saving_and_loading_mudata(save_path: str="."):
)


def test_scarches_mudata_prep_layer(save_path: str="."):
def test_scarches_mudata_prep_layer(save_path: str = "."):
n_latent = 5
mdata1 = synthetic_iid(return_mudata=True)

Expand Down Expand Up @@ -463,7 +470,7 @@ def test_scarches_mudata_prep_layer(save_path: str="."):
MULTIVI.load_query_data(mdata2, dir_path)


def test_multivi_save_load_mudata_format(save_path: str="."):
def test_multivi_save_load_mudata_format(save_path: str = "."):
mdata = synthetic_iid(return_mudata=True, protein_expression_key="protein")
invalid_mdata = mdata.copy()
invalid_mdata.mod["protein"] = invalid_mdata.mod["protein"][:, :10].copy()
Expand Down

0 comments on commit e03b006

Please sign in to comment.