From 143a9aab02bc470cc9a2cb17648c4cc2e5b867f6 Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Wed, 19 Jul 2023 06:29:18 -0700
Subject: [PATCH 1/9] Update README with new badges (#2193)
* Update README with new badges
* Update build
---
README.md | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index 7294b6cbad..f31f4e2f6f 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,14 @@
-[![Stars](https://img.shields.io/github/stars/scverse/scvi-tools?logo=GitHub&color=yellow)](https://github.com/YosefLab/scvi-tools/stargazers)
+[![Stars](https://img.shields.io/github/stars/scverse/scvi-tools?logo=GitHub&color=yellow)](https://github.com/scverse/scvi-tools/stargazers)
[![PyPI](https://img.shields.io/pypi/v/scvi-tools.svg)](https://pypi.org/project/scvi-tools)
-[![Documentation Status](https://readthedocs.org/projects/scvi/badge/?version=latest)](https://scvi.readthedocs.io/en/stable/?badge=stable)
-![Build
-Status](https://github.com/scverse/scvi-tools/workflows/scvi-tools/badge.svg)
-[![Coverage](https://codecov.io/gh/scverse/scvi-tools/branch/master/graph/badge.svg)](https://codecov.io/gh/YosefLab/scvi-tools)
-[![Code
-Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)
-[![Downloads](https://pepy.tech/badge/scvi-tools)](https://pepy.tech/project/scvi-tools)
-[![Project chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://scverse.zulipchat.com/)
+[![PyPIDownloads](https://pepy.tech/badge/scvi-tools)](https://pepy.tech/project/scvi-tools)
+[![CondaDownloads](https://img.shields.io/conda/dn/conda-forge/scvi-tools?logo=Anaconda)](https://anaconda.org/conda-forge/scvi-tools)
+[![Docs](https://readthedocs.org/projects/scvi/badge/?version=latest)](https://scvi.readthedocs.io/en/stable/?badge=stable)
+[![Build](https://github.com/scverse/scvi-tools/actions/workflows/build.yml/badge.svg)](https://github.com/scverse/scvi-tools/actions/workflows/build.yml/)
+[![Coverage](https://codecov.io/gh/scverse/scvi-tools/branch/main/graph/badge.svg)](https://codecov.io/gh/scverse/scvi-tools)
+[![Discourse](https://img.shields.io/discourse/posts?color=yellow&logo=discourse&server=https%3A%2F%2Fdiscourse.scverse.org)](https://discourse.scverse.org/)
+[![Chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://scverse.zulipchat.com/)
[scvi-tools](https://scvi-tools.org/) (single-cell variational inference
tools) is a package for probabilistic modeling and analysis of single-cell omics
From b7563e8a09f7cc4ec25e1290e84d326c541ef345 Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Wed, 19 Jul 2023 12:12:44 -0700
Subject: [PATCH 2/9] Reformat optional dependencies (#2188)
* Reformat pyproject
* Update dependencies
* Potential docs fix
* Fix test dependencies
* Reformat dependencies
* Update release workflow to 3.11
---
.github/workflows/build.yml | 9 ++-
.github/workflows/release.yml | 8 +--
.github/workflows/test_linux_cpu.yml | 9 ++-
.github/workflows/test_linux_cpu_pre.yml | 9 ++-
.github/workflows/test_linux_cuda.yml | 4 +-
.github/workflows/test_macos_cpu.yml | 9 ++-
.github/workflows/test_windows_cpu.yml | 9 ++-
pyproject.toml | 83 ++++++++++++++----------
readthedocs.yml | 5 +-
9 files changed, 90 insertions(+), 55 deletions(-)
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index c383b986bf..557fb64832 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -9,17 +9,22 @@ on:
jobs:
package:
runs-on: ubuntu-latest
+
steps:
- uses: actions/checkout@v3
- - name: Set up Python 3.10
+
+ - name: Set up Python 3.11
uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
+
- name: Install build dependencies
run: python -m pip install --upgrade pip wheel twine build
+
- name: Build package
run: python -m build
+
- name: Check package
run: twine check --strict dist/*.whl
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 1dd25654d3..da01223c2b 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -9,19 +9,19 @@ jobs:
release:
name: Release
runs-on: ubuntu-latest
+
steps:
- # will use ref/SHA that triggered it
- name: Checkout code
uses: actions/checkout@v3
- - name: Set up Python 3.9
+ - name: Set up Python 3.11
uses: actions/setup-python@v4
with:
- python-version: "3.9"
+ python-version: "3.11"
- name: Install hatch
run: |
- pip install hatch
+ python -m pip install --upgrade hatch
- name: Build project for distribution
run: hatch build
diff --git a/.github/workflows/test_linux_cpu.yml b/.github/workflows/test_linux_cpu.yml
index b14de9c4f4..67499a84d4 100644
--- a/.github/workflows/test_linux_cpu.yml
+++ b/.github/workflows/test_linux_cpu.yml
@@ -34,18 +34,22 @@ jobs:
steps:
- uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
+
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
- - name: Install dependencies
+
+ - name: Install scvi-tools test dependencies
run: |
- pip install ".[dev,pymde,autotune,hub]"
+ pip install ".[tests]"
+
- name: Test
env:
MPLBACKEND: agg
@@ -53,5 +57,6 @@ jobs:
DISPLAY: :42
run: |
pytest -v --cov --color=yes
+
- name: Upload coverage
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/test_linux_cpu_pre.yml b/.github/workflows/test_linux_cpu_pre.yml
index a5046dcce2..b6678c64f1 100644
--- a/.github/workflows/test_linux_cpu_pre.yml
+++ b/.github/workflows/test_linux_cpu_pre.yml
@@ -36,18 +36,22 @@ jobs:
steps:
- uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
+
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
- - name: Install dependencies
+
+ - name: Install scvi-tools test dependencies
run: |
- pip install --pre ".[dev,pymde,autotune,hub]"
+ pip install --pre ".[tests]"
+
- name: Test
env:
MPLBACKEND: agg
@@ -55,5 +59,6 @@ jobs:
DISPLAY: :42
run: |
pytest -v --cov --color=yes
+
- name: Upload coverage
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml
index 53a7b8ca1c..6c1bcd92db 100644
--- a/.github/workflows/test_linux_cuda.yml
+++ b/.github/workflows/test_linux_cuda.yml
@@ -21,7 +21,7 @@ jobs:
cuda: ["11"]
container:
- image: martinkim0/scvi-tools:ubuntu-${{ matrix.ubuntu }}-mamba-${{ matrix.mamba}}-python-${{ matrix.python }}-cuda-${{ matrix.cuda }}
+ image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}
options: --user root --gpus all
steps:
@@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
- pip install ".[dev,pymde,autotune,hub]"
+ pip install ".[tests]"
- name: Test
env:
diff --git a/.github/workflows/test_macos_cpu.yml b/.github/workflows/test_macos_cpu.yml
index c1147ffe6c..e35c1117ae 100644
--- a/.github/workflows/test_macos_cpu.yml
+++ b/.github/workflows/test_macos_cpu.yml
@@ -36,18 +36,22 @@ jobs:
steps:
- uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
+
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
- - name: Install dependencies
+
+ - name: Install scvi-tools test dependencies
run: |
- pip install ".[dev,pymde,autotune,hub]"
+ pip install ".[tests]"
+
- name: Test
env:
MPLBACKEND: agg
@@ -55,5 +59,6 @@ jobs:
DISPLAY: :42
run: |
pytest -v --cov --color=yes
+
- name: Upload coverage
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/test_windows_cpu.yml b/.github/workflows/test_windows_cpu.yml
index 2a0909054d..5d5d874dd7 100644
--- a/.github/workflows/test_windows_cpu.yml
+++ b/.github/workflows/test_windows_cpu.yml
@@ -36,18 +36,22 @@ jobs:
steps:
- uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"
+
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
- - name: Install dependencies
+
+ - name: Install scvi-tools test dependencies
run: |
- pip install ".[dev,pymde,autotune,hub]"
+ pip install ".[tests]"
+
- name: Test
env:
MPLBACKEND: agg
@@ -55,5 +59,6 @@ jobs:
DISPLAY: :42
run: |
pytest -v --cov --color=yes
+
- name: Upload coverage
uses: codecov/codecov-action@v3
diff --git a/pyproject.toml b/pyproject.toml
index a24fcbd342..3b5a4a4ae5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,7 +33,7 @@ classifiers = [
]
dependencies = [
"anndata>=0.7.5",
- "chex<=0.1.8",
+ "chex<=0.1.8", # see https://github.com/scverse/scvi-tools/pull/2187
"docrep>=0.3.2",
"flax",
"jax>=0.4.4",
@@ -60,50 +60,63 @@ dependencies = [
[project.optional-dependencies]
-dev = [
- "black",
+tests = [
"pytest",
"pytest-cov",
+ "scvi-tools[optional]"
+] # dependencies for running the test suite
+editing = [
+ "black",
"flake8",
- "scanpy>=1.6",
- "loompy>=3.0.6",
"jupyter",
"nbformat",
"nbconvert",
"pre-commit",
"ruff",
- "pymde",
- "genomepy",
- "cellxgene-census"
-]
+] # dependencies for editing and committing code
+dev = ["scvi-tools[editing,tests]"] # dependencies for dev work
+
docs = [
- "docutils>=0.8,!=0.18.*,!=0.19.*",
- "sphinx>=4.1",
- "ipython",
- "sphinx-book-theme>=1.0.1",
- "sphinx_copybutton",
- "sphinx-design",
- "sphinxext-opengraph",
- "sphinx-hoverxref",
- "sphinxcontrib-bibtex>=1.0.0",
- "myst-parser",
- "myst-nb",
- "sphinx-autodoc-typehints",
-]
+ "docutils>=0.8,!=0.18.*,!=0.19.*", # see https://github.com/scverse/cookiecutter-scverse/pull/205
+ "sphinx>=4.1",
+ "ipython",
+ "sphinx-book-theme>=1.0.1",
+ "sphinx_copybutton",
+ "sphinx-design",
+ "sphinxext-opengraph",
+ "sphinx-hoverxref",
+ "sphinxcontrib-bibtex>=1.0.0",
+ "myst-parser",
+ "myst-nb",
+ "sphinx-autodoc-typehints",
+] # basic docs dependencies
+docsbuild = ["scvi-tools[docs,optional]"] # docs build dependencies
+
autotune = [
- "hyperopt>=0.2",
- "ray[tune]>=2.5.0",
- "ipython",
- "scib-metrics>=0.3",
-]
-pymde = ["pymde"]
-tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc", "pynndescent", "pymde", "huggingface_hub", "genomepy"]
-hub = ["huggingface_hub"]
-regseq = [
- "biopython>=1.81",
- "genomepy",
-]
-census = ["cellxgene-census"]
+ "hyperopt>=0.2",
+ "ray[tune]>=2.5.0",
+ "ipython",
+ "scib-metrics>=0.3",
+] # scvi.autotune
+census = ["cellxgene-census"] # scvi.data.cellxgene
+hub = ["huggingface_hub"] # scvi.hub dependencies
+pymde = ["pymde"] # scvi.model.utils.mde dependencies
+regseq = ["biopython>=1.81", "genomepy"] # scvi.data.add_dna_sequence
+loompy = ["loompy>=3.0.6"] # read loom
+scanpy = ["scanpy>=1.6"] # scvi.criticism and read 10x
+optional = [
+ "scvi-tools[autotune,census,hub,loompy,pymde,regseq,scanpy]"
+] # all optional user functionality
+
+tutorials = [
+ "leidenalg",
+ "pynndescent",
+ "python-igraph",
+ "scikit-misc",
+ "scvi-tools[optional]",
+] # dependencies for all tutorials
+
+all = ["scvi-tools[dev,docs,tutorials]"] # all dependencies
[tool.hatch.build.targets.wheel]
packages = ['scvi']
diff --git a/readthedocs.yml b/readthedocs.yml
index 96fe2a79be..6af386e710 100644
--- a/readthedocs.yml
+++ b/readthedocs.yml
@@ -10,10 +10,7 @@ python:
- method: pip
path: .
extra_requirements:
- - docs
- - pymde
- - hub
- - autotune
+ - docsbuild
submodules:
include:
- "docs/tutorials/notebooks"
From 1d1082d48f7304e263a7c9cdbe1e198e7c367c13 Mon Sep 17 00:00:00 2001
From: Valeh Valiollah Pour Amiri <4193454+watiss@users.noreply.github.com>
Date: Thu, 20 Jul 2023 09:25:46 -0700
Subject: [PATCH 3/9] Store per-group LFC information #2173)
---
docs/release_notes/index.md | 8 ++++++++
scvi/criticism/_ppc.py | 21 ++++++++++++++++++++-
2 files changed, 28 insertions(+), 1 deletion(-)
diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md
index e99e76df78..4d4fa4f8a7 100644
--- a/docs/release_notes/index.md
+++ b/docs/release_notes/index.md
@@ -22,11 +22,19 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
- Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly
loading SciPy CSR and CSC data structures to their PyTorch counterparts, leading to
faster data loading depending on the sparsity of the data {pr}`2158`.
+- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
+ method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary`
+ stores the summary dataframe, and the `lfc_per_model_per_group` key stores the
+ per-group LFC.
#### Changed
- Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid`
for increased flexibility over dataset format {pr}`2163`.
+- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
+ method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary`
+ stores the summary dataframe, and the `lfc_per_model_per_group` key stores the
+ per-group LFC.
#### Removed
diff --git a/scvi/criticism/_ppc.py b/scvi/criticism/_ppc.py
index 0be9c54029..f859321381 100644
--- a/scvi/criticism/_ppc.py
+++ b/scvi/criticism/_ppc.py
@@ -371,6 +371,8 @@ def differential_expression(
],
)
i = 0
+ self.metrics[METRIC_DIFF_EXP] = {}
+ self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"] = {}
for g in groups:
raw_group_data = sc.get.rank_genes_groups_df(
adata_de, group=g, key=UNS_NAME_RGG_RAW
@@ -378,6 +380,8 @@ def differential_expression(
raw_group_data.set_index("names", inplace=True)
for model in de_keys.keys():
gene_overlap_f1s = []
+ rgds = []
+ sgds = []
lfc_maes = []
lfc_pearsons = []
lfc_spearmans = []
@@ -408,6 +412,8 @@ def differential_expression(
raw_group_data["logfoldchanges"],
sample_group_data["logfoldchanges"],
)
+ rgds.append(rgd)
+ sgds.append(sgd)
lfc_maes.append(np.mean(np.abs(rgd - sgd)))
lfc_pearsons.append(pearsonr(rgd, sgd)[0])
lfc_spearmans.append(spearmanr(rgd, sgd)[0])
@@ -430,6 +436,19 @@ def differential_expression(
df.loc[i, "lfc_spearman"] = np.mean(lfc_spearmans)
df.loc[i, "roc_auc"] = np.mean(roc_aucs)
df.loc[i, "pr_auc"] = np.mean(pr_aucs)
+ rgd, sgd = pd.DataFrame(rgds).mean(axis=0), pd.DataFrame(sgds).mean(
+ axis=0
+ )
+ if (
+ model
+ not in self.metrics[METRIC_DIFF_EXP][
+ "lfc_per_model_per_group"
+ ].keys()
+ ):
+ self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model] = {}
+ self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model][
+ g
+ ] = pd.DataFrame([rgd, sgd], index=["raw", "approx"]).T
i += 1
- self.metrics[METRIC_DIFF_EXP] = df
+ self.metrics[METRIC_DIFF_EXP]["summary"] = df
From 15108a42a1bbb75e5e6fb4f66c99038e046877de Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Tue, 25 Jul 2023 00:44:05 -0700
Subject: [PATCH 4/9] [pre-commit.ci] pre-commit autoupdate (#2199)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.0.278 → v0.0.280](https://github.com/astral-sh/ruff-pre-commit/compare/v0.0.278...v0.0.280)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 164bf11844..beaa6daeb9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -25,7 +25,7 @@ repos:
# https://github.com/jupyterlab/jupyterlab/issues/12675
language_version: "17.9.1"
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.0.278
+ rev: v0.0.280
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
From 3a7f19e933d9388f45b350ea673bab4bcb05407d Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Tue, 25 Jul 2023 01:45:09 -0700
Subject: [PATCH 5/9] Expose `torch.save` kwargs (#2200)
* Expose torch.save kwargs
* Add release note
* Annotations
* Fix docs annotation
---
docs/release_notes/index.md | 6 ++-
scvi/external/gimvi/_model.py | 72 +++++++++++++++++++---------------
scvi/model/base/_base_model.py | 60 ++++++++++++++++------------
3 files changed, 81 insertions(+), 57 deletions(-)
diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md
index 4d4fa4f8a7..c5e471a284 100644
--- a/docs/release_notes/index.md
+++ b/docs/release_notes/index.md
@@ -22,16 +22,18 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
- Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly
loading SciPy CSR and CSC data structures to their PyTorch counterparts, leading to
faster data loading depending on the sparsity of the data {pr}`2158`.
-- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
+- Add per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary`
stores the summary dataframe, and the `lfc_per_model_per_group` key stores the
per-group LFC.
+- Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save`
+ and {class}`scvi.external.GIMVI.save` {pr}`2200`.
#### Changed
- Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid`
for increased flexibility over dataset format {pr}`2163`.
-- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
+- Add per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression`
method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary`
stores the summary dataframe, and the `lfc_per_model_per_group` key stores the
per-group LFC.
diff --git a/scvi/external/gimvi/_model.py b/scvi/external/gimvi/_model.py
index 6b1044d720..4cf7d2978b 100644
--- a/scvi/external/gimvi/_model.py
+++ b/scvi/external/gimvi/_model.py
@@ -1,8 +1,9 @@
+from __future__ import annotations
+
import logging
import os
import warnings
from itertools import cycle
-from typing import List, Optional, Union
import numpy as np
import torch
@@ -77,8 +78,8 @@ def __init__(
self,
adata_seq: AnnData,
adata_spatial: AnnData,
- generative_distributions: Optional[List[str]] = None,
- model_library_size: Optional[List[bool]] = None,
+ generative_distributions: list[str] | None = None,
+ model_library_size: list[bool] | None = None,
n_latent: int = 10,
**model_kwargs,
):
@@ -162,13 +163,13 @@ def train(
self,
max_epochs: int = 200,
accelerator: str = "auto",
- devices: Union[int, List[int], str] = "auto",
+ devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
- validation_size: Optional[float] = None,
+ validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
- plan_kwargs: Optional[dict] = None,
+ plan_kwargs: dict | None = None,
**kwargs,
):
"""Train the model.
@@ -254,7 +255,7 @@ def train(
self.to_device(device)
self.is_trained_ = True
- def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128):
+ def _make_scvi_dls(self, adatas: list[AnnData] = None, batch_size=128):
if adatas is None:
adatas = self.adatas
post_list = [self._make_data_loader(ad) for ad in adatas]
@@ -266,10 +267,10 @@ def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128):
@torch.inference_mode()
def get_latent_representation(
self,
- adatas: List[AnnData] = None,
+ adatas: list[AnnData] = None,
deterministic: bool = True,
batch_size: int = 128,
- ) -> List[np.ndarray]:
+ ) -> list[np.ndarray]:
"""Return the latent space embedding for each dataset.
Parameters
@@ -307,12 +308,12 @@ def get_latent_representation(
@torch.inference_mode()
def get_imputed_values(
self,
- adatas: List[AnnData] = None,
+ adatas: list[AnnData] = None,
deterministic: bool = True,
normalized: bool = True,
- decode_mode: Optional[int] = None,
+ decode_mode: int | None = None,
batch_size: int = 128,
- ) -> List[np.ndarray]:
+ ) -> list[np.ndarray]:
"""Return imputed values for all genes for each dataset.
Parameters
@@ -376,9 +377,10 @@ def get_imputed_values(
def save(
self,
dir_path: str,
- prefix: Optional[str] = None,
+ prefix: str | None = None,
overwrite: bool = False,
save_anndata: bool = False,
+ save_kwargs: dict | None = None,
**anndata_write_kwargs,
):
"""Save the state of the model.
@@ -398,6 +400,8 @@ def save(
already exists at `dir_path`, error will be raised.
save_anndata
If True, also saves the anndata
+ save_kwargs
+ Keyword arguments passed into :func:`~torch.save`.
anndata_write_kwargs
Kwargs for anndata write function
"""
@@ -409,6 +413,7 @@ def save(
)
file_name_prefix = prefix or ""
+ save_kwargs = save_kwargs or {}
seq_adata = self.adatas[0]
spatial_adata = self.adatas[1]
@@ -442,6 +447,7 @@ def save(
"attr_dict": user_attributes,
},
model_save_path,
+ **save_kwargs,
)
@classmethod
@@ -449,12 +455,12 @@ def save(
def load(
cls,
dir_path: str,
- adata_seq: Optional[AnnData] = None,
- adata_spatial: Optional[AnnData] = None,
+ adata_seq: AnnData | None = None,
+ adata_spatial: AnnData | None = None,
accelerator: str = "auto",
- device: Union[int, str] = "auto",
- prefix: Optional[str] = None,
- backup_url: Optional[str] = None,
+ device: int | str = "auto",
+ prefix: str | None = None,
+ backup_url: str | None = None,
):
"""Instantiate a model from the saved output.
@@ -578,21 +584,24 @@ def convert_legacy_save(
dir_path: str,
output_dir_path: str,
overwrite: bool = False,
- prefix: Optional[str] = None,
+ prefix: str | None = None,
+ **save_kwargs,
) -> None:
"""Converts a legacy saved GIMVI model ( AnnDataManager:
"""Manager instance associated with self.adata."""
return self._adata_manager
- def to_device(self, device: Union[str, int]):
+ def to_device(self, device: str | int):
"""Move model to device.
Parameters
@@ -169,7 +171,7 @@ def _get_setup_method_args(**setup_locals) -> dict:
@staticmethod
def _create_modalities_attr_dict(
- modalities: Dict[str, str], setup_method_args: dict
+ modalities: dict[str, str], setup_method_args: dict
) -> attrdict:
"""Preprocesses a ``modalities`` dictionary used in ``setup_mudata()`` to map modality names.
@@ -224,7 +226,7 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager):
instance_manager_store = self._per_instance_manager_store[self.id]
instance_manager_store[adata_id] = adata_manager
- def deregister_manager(self, adata: Optional[AnnData] = None):
+ def deregister_manager(self, adata: AnnData | None = None):
"""Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`.
If `adata` is `None`, deregisters all :class:`~scvi.data.AnnDataManager` instances
@@ -263,7 +265,7 @@ def deregister_manager(self, adata: Optional[AnnData] = None):
@classmethod
def _get_most_recent_anndata_manager(
cls, adata: AnnOrMuData, required: bool = False
- ) -> Optional[AnnDataManager]:
+ ) -> AnnDataManager | None:
"""Retrieves the :class:`~scvi.data.AnnDataManager` for a given AnnData object specific to this model class.
Checks for the most recent :class:`~scvi.data.AnnDataManager` created for the given AnnData object via
@@ -305,7 +307,7 @@ def _get_most_recent_anndata_manager(
def get_anndata_manager(
self, adata: AnnOrMuData, required: bool = False
- ) -> Optional[AnnDataManager]:
+ ) -> AnnDataManager | None:
"""Retrieves the :class:`~scvi.data.AnnDataManager` for a given AnnData object specific to this model instance.
Requires ``self.id`` has been set. Checks for an :class:`~scvi.data.AnnDataManager`
@@ -384,8 +386,8 @@ def get_from_registry(
def _make_data_loader(
self,
adata: AnnOrMuData,
- indices: Optional[Sequence[int]] = None,
- batch_size: Optional[int] = None,
+ indices: Sequence[int] | None = None,
+ batch_size: int | None = None,
shuffle: bool = False,
data_loader_class=None,
**data_loader_kwargs,
@@ -435,7 +437,7 @@ def _make_data_loader(
return dl
def _validate_anndata(
- self, adata: Optional[AnnOrMuData] = None, copy_if_view: bool = True
+ self, adata: AnnOrMuData | None = None, copy_if_view: bool = True
) -> AnnData:
"""Validate anndata has been properly registered, transfer if necessary."""
if adata is None:
@@ -557,9 +559,10 @@ def train(self):
def save(
self,
dir_path: str,
- prefix: Optional[str] = None,
+ prefix: str | None = None,
overwrite: bool = False,
save_anndata: bool = False,
+ save_kwargs: dict | None = None,
**anndata_write_kwargs,
):
"""Save the state of the model.
@@ -579,6 +582,8 @@ def save(
already exists at `dir_path`, error will be raised.
save_anndata
If True, also saves the anndata
+ save_kwargs
+ Keyword arguments passed into :func:`~torch.save`.
anndata_write_kwargs
Kwargs for :meth:`~anndata.AnnData.write`
"""
@@ -590,6 +595,8 @@ def save(
)
file_name_prefix = prefix or ""
+ save_kwargs = save_kwargs or {}
+
if save_anndata:
file_suffix = ""
if isinstance(self.adata, AnnData):
@@ -621,6 +628,7 @@ def save(
"attr_dict": user_attributes,
},
model_save_path,
+ **save_kwargs,
)
@classmethod
@@ -628,11 +636,11 @@ def save(
def load(
cls,
dir_path: str,
- adata: Optional[AnnOrMuData] = None,
+ adata: AnnOrMuData | None = None,
accelerator: str = "auto",
- device: Union[int, str] = "auto",
- prefix: Optional[str] = None,
- backup_url: Optional[str] = None,
+ device: int | str = "auto",
+ prefix: str | None = None,
+ backup_url: str | None = None,
):
"""Instantiate a model from the saved output.
@@ -720,7 +728,8 @@ def convert_legacy_save(
dir_path: str,
output_dir_path: str,
overwrite: bool = False,
- prefix: Optional[str] = None,
+ prefix: str | None = None,
+ **save_kwargs,
) -> None:
"""Converts a legacy saved model ( None:
+ def view_setup_args(dir_path: str, prefix: str | None = None) -> None:
"""Print args used to setup a saved model.
Parameters
@@ -813,7 +825,7 @@ def view_setup_args(dir_path: str, prefix: Optional[str] = None) -> None:
AnnDataManager.view_setup_method_args(registry)
@staticmethod
- def load_registry(dir_path: str, prefix: Optional[str] = None) -> dict:
+ def load_registry(dir_path: str, prefix: str | None = None) -> dict:
"""Return the full registry saved with the model.
Parameters
@@ -839,7 +851,7 @@ def load_registry(dir_path: str, prefix: Optional[str] = None) -> dict:
return attr_dict.pop("registry_")
def view_anndata_setup(
- self, adata: Optional[AnnOrMuData] = None, hide_state_registries: bool = False
+ self, adata: AnnOrMuData | None = None, hide_state_registries: bool = False
) -> None:
"""Print summary of the setup for the initial AnnData or a given AnnData object.
@@ -867,7 +879,7 @@ class BaseMinifiedModeModelClass(BaseModelClass):
"""Abstract base class for scvi-tools models that can handle minified data."""
@property
- def minified_data_type(self) -> Union[MinifiedDataType, None]:
+ def minified_data_type(self) -> MinifiedDataType | None:
"""The type of minified data associated with this model, if applicable."""
return (
self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY)
From c8df1b0acdcca332aa03dcb5c29a276e4640a594 Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Tue, 25 Jul 2023 02:32:05 -0700
Subject: [PATCH 6/9] Add `accelerator` and `devices` options in pytest (#2201)
* Add accelerator and devices args to pytest
* Update cuda test workflow
* Fix annotations
---
.github/workflows/test_linux_cuda.yml | 2 +-
tests/conftest.py | 35 +++++++++++++-----------
tests/dataloaders/test_datasplitter.py | 12 ++++++---
tests/model/test_scvi.py | 21 ---------------
tests/models/test_pyro.py | 37 ++++++++++++++++----------
tests/train/test_distributed.py | 17 ------------
6 files changed, 52 insertions(+), 72 deletions(-)
delete mode 100644 tests/model/test_scvi.py
delete mode 100644 tests/train/test_distributed.py
diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml
index 6c1bcd92db..827fb2169d 100644
--- a/.github/workflows/test_linux_cuda.yml
+++ b/.github/workflows/test_linux_cuda.yml
@@ -38,7 +38,7 @@ jobs:
PLATFORM: ubuntu
DISPLAY: :42
run: |
- pytest -v --cov --color=yes --cuda
+ pytest -v --cov --color=yes --accelerator cuda --devices auto
- name: Upload coverage
uses: codecov/codecov-action@v3
diff --git a/tests/conftest.py b/tests/conftest.py
index 8ce99f0c70..129b17196f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2,7 +2,6 @@
from distutils.dir_util import copy_tree
import pytest
-import torch
import scvi
@@ -29,10 +28,16 @@ def pytest_addoption(parser):
help="Run tests that are optional.",
)
parser.addoption(
- "--cuda",
- action="store_true",
- default=False,
- help="Run tests that required a CUDA backend.",
+ "--accelerator",
+ action="store",
+ default="cpu",
+ help="Option to specify which accelerator to use for tests.",
+ )
+ parser.addoption(
+ "--devices",
+ action="store",
+ default="auto",
+ help="Option to specify which devices to use for tests.",
)
@@ -59,14 +64,6 @@ def pytest_collection_modifyitems(config, items):
if not run_optional and ("optional" in item.keywords):
item.add_marker(skip_optional)
- run_cuda = config.getoption("--cuda")
- skip_cuda = pytest.mark.skip(reason="need --cuda option to run")
- for item in items:
- # All tests marked with `pytest.mark.cuda` get skipped unless
- # `--cuda` passed
- if not run_cuda and ("cuda" in item.keywords):
- item.add_marker(skip_cuda)
-
@pytest.fixture(scope="session")
def save_path(tmpdir_factory):
@@ -91,6 +88,12 @@ def model_fit(request):
@pytest.fixture(scope="session")
-def cuda():
- """Docstring for cuda."""
- assert torch.cuda.is_available()
+def accelerator(request):
+ """Docstring for accelerator."""
+ return request.config.getoption("--accelerator")
+
+
+@pytest.fixture(scope="session")
+def devices(request):
+ """Docstring for devices."""
+ return request.config.getoption("--devices")
diff --git a/tests/dataloaders/test_datasplitter.py b/tests/dataloaders/test_datasplitter.py
index d8b2648a26..5671d640d5 100644
--- a/tests/dataloaders/test_datasplitter.py
+++ b/tests/dataloaders/test_datasplitter.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from math import ceil, floor
import numpy as np
@@ -41,12 +43,16 @@ def test_datasplitter_shuffle():
@pytest.mark.parametrize(
"sparse_format", ["csr_matrix", "csr_array", "csc_matrix", "csc_array"]
)
-def test_datasplitter_load_sparse_tensor(sparse_format: str):
+def test_datasplitter_load_sparse_tensor(
+ sparse_format: str,
+ accelerator: str,
+ devices: list | str | int,
+):
adata = scvi.data.synthetic_iid(sparse_format=sparse_format)
TestSparseModel.setup_anndata(adata)
model = TestSparseModel(adata)
model.train(
- accelerator="cpu",
- devices=1,
+ accelerator=accelerator,
+ devices=devices,
expected_sparse_layout=sparse_format.split("_")[0],
)
diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py
deleted file mode 100644
index 9fde0b532e..0000000000
--- a/tests/model/test_scvi.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import pytest
-
-import scvi
-
-
-@pytest.mark.cuda
-def test_scvi_train_ddp(devices: int = -1):
- adata = scvi.data.synthetic_iid()
- scvi.model.SCVI.setup_anndata(adata)
- model = scvi.model.SCVI(adata)
-
- model.train(
- max_epochs=1,
- check_val_every_n_epoch=1,
- accelerator="gpu",
- devices=devices,
- strategy="ddp_find_unused_parameters_true",
- )
-
- assert model.is_trained
- assert len(model.history) > 0
diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py
index 26a75b9eae..1e8184c8b1 100644
--- a/tests/models/test_pyro.py
+++ b/tests/models/test_pyro.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import os
-from typing import Optional
import numpy as np
import pyro
@@ -157,7 +158,7 @@ def setup_anndata(
cls,
adata: AnnData,
**kwargs,
- ) -> Optional[AnnData]:
+ ) -> AnnData | None:
setup_method_args = cls._get_setup_method_args(**locals())
# add index for each cell (provided to pyro plate for correct minibatching)
@@ -187,7 +188,10 @@ def _create_indices_adata_manager(adata: AnnData) -> AnnDataManager:
return adata_manager
-def test_pyro_bayesian_regression_low_level():
+def test_pyro_bayesian_regression_low_level(
+ accelerator: str,
+ devices: list | str | int,
+):
adata = synthetic_iid()
adata_manager = _create_indices_adata_manager(adata)
train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
@@ -196,8 +200,8 @@ def test_pyro_bayesian_regression_low_level():
plan = LowLevelPyroTrainingPlan(model)
plan.n_obs_training = len(train_dl.indices)
trainer = Trainer(
- accelerator="auto",
- devices="auto",
+ accelerator=accelerator,
+ devices=devices,
max_epochs=2,
callbacks=[PyroModelGuideWarmup(train_dl)],
)
@@ -213,7 +217,9 @@ def test_pyro_bayesian_regression_low_level():
]
-def test_pyro_bayesian_regression(save_path):
+def test_pyro_bayesian_regression(
+ accelerator: str, devices: list | str | int, save_path: str
+):
adata = synthetic_iid()
adata_manager = _create_indices_adata_manager(adata)
train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
@@ -222,8 +228,8 @@ def test_pyro_bayesian_regression(save_path):
plan = PyroTrainingPlan(model)
plan.n_obs_training = len(train_dl.indices)
trainer = Trainer(
- accelerator="auto",
- devices="auto",
+ accelerator=accelerator,
+ devices=devices,
max_epochs=2,
)
trainer.fit(plan, train_dl)
@@ -258,8 +264,8 @@ def test_pyro_bayesian_regression(save_path):
plan = PyroTrainingPlan(new_model)
plan.n_obs_training = len(train_dl.indices)
trainer = Trainer(
- accelerator="auto",
- devices="auto",
+ accelerator=accelerator,
+ devices=devices,
max_steps=1,
)
trainer.fit(plan, train_dl)
@@ -275,7 +281,10 @@ def test_pyro_bayesian_regression(save_path):
np.testing.assert_array_equal(linear_median_new, linear_median)
-def test_pyro_bayesian_regression_jit():
+def test_pyro_bayesian_regression_jit(
+ accelerator: str,
+ devices: list | str | int,
+):
adata = synthetic_iid()
adata_manager = _create_indices_adata_manager(adata)
train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
@@ -284,8 +293,8 @@ def test_pyro_bayesian_regression_jit():
plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
plan.n_obs_training = len(train_dl.indices)
trainer = Trainer(
- accelerator="auto",
- devices="auto",
+ accelerator=accelerator,
+ devices=devices,
max_epochs=2,
callbacks=[PyroJitGuideWarmup(train_dl)],
)
@@ -532,7 +541,7 @@ def setup_anndata(
cls,
adata: AnnData,
**kwargs,
- ) -> Optional[AnnData]:
+ ) -> AnnData | None:
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
diff --git a/tests/train/test_distributed.py b/tests/train/test_distributed.py
deleted file mode 100644
index a22c0179a4..0000000000
--- a/tests/train/test_distributed.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import pytest
-
-import scvi
-
-
-@pytest.mark.optional
-def test_scvi_train_ddp():
- adata = scvi.data.synthetic_iid()
- scvi.model.SCVI.setup_anndata(adata)
- model = scvi.model.SCVI(adata)
-
- model.train(
- max_epochs=1,
- accelerator="cpu",
- devices=2,
- strategy="ddp_find_unused_parameters_true",
- )
From 8647b7a54ad54f9a62fb0d690604ef3698ba1ef8 Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Tue, 25 Jul 2023 02:48:13 -0700
Subject: [PATCH 7/9] Update cuda test image (#2202)
---
.github/workflows/test_linux_cuda.yml | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml
index 827fb2169d..3fe94f65fe 100644
--- a/.github/workflows/test_linux_cuda.yml
+++ b/.github/workflows/test_linux_cuda.yml
@@ -15,13 +15,11 @@ jobs:
strategy:
fail-fast: false
matrix:
- ubuntu: [latest]
- mamba: [latest]
python: ["3.11"]
cuda: ["11"]
container:
- image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}
+ image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base
options: --user root --gpus all
steps:
From 0ee065e262afcf6e1e25916e3f600ba5ac7e5bfd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Szabolcs=20Horv=C3=A1t?=
Date: Tue, 25 Jul 2023 12:16:03 +0200
Subject: [PATCH 8/9] change python-igraph to igraph (#2197)
Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 3b5a4a4ae5..9d358520a9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -111,7 +111,7 @@ optional = [
tutorials = [
"leidenalg",
"pynndescent",
- "python-igraph",
+ "igraph",
"scikit-misc",
"scvi-tools[optional]",
] # dependencies for all tutorials
From 28a8c5de589b7698e318d3055ff83c361d0e79e2 Mon Sep 17 00:00:00 2001
From: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Date: Tue, 25 Jul 2023 07:25:30 -0700
Subject: [PATCH 9/9] Expose fixed model and train kwargs in autotune (#2203)
* Expose fixed model and train kwargs in autotune
* Add release note
---
docs/release_notes/index.md | 1 +
scvi/autotune/_manager.py | 67 ++++++++++++++++++++++++-------------
scvi/autotune/_tuner.py | 6 ++++
3 files changed, 50 insertions(+), 24 deletions(-)
diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md
index c5e471a284..234d57d74d 100644
--- a/docs/release_notes/index.md
+++ b/docs/release_notes/index.md
@@ -28,6 +28,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
per-group LFC.
- Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save`
and {class}`scvi.external.GIMVI.save` {pr}`2200`.
+- Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`.
#### Changed
diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py
index f54250401f..114d1b3bfd 100644
--- a/scvi/autotune/_manager.py
+++ b/scvi/autotune/_manager.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
import inspect
import logging
import os
import warnings
from collections import OrderedDict
from datetime import datetime
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable
import lightning.pytorch as pl
import rich
@@ -136,7 +138,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict:
return tunables
def _get_tunables(
- attr: Any, parent: Any = None, tunable_type: Optional[str] = None
+ attr: Any, parent: Any = None, tunable_type: str | None = None
) -> dict:
tunables = {}
if inspect.isfunction(attr):
@@ -159,7 +161,7 @@ def _get_metrics(model_cls: BaseModelClass) -> OrderedDict:
}
return registry
- def _get_search_space(self, search_space: dict) -> Tuple[dict, dict]:
+ def _get_search_space(self, search_space: dict) -> tuple[dict, dict]:
"""Parses a compact search space into separate kwargs dictionaries."""
model_kwargs = {}
train_kwargs = {}
@@ -210,7 +212,7 @@ def _validate_search_space(self, search_space: dict, use_defaults: bool) -> dict
return _search_space
def _validate_metrics(
- self, metric: str, additional_metrics: List[str]
+ self, metric: str, additional_metrics: list[str]
) -> OrderedDict:
"""Validates a set of metrics against the metric registry."""
registry_metrics = self._registry["metrics"]
@@ -240,7 +242,7 @@ def _validate_metrics(
return _metrics
@staticmethod
- def _get_primary_metric_and_mode(metrics: OrderedDict) -> Tuple[str, str]:
+ def _get_primary_metric_and_mode(metrics: OrderedDict) -> tuple[str, str]:
metric = list(metrics.keys())[0]
mode = metrics[metric]
return metric, mode
@@ -308,7 +310,7 @@ def _validate_scheduler_and_search_algorithm(
metrics: OrderedDict,
scheduler_kwargs: dict,
searcher_kwargs: dict,
- ) -> Tuple[Any, Any]:
+ ) -> tuple[Any, Any]:
"""Validates a scheduler and search algorithm pair for compatibility."""
supported = ["asha", "hyperband", "median", "pbt", "fifo"]
if scheduler not in supported:
@@ -361,7 +363,7 @@ def _validate_resources(self, resources: dict) -> dict:
# TODO: perform resource checking
return resources
- def _get_setup_info(self, adata: AnnOrMuData) -> Tuple[str, dict]:
+ def _get_setup_info(self, adata: AnnOrMuData) -> tuple[str, dict]:
"""Retrieves the method and kwargs used for setting up `adata` with the model class."""
manager = self._model_cls._get_most_recent_anndata_manager(adata)
setup_method_name = manager._registry.get(_SETUP_METHOD_NAME, "setup_anndata")
@@ -373,6 +375,8 @@ def _get_trainable(
self,
adata: AnnOrMuData,
metrics: OrderedDict,
+ model_kwargs: dict,
+ train_kwargs: dict,
resources: dict,
setup_method_name: str,
setup_kwargs: dict,
@@ -386,20 +390,27 @@ def _trainable(
model_cls: BaseModelClass,
adata: AnnOrMuData,
metric: str,
+ model_kwargs: dict,
+ train_kwargs: dict,
setup_method_name: str,
setup_kwargs: dict,
max_epochs: int,
accelerator: str,
devices: int,
) -> None:
- model_kwargs, train_kwargs = self._get_search_space(search_space)
+ _model_kwargs, _train_kwargs = self._get_search_space(search_space)
+ model_kwargs.update(_model_kwargs)
+ train_kwargs.update(_train_kwargs)
+
getattr(model_cls, setup_method_name)(adata, **setup_kwargs)
model = model_cls(adata, **model_kwargs)
+
# This is to get around lightning import changes
callback_cls = type(
"_TuneReportCallback", (TuneReportCallback, pl.Callback), {}
)
monitor = callback_cls(metric, on="validation_end")
+
model.train(
max_epochs=max_epochs,
accelerator=accelerator,
@@ -418,6 +429,8 @@ def _trainable(
model_cls=self._model_cls,
adata=adata,
metric=list(metrics.keys())[0],
+ model_kwargs=model_kwargs,
+ train_kwargs=train_kwargs,
setup_method_name=setup_method_name,
setup_kwargs=setup_kwargs,
max_epochs=max_epochs,
@@ -427,8 +440,8 @@ def _trainable(
return tune.with_resources(_wrap_params, resources=resources)
def _validate_experiment_name_and_logging_dir(
- self, experiment_name: Optional[str], logging_dir: Optional[str]
- ) -> Tuple[str, str]:
+ self, experiment_name: str | None, logging_dir: str | None
+ ) -> tuple[str, str]:
if experiment_name is None:
experiment_name = "tune_"
experiment_name += self._model_cls.__name__.lower() + "_"
@@ -442,26 +455,30 @@ def _get_tuner(
self,
adata: AnnOrMuData,
*,
- metric: Optional[str] = None,
- additional_metrics: Optional[List[str]] = None,
- search_space: Optional[dict] = None,
+ metric: str | None = None,
+ additional_metrics: list[str] | None = None,
+ search_space: dict | None = None,
+ model_kwargs: dict | None = None,
+ train_kwargs: dict | None = None,
use_defaults: bool = False,
- num_samples: Optional[int] = None,
- max_epochs: Optional[int] = None,
- scheduler: Optional[str] = None,
- scheduler_kwargs: Optional[dict] = None,
- searcher: Optional[str] = None,
- searcher_kwargs: Optional[dict] = None,
+ num_samples: int | None = None,
+ max_epochs: int | None = None,
+ scheduler: str | None = None,
+ scheduler_kwargs: dict | None = None,
+ searcher: str | None = None,
+ searcher_kwargs: dict | None = None,
reporter: bool = True,
- resources: Optional[dict] = None,
- experiment_name: Optional[str] = None,
- logging_dir: Optional[str] = None,
- ) -> Tuple[Any, dict]:
+ resources: dict | None = None,
+ experiment_name: str | None = None,
+ logging_dir: str | None = None,
+ ) -> tuple[Any, dict]:
metric = (
metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0]
)
additional_metrics = additional_metrics or []
search_space = search_space or {}
+ model_kwargs = model_kwargs or {}
+ train_kwargs = train_kwargs or {}
num_samples = num_samples or 10 # TODO: better default
max_epochs = max_epochs or 100 # TODO: better default
scheduler = scheduler or "asha"
@@ -481,6 +498,8 @@ def _get_tuner(
_trainable = self._get_trainable(
adata,
_metrics,
+ model_kwargs,
+ train_kwargs,
_resources,
_setup_method_name,
_setup_args,
@@ -537,7 +556,7 @@ def _get_analysis(self, results: Any, config: dict) -> TuneAnalysis:
)
@staticmethod
- def _add_columns(table: rich.table.Table, columns: List[str]) -> rich.table.Table:
+ def _add_columns(table: rich.table.Table, columns: list[str]) -> rich.table.Table:
"""Adds columns to a :class:`~rich.table.Table` with default formatting."""
for i, column in enumerate(columns):
table.add_column(column, style=COLORS[i], **COLUMN_KWARGS)
diff --git a/scvi/autotune/_tuner.py b/scvi/autotune/_tuner.py
index c9065f39a0..9568b064fc 100644
--- a/scvi/autotune/_tuner.py
+++ b/scvi/autotune/_tuner.py
@@ -48,6 +48,12 @@ def fit(self, adata: AnnOrMuData, **kwargs) -> None:
provided as instantiated Ray Tune sample functions. Available
hyperparameters can be viewed with :meth:`~scvi.autotune.ModelTuner.info`.
Must be provided if `use_defaults` is `False`.
+ model_kwargs
+ Keyword arguments passed to the model class's constructor. Arguments must
+ not overlap with those in `search_space`.
+ train_kwargs
+ Keyword arguments passed to the model's `train` method. Arguments must not
+ overlap with those in `search_space`.
use_defaults
Whether to use the model class's default search space, which can be viewed
with :meth:`~scvi.autotune.ModelTuner.info`. If `True` and `search_space` is