diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index c5e471a284..234d57d74d 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -28,6 +28,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits per-group LFC. - Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save` and {class}`scvi.external.GIMVI.save` {pr}`2200`. +- Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`. #### Changed diff --git a/scvi/autotune/_manager.py b/scvi/autotune/_manager.py index f54250401f..114d1b3bfd 100644 --- a/scvi/autotune/_manager.py +++ b/scvi/autotune/_manager.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import inspect import logging import os import warnings from collections import OrderedDict from datetime import datetime -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable import lightning.pytorch as pl import rich @@ -136,7 +138,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict: return tunables def _get_tunables( - attr: Any, parent: Any = None, tunable_type: Optional[str] = None + attr: Any, parent: Any = None, tunable_type: str | None = None ) -> dict: tunables = {} if inspect.isfunction(attr): @@ -159,7 +161,7 @@ def _get_metrics(model_cls: BaseModelClass) -> OrderedDict: } return registry - def _get_search_space(self, search_space: dict) -> Tuple[dict, dict]: + def _get_search_space(self, search_space: dict) -> tuple[dict, dict]: """Parses a compact search space into separate kwargs dictionaries.""" model_kwargs = {} train_kwargs = {} @@ -210,7 +212,7 @@ def _validate_search_space(self, search_space: dict, use_defaults: bool) -> dict return _search_space def _validate_metrics( - self, metric: str, additional_metrics: List[str] + self, metric: str, additional_metrics: list[str] ) -> OrderedDict: """Validates a set of metrics against the metric registry.""" registry_metrics = self._registry["metrics"] @@ -240,7 +242,7 @@ def _validate_metrics( return _metrics @staticmethod - def _get_primary_metric_and_mode(metrics: OrderedDict) -> Tuple[str, str]: + def _get_primary_metric_and_mode(metrics: OrderedDict) -> tuple[str, str]: metric = list(metrics.keys())[0] mode = metrics[metric] return metric, mode @@ -308,7 +310,7 @@ def _validate_scheduler_and_search_algorithm( metrics: OrderedDict, scheduler_kwargs: dict, searcher_kwargs: dict, - ) -> Tuple[Any, Any]: + ) -> tuple[Any, Any]: """Validates a scheduler and search algorithm pair for compatibility.""" supported = ["asha", "hyperband", "median", "pbt", "fifo"] if scheduler not in supported: @@ -361,7 +363,7 @@ def _validate_resources(self, resources: dict) -> dict: # TODO: perform resource checking return resources - def _get_setup_info(self, adata: AnnOrMuData) -> Tuple[str, dict]: + def _get_setup_info(self, adata: AnnOrMuData) -> tuple[str, dict]: """Retrieves the method and kwargs used for setting up `adata` with the model class.""" manager = self._model_cls._get_most_recent_anndata_manager(adata) setup_method_name = manager._registry.get(_SETUP_METHOD_NAME, "setup_anndata") @@ -373,6 +375,8 @@ def _get_trainable( self, adata: AnnOrMuData, metrics: OrderedDict, + model_kwargs: dict, + train_kwargs: dict, resources: dict, setup_method_name: str, setup_kwargs: dict, @@ -386,20 +390,27 @@ def _trainable( model_cls: BaseModelClass, adata: AnnOrMuData, metric: str, + model_kwargs: dict, + train_kwargs: dict, setup_method_name: str, setup_kwargs: dict, max_epochs: int, accelerator: str, devices: int, ) -> None: - model_kwargs, train_kwargs = self._get_search_space(search_space) + _model_kwargs, _train_kwargs = self._get_search_space(search_space) + model_kwargs.update(_model_kwargs) + train_kwargs.update(_train_kwargs) + getattr(model_cls, setup_method_name)(adata, **setup_kwargs) model = model_cls(adata, **model_kwargs) + # This is to get around lightning import changes callback_cls = type( "_TuneReportCallback", (TuneReportCallback, pl.Callback), {} ) monitor = callback_cls(metric, on="validation_end") + model.train( max_epochs=max_epochs, accelerator=accelerator, @@ -418,6 +429,8 @@ def _trainable( model_cls=self._model_cls, adata=adata, metric=list(metrics.keys())[0], + model_kwargs=model_kwargs, + train_kwargs=train_kwargs, setup_method_name=setup_method_name, setup_kwargs=setup_kwargs, max_epochs=max_epochs, @@ -427,8 +440,8 @@ def _trainable( return tune.with_resources(_wrap_params, resources=resources) def _validate_experiment_name_and_logging_dir( - self, experiment_name: Optional[str], logging_dir: Optional[str] - ) -> Tuple[str, str]: + self, experiment_name: str | None, logging_dir: str | None + ) -> tuple[str, str]: if experiment_name is None: experiment_name = "tune_" experiment_name += self._model_cls.__name__.lower() + "_" @@ -442,26 +455,30 @@ def _get_tuner( self, adata: AnnOrMuData, *, - metric: Optional[str] = None, - additional_metrics: Optional[List[str]] = None, - search_space: Optional[dict] = None, + metric: str | None = None, + additional_metrics: list[str] | None = None, + search_space: dict | None = None, + model_kwargs: dict | None = None, + train_kwargs: dict | None = None, use_defaults: bool = False, - num_samples: Optional[int] = None, - max_epochs: Optional[int] = None, - scheduler: Optional[str] = None, - scheduler_kwargs: Optional[dict] = None, - searcher: Optional[str] = None, - searcher_kwargs: Optional[dict] = None, + num_samples: int | None = None, + max_epochs: int | None = None, + scheduler: str | None = None, + scheduler_kwargs: dict | None = None, + searcher: str | None = None, + searcher_kwargs: dict | None = None, reporter: bool = True, - resources: Optional[dict] = None, - experiment_name: Optional[str] = None, - logging_dir: Optional[str] = None, - ) -> Tuple[Any, dict]: + resources: dict | None = None, + experiment_name: str | None = None, + logging_dir: str | None = None, + ) -> tuple[Any, dict]: metric = ( metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0] ) additional_metrics = additional_metrics or [] search_space = search_space or {} + model_kwargs = model_kwargs or {} + train_kwargs = train_kwargs or {} num_samples = num_samples or 10 # TODO: better default max_epochs = max_epochs or 100 # TODO: better default scheduler = scheduler or "asha" @@ -481,6 +498,8 @@ def _get_tuner( _trainable = self._get_trainable( adata, _metrics, + model_kwargs, + train_kwargs, _resources, _setup_method_name, _setup_args, @@ -537,7 +556,7 @@ def _get_analysis(self, results: Any, config: dict) -> TuneAnalysis: ) @staticmethod - def _add_columns(table: rich.table.Table, columns: List[str]) -> rich.table.Table: + def _add_columns(table: rich.table.Table, columns: list[str]) -> rich.table.Table: """Adds columns to a :class:`~rich.table.Table` with default formatting.""" for i, column in enumerate(columns): table.add_column(column, style=COLORS[i], **COLUMN_KWARGS) diff --git a/scvi/autotune/_tuner.py b/scvi/autotune/_tuner.py index c9065f39a0..9568b064fc 100644 --- a/scvi/autotune/_tuner.py +++ b/scvi/autotune/_tuner.py @@ -48,6 +48,12 @@ def fit(self, adata: AnnOrMuData, **kwargs) -> None: provided as instantiated Ray Tune sample functions. Available hyperparameters can be viewed with :meth:`~scvi.autotune.ModelTuner.info`. Must be provided if `use_defaults` is `False`. + model_kwargs + Keyword arguments passed to the model class's constructor. Arguments must + not overlap with those in `search_space`. + train_kwargs + Keyword arguments passed to the model's `train` method. Arguments must not + overlap with those in `search_space`. use_defaults Whether to use the model class's default search space, which can be viewed with :meth:`~scvi.autotune.ModelTuner.info`. If `True` and `search_space` is