From 143a9aab02bc470cc9a2cb17648c4cc2e5b867f6 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 19 Jul 2023 06:29:18 -0700 Subject: [PATCH 1/9] Update README with new badges (#2193) * Update README with new badges * Update build --- README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7294b6cbad..f31f4e2f6f 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,14 @@ scvi-tools -[![Stars](https://img.shields.io/github/stars/scverse/scvi-tools?logo=GitHub&color=yellow)](https://github.com/YosefLab/scvi-tools/stargazers) +[![Stars](https://img.shields.io/github/stars/scverse/scvi-tools?logo=GitHub&color=yellow)](https://github.com/scverse/scvi-tools/stargazers) [![PyPI](https://img.shields.io/pypi/v/scvi-tools.svg)](https://pypi.org/project/scvi-tools) -[![Documentation Status](https://readthedocs.org/projects/scvi/badge/?version=latest)](https://scvi.readthedocs.io/en/stable/?badge=stable) -![Build -Status](https://github.com/scverse/scvi-tools/workflows/scvi-tools/badge.svg) -[![Coverage](https://codecov.io/gh/scverse/scvi-tools/branch/master/graph/badge.svg)](https://codecov.io/gh/YosefLab/scvi-tools) -[![Code -Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) -[![Downloads](https://pepy.tech/badge/scvi-tools)](https://pepy.tech/project/scvi-tools) -[![Project chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://scverse.zulipchat.com/) +[![PyPIDownloads](https://pepy.tech/badge/scvi-tools)](https://pepy.tech/project/scvi-tools) +[![CondaDownloads](https://img.shields.io/conda/dn/conda-forge/scvi-tools?logo=Anaconda)](https://anaconda.org/conda-forge/scvi-tools) +[![Docs](https://readthedocs.org/projects/scvi/badge/?version=latest)](https://scvi.readthedocs.io/en/stable/?badge=stable) +[![Build](https://github.com/scverse/scvi-tools/actions/workflows/build.yml/badge.svg)](https://github.com/scverse/scvi-tools/actions/workflows/build.yml/) +[![Coverage](https://codecov.io/gh/scverse/scvi-tools/branch/main/graph/badge.svg)](https://codecov.io/gh/scverse/scvi-tools) +[![Discourse](https://img.shields.io/discourse/posts?color=yellow&logo=discourse&server=https%3A%2F%2Fdiscourse.scverse.org)](https://discourse.scverse.org/) +[![Chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://scverse.zulipchat.com/) [scvi-tools](https://scvi-tools.org/) (single-cell variational inference tools) is a package for probabilistic modeling and analysis of single-cell omics From b7563e8a09f7cc4ec25e1290e84d326c541ef345 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Wed, 19 Jul 2023 12:12:44 -0700 Subject: [PATCH 2/9] Reformat optional dependencies (#2188) * Reformat pyproject * Update dependencies * Potential docs fix * Fix test dependencies * Reformat dependencies * Update release workflow to 3.11 --- .github/workflows/build.yml | 9 ++- .github/workflows/release.yml | 8 +-- .github/workflows/test_linux_cpu.yml | 9 ++- .github/workflows/test_linux_cpu_pre.yml | 9 ++- .github/workflows/test_linux_cuda.yml | 4 +- .github/workflows/test_macos_cpu.yml | 9 ++- .github/workflows/test_windows_cpu.yml | 9 ++- pyproject.toml | 83 ++++++++++++++---------- readthedocs.yml | 5 +- 9 files changed, 90 insertions(+), 55 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c383b986bf..557fb64832 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,17 +9,22 @@ on: jobs: package: runs-on: ubuntu-latest + steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 + + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install build dependencies run: python -m pip install --upgrade pip wheel twine build + - name: Build package run: python -m build + - name: Check package run: twine check --strict dist/*.whl diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1dd25654d3..da01223c2b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,19 +9,19 @@ jobs: release: name: Release runs-on: ubuntu-latest + steps: - # will use ref/SHA that triggered it - name: Checkout code uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v4 with: - python-version: "3.9" + python-version: "3.11" - name: Install hatch run: | - pip install hatch + python -m pip install --upgrade hatch - name: Build project for distribution run: hatch build diff --git a/.github/workflows/test_linux_cpu.yml b/.github/workflows/test_linux_cpu.yml index b14de9c4f4..67499a84d4 100644 --- a/.github/workflows/test_linux_cpu.yml +++ b/.github/workflows/test_linux_cpu.yml @@ -34,18 +34,22 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - - name: Install dependencies + + - name: Install scvi-tools test dependencies run: | - pip install ".[dev,pymde,autotune,hub]" + pip install ".[tests]" + - name: Test env: MPLBACKEND: agg @@ -53,5 +57,6 @@ jobs: DISPLAY: :42 run: | pytest -v --cov --color=yes + - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.github/workflows/test_linux_cpu_pre.yml b/.github/workflows/test_linux_cpu_pre.yml index a5046dcce2..b6678c64f1 100644 --- a/.github/workflows/test_linux_cpu_pre.yml +++ b/.github/workflows/test_linux_cpu_pre.yml @@ -36,18 +36,22 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - - name: Install dependencies + + - name: Install scvi-tools test dependencies run: | - pip install --pre ".[dev,pymde,autotune,hub]" + pip install --pre ".[tests]" + - name: Test env: MPLBACKEND: agg @@ -55,5 +59,6 @@ jobs: DISPLAY: :42 run: | pytest -v --cov --color=yes + - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 53a7b8ca1c..6c1bcd92db 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -21,7 +21,7 @@ jobs: cuda: ["11"] container: - image: martinkim0/scvi-tools:ubuntu-${{ matrix.ubuntu }}-mamba-${{ matrix.mamba}}-python-${{ matrix.python }}-cuda-${{ matrix.cuda }} + image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }} options: --user root --gpus all steps: @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | - pip install ".[dev,pymde,autotune,hub]" + pip install ".[tests]" - name: Test env: diff --git a/.github/workflows/test_macos_cpu.yml b/.github/workflows/test_macos_cpu.yml index c1147ffe6c..e35c1117ae 100644 --- a/.github/workflows/test_macos_cpu.yml +++ b/.github/workflows/test_macos_cpu.yml @@ -36,18 +36,22 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - - name: Install dependencies + + - name: Install scvi-tools test dependencies run: | - pip install ".[dev,pymde,autotune,hub]" + pip install ".[tests]" + - name: Test env: MPLBACKEND: agg @@ -55,5 +59,6 @@ jobs: DISPLAY: :42 run: | pytest -v --cov --color=yes + - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/.github/workflows/test_windows_cpu.yml b/.github/workflows/test_windows_cpu.yml index 2a0909054d..5d5d874dd7 100644 --- a/.github/workflows/test_windows_cpu.yml +++ b/.github/workflows/test_windows_cpu.yml @@ -36,18 +36,22 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - - name: Install dependencies + + - name: Install scvi-tools test dependencies run: | - pip install ".[dev,pymde,autotune,hub]" + pip install ".[tests]" + - name: Test env: MPLBACKEND: agg @@ -55,5 +59,6 @@ jobs: DISPLAY: :42 run: | pytest -v --cov --color=yes + - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/pyproject.toml b/pyproject.toml index a24fcbd342..3b5a4a4ae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ "anndata>=0.7.5", - "chex<=0.1.8", + "chex<=0.1.8", # see https://github.com/scverse/scvi-tools/pull/2187 "docrep>=0.3.2", "flax", "jax>=0.4.4", @@ -60,50 +60,63 @@ dependencies = [ [project.optional-dependencies] -dev = [ - "black", +tests = [ "pytest", "pytest-cov", + "scvi-tools[optional]" +] # dependencies for running the test suite +editing = [ + "black", "flake8", - "scanpy>=1.6", - "loompy>=3.0.6", "jupyter", "nbformat", "nbconvert", "pre-commit", "ruff", - "pymde", - "genomepy", - "cellxgene-census" -] +] # dependencies for editing and committing code +dev = ["scvi-tools[editing,tests]"] # dependencies for dev work + docs = [ - "docutils>=0.8,!=0.18.*,!=0.19.*", - "sphinx>=4.1", - "ipython", - "sphinx-book-theme>=1.0.1", - "sphinx_copybutton", - "sphinx-design", - "sphinxext-opengraph", - "sphinx-hoverxref", - "sphinxcontrib-bibtex>=1.0.0", - "myst-parser", - "myst-nb", - "sphinx-autodoc-typehints", -] + "docutils>=0.8,!=0.18.*,!=0.19.*", # see https://github.com/scverse/cookiecutter-scverse/pull/205 + "sphinx>=4.1", + "ipython", + "sphinx-book-theme>=1.0.1", + "sphinx_copybutton", + "sphinx-design", + "sphinxext-opengraph", + "sphinx-hoverxref", + "sphinxcontrib-bibtex>=1.0.0", + "myst-parser", + "myst-nb", + "sphinx-autodoc-typehints", +] # basic docs dependencies +docsbuild = ["scvi-tools[docs,optional]"] # docs build dependencies + autotune = [ - "hyperopt>=0.2", - "ray[tune]>=2.5.0", - "ipython", - "scib-metrics>=0.3", -] -pymde = ["pymde"] -tutorials = ["scanpy", "leidenalg", "python-igraph", "loompy", "scikit-misc", "pynndescent", "pymde", "huggingface_hub", "genomepy"] -hub = ["huggingface_hub"] -regseq = [ - "biopython>=1.81", - "genomepy", -] -census = ["cellxgene-census"] + "hyperopt>=0.2", + "ray[tune]>=2.5.0", + "ipython", + "scib-metrics>=0.3", +] # scvi.autotune +census = ["cellxgene-census"] # scvi.data.cellxgene +hub = ["huggingface_hub"] # scvi.hub dependencies +pymde = ["pymde"] # scvi.model.utils.mde dependencies +regseq = ["biopython>=1.81", "genomepy"] # scvi.data.add_dna_sequence +loompy = ["loompy>=3.0.6"] # read loom +scanpy = ["scanpy>=1.6"] # scvi.criticism and read 10x +optional = [ + "scvi-tools[autotune,census,hub,loompy,pymde,regseq,scanpy]" +] # all optional user functionality + +tutorials = [ + "leidenalg", + "pynndescent", + "python-igraph", + "scikit-misc", + "scvi-tools[optional]", +] # dependencies for all tutorials + +all = ["scvi-tools[dev,docs,tutorials]"] # all dependencies [tool.hatch.build.targets.wheel] packages = ['scvi'] diff --git a/readthedocs.yml b/readthedocs.yml index 96fe2a79be..6af386e710 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -10,10 +10,7 @@ python: - method: pip path: . extra_requirements: - - docs - - pymde - - hub - - autotune + - docsbuild submodules: include: - "docs/tutorials/notebooks" From 1d1082d48f7304e263a7c9cdbe1e198e7c367c13 Mon Sep 17 00:00:00 2001 From: Valeh Valiollah Pour Amiri <4193454+watiss@users.noreply.github.com> Date: Thu, 20 Jul 2023 09:25:46 -0700 Subject: [PATCH 3/9] Store per-group LFC information #2173) --- docs/release_notes/index.md | 8 ++++++++ scvi/criticism/_ppc.py | 21 ++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index e99e76df78..4d4fa4f8a7 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -22,11 +22,19 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits - Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly loading SciPy CSR and CSC data structures to their PyTorch counterparts, leading to faster data loading depending on the sparsity of the data {pr}`2158`. +- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` + method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` + stores the summary dataframe, and the `lfc_per_model_per_group` key stores the + per-group LFC. #### Changed - Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid` for increased flexibility over dataset format {pr}`2163`. +- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` + method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` + stores the summary dataframe, and the `lfc_per_model_per_group` key stores the + per-group LFC. #### Removed diff --git a/scvi/criticism/_ppc.py b/scvi/criticism/_ppc.py index 0be9c54029..f859321381 100644 --- a/scvi/criticism/_ppc.py +++ b/scvi/criticism/_ppc.py @@ -371,6 +371,8 @@ def differential_expression( ], ) i = 0 + self.metrics[METRIC_DIFF_EXP] = {} + self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"] = {} for g in groups: raw_group_data = sc.get.rank_genes_groups_df( adata_de, group=g, key=UNS_NAME_RGG_RAW @@ -378,6 +380,8 @@ def differential_expression( raw_group_data.set_index("names", inplace=True) for model in de_keys.keys(): gene_overlap_f1s = [] + rgds = [] + sgds = [] lfc_maes = [] lfc_pearsons = [] lfc_spearmans = [] @@ -408,6 +412,8 @@ def differential_expression( raw_group_data["logfoldchanges"], sample_group_data["logfoldchanges"], ) + rgds.append(rgd) + sgds.append(sgd) lfc_maes.append(np.mean(np.abs(rgd - sgd))) lfc_pearsons.append(pearsonr(rgd, sgd)[0]) lfc_spearmans.append(spearmanr(rgd, sgd)[0]) @@ -430,6 +436,19 @@ def differential_expression( df.loc[i, "lfc_spearman"] = np.mean(lfc_spearmans) df.loc[i, "roc_auc"] = np.mean(roc_aucs) df.loc[i, "pr_auc"] = np.mean(pr_aucs) + rgd, sgd = pd.DataFrame(rgds).mean(axis=0), pd.DataFrame(sgds).mean( + axis=0 + ) + if ( + model + not in self.metrics[METRIC_DIFF_EXP][ + "lfc_per_model_per_group" + ].keys() + ): + self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model] = {} + self.metrics[METRIC_DIFF_EXP]["lfc_per_model_per_group"][model][ + g + ] = pd.DataFrame([rgd, sgd], index=["raw", "approx"]).T i += 1 - self.metrics[METRIC_DIFF_EXP] = df + self.metrics[METRIC_DIFF_EXP]["summary"] = df From 15108a42a1bbb75e5e6fb4f66c99038e046877de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 00:44:05 -0700 Subject: [PATCH 4/9] [pre-commit.ci] pre-commit autoupdate (#2199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.0.278 → v0.0.280](https://github.com/astral-sh/ruff-pre-commit/compare/v0.0.278...v0.0.280) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 164bf11844..beaa6daeb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: # https://github.com/jupyterlab/jupyterlab/issues/12675 language_version: "17.9.1" - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.278 + rev: v0.0.280 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 3a7f19e933d9388f45b350ea673bab4bcb05407d Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 25 Jul 2023 01:45:09 -0700 Subject: [PATCH 5/9] Expose `torch.save` kwargs (#2200) * Expose torch.save kwargs * Add release note * Annotations * Fix docs annotation --- docs/release_notes/index.md | 6 ++- scvi/external/gimvi/_model.py | 72 +++++++++++++++++++--------------- scvi/model/base/_base_model.py | 60 ++++++++++++++++------------ 3 files changed, 81 insertions(+), 57 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 4d4fa4f8a7..c5e471a284 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -22,16 +22,18 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits - Add `load_sparse_tensor` argument in {class}`scvi.data.AnnTorchDataset` for directly loading SciPy CSR and CSC data structures to their PyTorch counterparts, leading to faster data loading depending on the sparsity of the data {pr}`2158`. -- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` +- Add per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` stores the summary dataframe, and the `lfc_per_model_per_group` key stores the per-group LFC. +- Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save` + and {class}`scvi.external.GIMVI.save` {pr}`2200`. #### Changed - Replace `sparse` with `sparse_format` argument in {meth}`scvi.data.synthetic_iid` for increased flexibility over dataset format {pr}`2163`. -- Added per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` +- Add per-group LFC information to the {meth}`scvi.criticism.PosteriorPredictiveCheck.differential_expression` method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` stores the summary dataframe, and the `lfc_per_model_per_group` key stores the per-group LFC. diff --git a/scvi/external/gimvi/_model.py b/scvi/external/gimvi/_model.py index 6b1044d720..4cf7d2978b 100644 --- a/scvi/external/gimvi/_model.py +++ b/scvi/external/gimvi/_model.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging import os import warnings from itertools import cycle -from typing import List, Optional, Union import numpy as np import torch @@ -77,8 +78,8 @@ def __init__( self, adata_seq: AnnData, adata_spatial: AnnData, - generative_distributions: Optional[List[str]] = None, - model_library_size: Optional[List[bool]] = None, + generative_distributions: list[str] | None = None, + model_library_size: list[bool] | None = None, n_latent: int = 10, **model_kwargs, ): @@ -162,13 +163,13 @@ def train( self, max_epochs: int = 200, accelerator: str = "auto", - devices: Union[int, List[int], str] = "auto", + devices: int | list[int] | str = "auto", kappa: int = 5, train_size: float = 0.9, - validation_size: Optional[float] = None, + validation_size: float | None = None, shuffle_set_split: bool = True, batch_size: int = 128, - plan_kwargs: Optional[dict] = None, + plan_kwargs: dict | None = None, **kwargs, ): """Train the model. @@ -254,7 +255,7 @@ def train( self.to_device(device) self.is_trained_ = True - def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128): + def _make_scvi_dls(self, adatas: list[AnnData] = None, batch_size=128): if adatas is None: adatas = self.adatas post_list = [self._make_data_loader(ad) for ad in adatas] @@ -266,10 +267,10 @@ def _make_scvi_dls(self, adatas: List[AnnData] = None, batch_size=128): @torch.inference_mode() def get_latent_representation( self, - adatas: List[AnnData] = None, + adatas: list[AnnData] = None, deterministic: bool = True, batch_size: int = 128, - ) -> List[np.ndarray]: + ) -> list[np.ndarray]: """Return the latent space embedding for each dataset. Parameters @@ -307,12 +308,12 @@ def get_latent_representation( @torch.inference_mode() def get_imputed_values( self, - adatas: List[AnnData] = None, + adatas: list[AnnData] = None, deterministic: bool = True, normalized: bool = True, - decode_mode: Optional[int] = None, + decode_mode: int | None = None, batch_size: int = 128, - ) -> List[np.ndarray]: + ) -> list[np.ndarray]: """Return imputed values for all genes for each dataset. Parameters @@ -376,9 +377,10 @@ def get_imputed_values( def save( self, dir_path: str, - prefix: Optional[str] = None, + prefix: str | None = None, overwrite: bool = False, save_anndata: bool = False, + save_kwargs: dict | None = None, **anndata_write_kwargs, ): """Save the state of the model. @@ -398,6 +400,8 @@ def save( already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata + save_kwargs + Keyword arguments passed into :func:`~torch.save`. anndata_write_kwargs Kwargs for anndata write function """ @@ -409,6 +413,7 @@ def save( ) file_name_prefix = prefix or "" + save_kwargs = save_kwargs or {} seq_adata = self.adatas[0] spatial_adata = self.adatas[1] @@ -442,6 +447,7 @@ def save( "attr_dict": user_attributes, }, model_save_path, + **save_kwargs, ) @classmethod @@ -449,12 +455,12 @@ def save( def load( cls, dir_path: str, - adata_seq: Optional[AnnData] = None, - adata_spatial: Optional[AnnData] = None, + adata_seq: AnnData | None = None, + adata_spatial: AnnData | None = None, accelerator: str = "auto", - device: Union[int, str] = "auto", - prefix: Optional[str] = None, - backup_url: Optional[str] = None, + device: int | str = "auto", + prefix: str | None = None, + backup_url: str | None = None, ): """Instantiate a model from the saved output. @@ -578,21 +584,24 @@ def convert_legacy_save( dir_path: str, output_dir_path: str, overwrite: bool = False, - prefix: Optional[str] = None, + prefix: str | None = None, + **save_kwargs, ) -> None: """Converts a legacy saved GIMVI model ( AnnDataManager: """Manager instance associated with self.adata.""" return self._adata_manager - def to_device(self, device: Union[str, int]): + def to_device(self, device: str | int): """Move model to device. Parameters @@ -169,7 +171,7 @@ def _get_setup_method_args(**setup_locals) -> dict: @staticmethod def _create_modalities_attr_dict( - modalities: Dict[str, str], setup_method_args: dict + modalities: dict[str, str], setup_method_args: dict ) -> attrdict: """Preprocesses a ``modalities`` dictionary used in ``setup_mudata()`` to map modality names. @@ -224,7 +226,7 @@ def _register_manager_for_instance(self, adata_manager: AnnDataManager): instance_manager_store = self._per_instance_manager_store[self.id] instance_manager_store[adata_id] = adata_manager - def deregister_manager(self, adata: Optional[AnnData] = None): + def deregister_manager(self, adata: AnnData | None = None): """Deregisters the :class:`~scvi.data.AnnDataManager` instance associated with `adata`. If `adata` is `None`, deregisters all :class:`~scvi.data.AnnDataManager` instances @@ -263,7 +265,7 @@ def deregister_manager(self, adata: Optional[AnnData] = None): @classmethod def _get_most_recent_anndata_manager( cls, adata: AnnOrMuData, required: bool = False - ) -> Optional[AnnDataManager]: + ) -> AnnDataManager | None: """Retrieves the :class:`~scvi.data.AnnDataManager` for a given AnnData object specific to this model class. Checks for the most recent :class:`~scvi.data.AnnDataManager` created for the given AnnData object via @@ -305,7 +307,7 @@ def _get_most_recent_anndata_manager( def get_anndata_manager( self, adata: AnnOrMuData, required: bool = False - ) -> Optional[AnnDataManager]: + ) -> AnnDataManager | None: """Retrieves the :class:`~scvi.data.AnnDataManager` for a given AnnData object specific to this model instance. Requires ``self.id`` has been set. Checks for an :class:`~scvi.data.AnnDataManager` @@ -384,8 +386,8 @@ def get_from_registry( def _make_data_loader( self, adata: AnnOrMuData, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, + indices: Sequence[int] | None = None, + batch_size: int | None = None, shuffle: bool = False, data_loader_class=None, **data_loader_kwargs, @@ -435,7 +437,7 @@ def _make_data_loader( return dl def _validate_anndata( - self, adata: Optional[AnnOrMuData] = None, copy_if_view: bool = True + self, adata: AnnOrMuData | None = None, copy_if_view: bool = True ) -> AnnData: """Validate anndata has been properly registered, transfer if necessary.""" if adata is None: @@ -557,9 +559,10 @@ def train(self): def save( self, dir_path: str, - prefix: Optional[str] = None, + prefix: str | None = None, overwrite: bool = False, save_anndata: bool = False, + save_kwargs: dict | None = None, **anndata_write_kwargs, ): """Save the state of the model. @@ -579,6 +582,8 @@ def save( already exists at `dir_path`, error will be raised. save_anndata If True, also saves the anndata + save_kwargs + Keyword arguments passed into :func:`~torch.save`. anndata_write_kwargs Kwargs for :meth:`~anndata.AnnData.write` """ @@ -590,6 +595,8 @@ def save( ) file_name_prefix = prefix or "" + save_kwargs = save_kwargs or {} + if save_anndata: file_suffix = "" if isinstance(self.adata, AnnData): @@ -621,6 +628,7 @@ def save( "attr_dict": user_attributes, }, model_save_path, + **save_kwargs, ) @classmethod @@ -628,11 +636,11 @@ def save( def load( cls, dir_path: str, - adata: Optional[AnnOrMuData] = None, + adata: AnnOrMuData | None = None, accelerator: str = "auto", - device: Union[int, str] = "auto", - prefix: Optional[str] = None, - backup_url: Optional[str] = None, + device: int | str = "auto", + prefix: str | None = None, + backup_url: str | None = None, ): """Instantiate a model from the saved output. @@ -720,7 +728,8 @@ def convert_legacy_save( dir_path: str, output_dir_path: str, overwrite: bool = False, - prefix: Optional[str] = None, + prefix: str | None = None, + **save_kwargs, ) -> None: """Converts a legacy saved model ( None: + def view_setup_args(dir_path: str, prefix: str | None = None) -> None: """Print args used to setup a saved model. Parameters @@ -813,7 +825,7 @@ def view_setup_args(dir_path: str, prefix: Optional[str] = None) -> None: AnnDataManager.view_setup_method_args(registry) @staticmethod - def load_registry(dir_path: str, prefix: Optional[str] = None) -> dict: + def load_registry(dir_path: str, prefix: str | None = None) -> dict: """Return the full registry saved with the model. Parameters @@ -839,7 +851,7 @@ def load_registry(dir_path: str, prefix: Optional[str] = None) -> dict: return attr_dict.pop("registry_") def view_anndata_setup( - self, adata: Optional[AnnOrMuData] = None, hide_state_registries: bool = False + self, adata: AnnOrMuData | None = None, hide_state_registries: bool = False ) -> None: """Print summary of the setup for the initial AnnData or a given AnnData object. @@ -867,7 +879,7 @@ class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" @property - def minified_data_type(self) -> Union[MinifiedDataType, None]: + def minified_data_type(self) -> MinifiedDataType | None: """The type of minified data associated with this model, if applicable.""" return ( self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) From c8df1b0acdcca332aa03dcb5c29a276e4640a594 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 25 Jul 2023 02:32:05 -0700 Subject: [PATCH 6/9] Add `accelerator` and `devices` options in pytest (#2201) * Add accelerator and devices args to pytest * Update cuda test workflow * Fix annotations --- .github/workflows/test_linux_cuda.yml | 2 +- tests/conftest.py | 35 +++++++++++++----------- tests/dataloaders/test_datasplitter.py | 12 ++++++--- tests/model/test_scvi.py | 21 --------------- tests/models/test_pyro.py | 37 ++++++++++++++++---------- tests/train/test_distributed.py | 17 ------------ 6 files changed, 52 insertions(+), 72 deletions(-) delete mode 100644 tests/model/test_scvi.py delete mode 100644 tests/train/test_distributed.py diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 6c1bcd92db..827fb2169d 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -38,7 +38,7 @@ jobs: PLATFORM: ubuntu DISPLAY: :42 run: | - pytest -v --cov --color=yes --cuda + pytest -v --cov --color=yes --accelerator cuda --devices auto - name: Upload coverage uses: codecov/codecov-action@v3 diff --git a/tests/conftest.py b/tests/conftest.py index 8ce99f0c70..129b17196f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ from distutils.dir_util import copy_tree import pytest -import torch import scvi @@ -29,10 +28,16 @@ def pytest_addoption(parser): help="Run tests that are optional.", ) parser.addoption( - "--cuda", - action="store_true", - default=False, - help="Run tests that required a CUDA backend.", + "--accelerator", + action="store", + default="cpu", + help="Option to specify which accelerator to use for tests.", + ) + parser.addoption( + "--devices", + action="store", + default="auto", + help="Option to specify which devices to use for tests.", ) @@ -59,14 +64,6 @@ def pytest_collection_modifyitems(config, items): if not run_optional and ("optional" in item.keywords): item.add_marker(skip_optional) - run_cuda = config.getoption("--cuda") - skip_cuda = pytest.mark.skip(reason="need --cuda option to run") - for item in items: - # All tests marked with `pytest.mark.cuda` get skipped unless - # `--cuda` passed - if not run_cuda and ("cuda" in item.keywords): - item.add_marker(skip_cuda) - @pytest.fixture(scope="session") def save_path(tmpdir_factory): @@ -91,6 +88,12 @@ def model_fit(request): @pytest.fixture(scope="session") -def cuda(): - """Docstring for cuda.""" - assert torch.cuda.is_available() +def accelerator(request): + """Docstring for accelerator.""" + return request.config.getoption("--accelerator") + + +@pytest.fixture(scope="session") +def devices(request): + """Docstring for devices.""" + return request.config.getoption("--devices") diff --git a/tests/dataloaders/test_datasplitter.py b/tests/dataloaders/test_datasplitter.py index d8b2648a26..5671d640d5 100644 --- a/tests/dataloaders/test_datasplitter.py +++ b/tests/dataloaders/test_datasplitter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from math import ceil, floor import numpy as np @@ -41,12 +43,16 @@ def test_datasplitter_shuffle(): @pytest.mark.parametrize( "sparse_format", ["csr_matrix", "csr_array", "csc_matrix", "csc_array"] ) -def test_datasplitter_load_sparse_tensor(sparse_format: str): +def test_datasplitter_load_sparse_tensor( + sparse_format: str, + accelerator: str, + devices: list | str | int, +): adata = scvi.data.synthetic_iid(sparse_format=sparse_format) TestSparseModel.setup_anndata(adata) model = TestSparseModel(adata) model.train( - accelerator="cpu", - devices=1, + accelerator=accelerator, + devices=devices, expected_sparse_layout=sparse_format.split("_")[0], ) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py deleted file mode 100644 index 9fde0b532e..0000000000 --- a/tests/model/test_scvi.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -import scvi - - -@pytest.mark.cuda -def test_scvi_train_ddp(devices: int = -1): - adata = scvi.data.synthetic_iid() - scvi.model.SCVI.setup_anndata(adata) - model = scvi.model.SCVI(adata) - - model.train( - max_epochs=1, - check_val_every_n_epoch=1, - accelerator="gpu", - devices=devices, - strategy="ddp_find_unused_parameters_true", - ) - - assert model.is_trained - assert len(model.history) > 0 diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 26a75b9eae..1e8184c8b1 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import os -from typing import Optional import numpy as np import pyro @@ -157,7 +158,7 @@ def setup_anndata( cls, adata: AnnData, **kwargs, - ) -> Optional[AnnData]: + ) -> AnnData | None: setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) @@ -187,7 +188,10 @@ def _create_indices_adata_manager(adata: AnnData) -> AnnDataManager: return adata_manager -def test_pyro_bayesian_regression_low_level(): +def test_pyro_bayesian_regression_low_level( + accelerator: str, + devices: list | str | int, +): adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) @@ -196,8 +200,8 @@ def test_pyro_bayesian_regression_low_level(): plan = LowLevelPyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator, + devices=devices, max_epochs=2, callbacks=[PyroModelGuideWarmup(train_dl)], ) @@ -213,7 +217,9 @@ def test_pyro_bayesian_regression_low_level(): ] -def test_pyro_bayesian_regression(save_path): +def test_pyro_bayesian_regression( + accelerator: str, devices: list | str | int, save_path: str +): adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) @@ -222,8 +228,8 @@ def test_pyro_bayesian_regression(save_path): plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator, + devices=devices, max_epochs=2, ) trainer.fit(plan, train_dl) @@ -258,8 +264,8 @@ def test_pyro_bayesian_regression(save_path): plan = PyroTrainingPlan(new_model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator, + devices=devices, max_steps=1, ) trainer.fit(plan, train_dl) @@ -275,7 +281,10 @@ def test_pyro_bayesian_regression(save_path): np.testing.assert_array_equal(linear_median_new, linear_median) -def test_pyro_bayesian_regression_jit(): +def test_pyro_bayesian_regression_jit( + accelerator: str, + devices: list | str | int, +): adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128) @@ -284,8 +293,8 @@ def test_pyro_bayesian_regression_jit(): plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( - accelerator="auto", - devices="auto", + accelerator=accelerator, + devices=devices, max_epochs=2, callbacks=[PyroJitGuideWarmup(train_dl)], ) @@ -532,7 +541,7 @@ def setup_anndata( cls, adata: AnnData, **kwargs, - ) -> Optional[AnnData]: + ) -> AnnData | None: setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ diff --git a/tests/train/test_distributed.py b/tests/train/test_distributed.py deleted file mode 100644 index a22c0179a4..0000000000 --- a/tests/train/test_distributed.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest - -import scvi - - -@pytest.mark.optional -def test_scvi_train_ddp(): - adata = scvi.data.synthetic_iid() - scvi.model.SCVI.setup_anndata(adata) - model = scvi.model.SCVI(adata) - - model.train( - max_epochs=1, - accelerator="cpu", - devices=2, - strategy="ddp_find_unused_parameters_true", - ) From 8647b7a54ad54f9a62fb0d690604ef3698ba1ef8 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 25 Jul 2023 02:48:13 -0700 Subject: [PATCH 7/9] Update cuda test image (#2202) --- .github/workflows/test_linux_cuda.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 827fb2169d..3fe94f65fe 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -15,13 +15,11 @@ jobs: strategy: fail-fast: false matrix: - ubuntu: [latest] - mamba: [latest] python: ["3.11"] cuda: ["11"] container: - image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }} + image: martinkim0/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base options: --user root --gpus all steps: From 0ee065e262afcf6e1e25916e3f600ba5ac7e5bfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szabolcs=20Horv=C3=A1t?= Date: Tue, 25 Jul 2023 12:16:03 +0200 Subject: [PATCH 8/9] change python-igraph to igraph (#2197) Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3b5a4a4ae5..9d358520a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ optional = [ tutorials = [ "leidenalg", "pynndescent", - "python-igraph", + "igraph", "scikit-misc", "scvi-tools[optional]", ] # dependencies for all tutorials From 28a8c5de589b7698e318d3055ff83c361d0e79e2 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Tue, 25 Jul 2023 07:25:30 -0700 Subject: [PATCH 9/9] Expose fixed model and train kwargs in autotune (#2203) * Expose fixed model and train kwargs in autotune * Add release note --- docs/release_notes/index.md | 1 + scvi/autotune/_manager.py | 67 ++++++++++++++++++++++++------------- scvi/autotune/_tuner.py | 6 ++++ 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index c5e471a284..234d57d74d 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -28,6 +28,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits per-group LFC. - Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save` and {class}`scvi.external.GIMVI.save` {pr}`2200`. +- Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`. #### Changed diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py index f54250401f..114d1b3bfd 100644 --- a/scvi/autotune/_manager.py +++ b/scvi/autotune/_manager.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import inspect import logging import os import warnings from collections import OrderedDict from datetime import datetime -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable import lightning.pytorch as pl import rich @@ -136,7 +138,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict: return tunables def _get_tunables( - attr: Any, parent: Any = None, tunable_type: Optional[str] = None + attr: Any, parent: Any = None, tunable_type: str | None = None ) -> dict: tunables = {} if inspect.isfunction(attr): @@ -159,7 +161,7 @@ def _get_metrics(model_cls: BaseModelClass) -> OrderedDict: } return registry - def _get_search_space(self, search_space: dict) -> Tuple[dict, dict]: + def _get_search_space(self, search_space: dict) -> tuple[dict, dict]: """Parses a compact search space into separate kwargs dictionaries.""" model_kwargs = {} train_kwargs = {} @@ -210,7 +212,7 @@ def _validate_search_space(self, search_space: dict, use_defaults: bool) -> dict return _search_space def _validate_metrics( - self, metric: str, additional_metrics: List[str] + self, metric: str, additional_metrics: list[str] ) -> OrderedDict: """Validates a set of metrics against the metric registry.""" registry_metrics = self._registry["metrics"] @@ -240,7 +242,7 @@ def _validate_metrics( return _metrics @staticmethod - def _get_primary_metric_and_mode(metrics: OrderedDict) -> Tuple[str, str]: + def _get_primary_metric_and_mode(metrics: OrderedDict) -> tuple[str, str]: metric = list(metrics.keys())[0] mode = metrics[metric] return metric, mode @@ -308,7 +310,7 @@ def _validate_scheduler_and_search_algorithm( metrics: OrderedDict, scheduler_kwargs: dict, searcher_kwargs: dict, - ) -> Tuple[Any, Any]: + ) -> tuple[Any, Any]: """Validates a scheduler and search algorithm pair for compatibility.""" supported = ["asha", "hyperband", "median", "pbt", "fifo"] if scheduler not in supported: @@ -361,7 +363,7 @@ def _validate_resources(self, resources: dict) -> dict: # TODO: perform resource checking return resources - def _get_setup_info(self, adata: AnnOrMuData) -> Tuple[str, dict]: + def _get_setup_info(self, adata: AnnOrMuData) -> tuple[str, dict]: """Retrieves the method and kwargs used for setting up `adata` with the model class.""" manager = self._model_cls._get_most_recent_anndata_manager(adata) setup_method_name = manager._registry.get(_SETUP_METHOD_NAME, "setup_anndata") @@ -373,6 +375,8 @@ def _get_trainable( self, adata: AnnOrMuData, metrics: OrderedDict, + model_kwargs: dict, + train_kwargs: dict, resources: dict, setup_method_name: str, setup_kwargs: dict, @@ -386,20 +390,27 @@ def _trainable( model_cls: BaseModelClass, adata: AnnOrMuData, metric: str, + model_kwargs: dict, + train_kwargs: dict, setup_method_name: str, setup_kwargs: dict, max_epochs: int, accelerator: str, devices: int, ) -> None: - model_kwargs, train_kwargs = self._get_search_space(search_space) + _model_kwargs, _train_kwargs = self._get_search_space(search_space) + model_kwargs.update(_model_kwargs) + train_kwargs.update(_train_kwargs) + getattr(model_cls, setup_method_name)(adata, **setup_kwargs) model = model_cls(adata, **model_kwargs) + # This is to get around lightning import changes callback_cls = type( "_TuneReportCallback", (TuneReportCallback, pl.Callback), {} ) monitor = callback_cls(metric, on="validation_end") + model.train( max_epochs=max_epochs, accelerator=accelerator, @@ -418,6 +429,8 @@ def _trainable( model_cls=self._model_cls, adata=adata, metric=list(metrics.keys())[0], + model_kwargs=model_kwargs, + train_kwargs=train_kwargs, setup_method_name=setup_method_name, setup_kwargs=setup_kwargs, max_epochs=max_epochs, @@ -427,8 +440,8 @@ def _trainable( return tune.with_resources(_wrap_params, resources=resources) def _validate_experiment_name_and_logging_dir( - self, experiment_name: Optional[str], logging_dir: Optional[str] - ) -> Tuple[str, str]: + self, experiment_name: str | None, logging_dir: str | None + ) -> tuple[str, str]: if experiment_name is None: experiment_name = "tune_" experiment_name += self._model_cls.__name__.lower() + "_" @@ -442,26 +455,30 @@ def _get_tuner( self, adata: AnnOrMuData, *, - metric: Optional[str] = None, - additional_metrics: Optional[List[str]] = None, - search_space: Optional[dict] = None, + metric: str | None = None, + additional_metrics: list[str] | None = None, + search_space: dict | None = None, + model_kwargs: dict | None = None, + train_kwargs: dict | None = None, use_defaults: bool = False, - num_samples: Optional[int] = None, - max_epochs: Optional[int] = None, - scheduler: Optional[str] = None, - scheduler_kwargs: Optional[dict] = None, - searcher: Optional[str] = None, - searcher_kwargs: Optional[dict] = None, + num_samples: int | None = None, + max_epochs: int | None = None, + scheduler: str | None = None, + scheduler_kwargs: dict | None = None, + searcher: str | None = None, + searcher_kwargs: dict | None = None, reporter: bool = True, - resources: Optional[dict] = None, - experiment_name: Optional[str] = None, - logging_dir: Optional[str] = None, - ) -> Tuple[Any, dict]: + resources: dict | None = None, + experiment_name: str | None = None, + logging_dir: str | None = None, + ) -> tuple[Any, dict]: metric = ( metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0] ) additional_metrics = additional_metrics or [] search_space = search_space or {} + model_kwargs = model_kwargs or {} + train_kwargs = train_kwargs or {} num_samples = num_samples or 10 # TODO: better default max_epochs = max_epochs or 100 # TODO: better default scheduler = scheduler or "asha" @@ -481,6 +498,8 @@ def _get_tuner( _trainable = self._get_trainable( adata, _metrics, + model_kwargs, + train_kwargs, _resources, _setup_method_name, _setup_args, @@ -537,7 +556,7 @@ def _get_analysis(self, results: Any, config: dict) -> TuneAnalysis: ) @staticmethod - def _add_columns(table: rich.table.Table, columns: List[str]) -> rich.table.Table: + def _add_columns(table: rich.table.Table, columns: list[str]) -> rich.table.Table: """Adds columns to a :class:`~rich.table.Table` with default formatting.""" for i, column in enumerate(columns): table.add_column(column, style=COLORS[i], **COLUMN_KWARGS) diff --git a/scvi/autotune/_tuner.py b/scvi/autotune/_tuner.py index c9065f39a0..9568b064fc 100644 --- a/scvi/autotune/_tuner.py +++ b/scvi/autotune/_tuner.py @@ -48,6 +48,12 @@ def fit(self, adata: AnnOrMuData, **kwargs) -> None: provided as instantiated Ray Tune sample functions. Available hyperparameters can be viewed with :meth:`~scvi.autotune.ModelTuner.info`. Must be provided if `use_defaults` is `False`. + model_kwargs + Keyword arguments passed to the model class's constructor. Arguments must + not overlap with those in `search_space`. + train_kwargs + Keyword arguments passed to the model's `train` method. Arguments must not + overlap with those in `search_space`. use_defaults Whether to use the model class's default search space, which can be viewed with :meth:`~scvi.autotune.ModelTuner.info`. If `True` and `search_space` is