Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose fixed model and train kwargs in autotune #2203

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
per-group LFC.
- Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save`
and {class}`scvi.external.GIMVI.save` {pr}`2200`.
- Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`.

#### Changed

Expand Down
67 changes: 43 additions & 24 deletions scvi/autotune/_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import inspect
import logging
import os
import warnings
from collections import OrderedDict
from datetime import datetime
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable

import lightning.pytorch as pl
import rich
Expand Down Expand Up @@ -136,7 +138,7 @@ def _parse_func_params(func: Callable, parent: Any, tunable_type: str) -> dict:
return tunables

def _get_tunables(
attr: Any, parent: Any = None, tunable_type: Optional[str] = None
attr: Any, parent: Any = None, tunable_type: str | None = None
) -> dict:
tunables = {}
if inspect.isfunction(attr):
Expand All @@ -159,7 +161,7 @@ def _get_metrics(model_cls: BaseModelClass) -> OrderedDict:
}
return registry

def _get_search_space(self, search_space: dict) -> Tuple[dict, dict]:
def _get_search_space(self, search_space: dict) -> tuple[dict, dict]:
"""Parses a compact search space into separate kwargs dictionaries."""
model_kwargs = {}
train_kwargs = {}
Expand Down Expand Up @@ -210,7 +212,7 @@ def _validate_search_space(self, search_space: dict, use_defaults: bool) -> dict
return _search_space

def _validate_metrics(
self, metric: str, additional_metrics: List[str]
self, metric: str, additional_metrics: list[str]
) -> OrderedDict:
"""Validates a set of metrics against the metric registry."""
registry_metrics = self._registry["metrics"]
Expand Down Expand Up @@ -240,7 +242,7 @@ def _validate_metrics(
return _metrics

@staticmethod
def _get_primary_metric_and_mode(metrics: OrderedDict) -> Tuple[str, str]:
def _get_primary_metric_and_mode(metrics: OrderedDict) -> tuple[str, str]:
metric = list(metrics.keys())[0]
mode = metrics[metric]
return metric, mode
Expand Down Expand Up @@ -308,7 +310,7 @@ def _validate_scheduler_and_search_algorithm(
metrics: OrderedDict,
scheduler_kwargs: dict,
searcher_kwargs: dict,
) -> Tuple[Any, Any]:
) -> tuple[Any, Any]:
"""Validates a scheduler and search algorithm pair for compatibility."""
supported = ["asha", "hyperband", "median", "pbt", "fifo"]
if scheduler not in supported:
Expand Down Expand Up @@ -361,7 +363,7 @@ def _validate_resources(self, resources: dict) -> dict:
# TODO: perform resource checking
return resources

def _get_setup_info(self, adata: AnnOrMuData) -> Tuple[str, dict]:
def _get_setup_info(self, adata: AnnOrMuData) -> tuple[str, dict]:
"""Retrieves the method and kwargs used for setting up `adata` with the model class."""
manager = self._model_cls._get_most_recent_anndata_manager(adata)
setup_method_name = manager._registry.get(_SETUP_METHOD_NAME, "setup_anndata")
Expand All @@ -373,6 +375,8 @@ def _get_trainable(
self,
adata: AnnOrMuData,
metrics: OrderedDict,
model_kwargs: dict,
train_kwargs: dict,
resources: dict,
setup_method_name: str,
setup_kwargs: dict,
Expand All @@ -386,20 +390,27 @@ def _trainable(
model_cls: BaseModelClass,
adata: AnnOrMuData,
metric: str,
model_kwargs: dict,
train_kwargs: dict,
setup_method_name: str,
setup_kwargs: dict,
max_epochs: int,
accelerator: str,
devices: int,
) -> None:
model_kwargs, train_kwargs = self._get_search_space(search_space)
_model_kwargs, _train_kwargs = self._get_search_space(search_space)
model_kwargs.update(_model_kwargs)
train_kwargs.update(_train_kwargs)

getattr(model_cls, setup_method_name)(adata, **setup_kwargs)
model = model_cls(adata, **model_kwargs)

# This is to get around lightning import changes
callback_cls = type(
"_TuneReportCallback", (TuneReportCallback, pl.Callback), {}
)
monitor = callback_cls(metric, on="validation_end")

model.train(
max_epochs=max_epochs,
accelerator=accelerator,
Expand All @@ -418,6 +429,8 @@ def _trainable(
model_cls=self._model_cls,
adata=adata,
metric=list(metrics.keys())[0],
model_kwargs=model_kwargs,
train_kwargs=train_kwargs,
setup_method_name=setup_method_name,
setup_kwargs=setup_kwargs,
max_epochs=max_epochs,
Expand All @@ -427,8 +440,8 @@ def _trainable(
return tune.with_resources(_wrap_params, resources=resources)

def _validate_experiment_name_and_logging_dir(
self, experiment_name: Optional[str], logging_dir: Optional[str]
) -> Tuple[str, str]:
self, experiment_name: str | None, logging_dir: str | None
) -> tuple[str, str]:
if experiment_name is None:
experiment_name = "tune_"
experiment_name += self._model_cls.__name__.lower() + "_"
Expand All @@ -442,26 +455,30 @@ def _get_tuner(
self,
adata: AnnOrMuData,
*,
metric: Optional[str] = None,
additional_metrics: Optional[List[str]] = None,
search_space: Optional[dict] = None,
metric: str | None = None,
additional_metrics: list[str] | None = None,
search_space: dict | None = None,
model_kwargs: dict | None = None,
train_kwargs: dict | None = None,
use_defaults: bool = False,
num_samples: Optional[int] = None,
max_epochs: Optional[int] = None,
scheduler: Optional[str] = None,
scheduler_kwargs: Optional[dict] = None,
searcher: Optional[str] = None,
searcher_kwargs: Optional[dict] = None,
num_samples: int | None = None,
max_epochs: int | None = None,
scheduler: str | None = None,
scheduler_kwargs: dict | None = None,
searcher: str | None = None,
searcher_kwargs: dict | None = None,
reporter: bool = True,
resources: Optional[dict] = None,
experiment_name: Optional[str] = None,
logging_dir: Optional[str] = None,
) -> Tuple[Any, dict]:
resources: dict | None = None,
experiment_name: str | None = None,
logging_dir: str | None = None,
) -> tuple[Any, dict]:
metric = (
metric or self._get_primary_metric_and_mode(self._registry["metrics"])[0]
)
additional_metrics = additional_metrics or []
search_space = search_space or {}
model_kwargs = model_kwargs or {}
train_kwargs = train_kwargs or {}
num_samples = num_samples or 10 # TODO: better default
max_epochs = max_epochs or 100 # TODO: better default
scheduler = scheduler or "asha"
Expand All @@ -481,6 +498,8 @@ def _get_tuner(
_trainable = self._get_trainable(
adata,
_metrics,
model_kwargs,
train_kwargs,
_resources,
_setup_method_name,
_setup_args,
Expand Down Expand Up @@ -537,7 +556,7 @@ def _get_analysis(self, results: Any, config: dict) -> TuneAnalysis:
)

@staticmethod
def _add_columns(table: rich.table.Table, columns: List[str]) -> rich.table.Table:
def _add_columns(table: rich.table.Table, columns: list[str]) -> rich.table.Table:
"""Adds columns to a :class:`~rich.table.Table` with default formatting."""
for i, column in enumerate(columns):
table.add_column(column, style=COLORS[i], **COLUMN_KWARGS)
Expand Down
6 changes: 6 additions & 0 deletions scvi/autotune/_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def fit(self, adata: AnnOrMuData, **kwargs) -> None:
provided as instantiated Ray Tune sample functions. Available
hyperparameters can be viewed with :meth:`~scvi.autotune.ModelTuner.info`.
Must be provided if `use_defaults` is `False`.
model_kwargs
Keyword arguments passed to the model class's constructor. Arguments must
not overlap with those in `search_space`.
train_kwargs
Keyword arguments passed to the model's `train` method. Arguments must not
overlap with those in `search_space`.
use_defaults
Whether to use the model class's default search space, which can be viewed
with :meth:`~scvi.autotune.ModelTuner.info`. If `True` and `search_space` is
Expand Down