diff --git a/tuning/autotune_new_model.ipynb b/tuning/autotune_new_model.ipynb index ecab445..ca408c8 100644 --- a/tuning/autotune_new_model.ipynb +++ b/tuning/autotune_new_model.ipynb @@ -1,676 +1,745 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Using autotune with a new model class" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{warning}\n", - "`scvi.autotune` development is still in progress. The API is subject to change.\n", - "```" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This tutorial provides an overview of how to prepare a new model to interact with [`scvi.autotune.ModelTuner`](https://docs.scvi-tools.org/en/latest/api/reference/scvi.autotune.ModelTuner.html#scvi.autotune.ModelTuner). For a high-level overview of `scvi.autotune`, see the tutorial for [model hyperparameter tuning with scVI](<>). This tutorial also assumes a general understanding of how models are implemented in `scvi-tools` as covered in the [model development tutorial](https://docs.scvi-tools.org/en/latest/tutorials/notebooks/model_user_guide.html).\n", - "\n", - "In particular, we will go through the following steps:\n", - "\n", - "1. Installing required packages\n", - "1. Creating a new model class\n", - "1. Exposing tunable hyperparameters\n", - "1. Exposing logged metrics\n", - "1. Using `TunableMixin`" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Installing required packages" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Uncomment the following lines in Google Colab in order to install `scvi-tools`:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install --quiet scvi-colab\n", - "# from scvi_colab import install\n", - "\n", - "# install()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import scvi\n", - "from flax.core import freeze\n", - "from ray import tune\n", - "from scvi._decorators import classproperty\n", - "from scvi._types import Tunable, TunableMixin\n", - "from scvi.autotune import ModelTuner" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 0\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Last run with scvi-tools version: 1.1.0\n" - ] - } - ], - "source": [ - "scvi.settings.seed = 0\n", - "print(\"Last run with scvi-tools version:\", scvi.__version__)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Creating a new model class" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To showcase how to use `scvi.autotune.ModelTuner` with a new model class, we will create a simple linear regression model with an $\\ell_1$ penalty in Jax (*i.e.*, Lasso)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "class Lasso:\n", - " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " n_input: int,\n", - " n_output: int,\n", - " l1_weight: float = 0.0,\n", - " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", - " ):\n", - " k1, k2 = jax.random.split(rng)\n", - " self.l1_weight = l1_weight\n", - " self.params = freeze(\n", - " {\n", - " \"w\": jax.random.normal(k1, (n_input, n_output)),\n", - " \"b\": jax.random.normal(k2, (n_output,)),\n", - " }\n", - " )\n", - "\n", - " def forward(self, params, x):\n", - " \"\"\"Forward pass.\"\"\"\n", - " return jnp.dot(x, params[\"w\"]) + params[\"b\"]\n", - "\n", - " def loss(self, params, x, y):\n", - " \"\"\"Mean squared error loss with L1 regularization.\"\"\"\n", - " mse = jnp.mean((self.forward(params, x) - y) ** 2)\n", - " l1 = self.l1_weight * jnp.sum(jnp.abs(self.params[\"w\"]))\n", - " return mse + l1\n", - "\n", - " def train(self, x, y, learning_rate: float = 1e-3, n_epochs: int = 500):\n", - " \"\"\"Train the model using gradient descent.\"\"\"\n", - " losses = []\n", - " for _ in range(n_epochs):\n", - " loss = self.loss(self.params, x, y)\n", - " grads = jax.grad(self.loss)(self.params, x, y)\n", - " self.params = freeze(\n", - " jax.tree_util.tree_map(\n", - " lambda p, g: p - learning_rate * g, self.params, grads\n", - " )\n", - " )\n", - " losses.append(loss)\n", - " return losses" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Exposing tunable hyperparameters" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For the model class above, we would like to expose the following hyperparameters as tunables: `l1_weight` and `learning_rate`, since these are non-trainable. We need two modifications to allow this:\n", - "\n", - "- Annotate the hyperparameters with the `Tunable` typing class\n", - "- Add a `_tunables` class property or attribute referencing functions that contain the tunable hyperparameters" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class LassoTunable(Lasso):\n", - " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " n_input: int,\n", - " n_output: int,\n", - " l1_weight: Tunable[float] = 0.0, # <<===== Add this\n", - " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", - " ):\n", - " super().__init__(n_input, n_output, l1_weight, rng)\n", - "\n", - " def train(\n", - " self,\n", - " x,\n", - " y,\n", - " learning_rate: Tunable[float] = 1e-3, # <<===== Add this\n", - " n_epochs: int = 500,\n", - " ):\n", - " super().train(x, y, learning_rate, n_epochs)\n", - "\n", - " # <<===== Add this =====>> #\n", - " @classproperty\n", - " def _tunables(cls):\n", - " return [cls.__init__, cls.train]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, we can set up a `ModelTuner` instance with our new model and quickly check everything is working as expected with `info()`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/martinkim/dev/scvi-tools/scvi/autotune/_manager.py:57: UserWarning: No default search space available for LassoTunable.\n", - " self._defaults = self._get_defaults(self._model_cls)\n" - ] - }, - { - "data": { - "text/html": [ - "
ModelTuner registry for LassoTunable\n",
-                            "
\n" - ], - "text/plain": [ - "ModelTuner registry for LassoTunable\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
             Tunable hyperparameters             \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Default value     Source    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
-                            "│   l1_weight          0.0       LassoTunable │\n",
-                            "│ learning_rate       0.001      LassoTunable │\n",
-                            "└────────────────┴───────────────┴──────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Tunable hyperparameters \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "└────────────────┴───────────────┴──────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
       Available metrics        \n",
-                            "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
-                            "┃     Metric          Mode    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
-                            "│ validation_loss     min     │\n",
-                            "└─────────────────┴────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Available metrics \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", - "└─────────────────┴────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                        Default search space                         \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
-                            "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Default search space \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", - "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "tuner = ModelTuner(LassoTunable)\n", - "tuner.info()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Exposing logged metrics" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To populate the metrics table, we add two new lines of code: a call to `ray.tune.report` in our `train` function that logs our loss, and a corresponding class property called `_metrics` that lists the key of the metric we log." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "class LassoTunable(Lasso):\n", - " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " n_input: int,\n", - " n_output: int,\n", - " l1_weight: Tunable[float] = 0.0,\n", - " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", - " ):\n", - " super().__init__(n_input, n_output, l1_weight, rng)\n", - "\n", - " def train(self, x, y, learning_rate: Tunable[float] = 1e-3, n_epochs: int = 500):\n", - " \"\"\"Train the model using gradient descent.\"\"\"\n", - " losses = []\n", - " for _ in range(n_epochs):\n", - " loss = self.loss(self.params, x, y)\n", - " grads = jax.grad(self.loss)(self.params, x, y)\n", - " self.params = freeze(\n", - " jax.tree_util.tree_map(\n", - " lambda p, g: p - learning_rate * g, self.params, grads\n", - " )\n", - " )\n", - " tune.report({\"mse_l1_loss\": loss}) # <<===== Add this\n", - " losses.append(loss)\n", - " return losses\n", - "\n", - " @classproperty\n", - " def _tunables(cls):\n", - " return [cls.__init__, cls.train]\n", - "\n", - " # <<===== Add this =====>> #\n", - " @classproperty\n", - " def _metrics(cls):\n", - " return [\"mse_l1_loss\"]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We see that our tuner instance has detected our desired metric, so now we can pass `mse_l1_loss` to `ModelTuner.fit` to be optimized." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
ModelTuner registry for LassoTunable\n",
-                            "
\n" - ], - "text/plain": [ - "ModelTuner registry for LassoTunable\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
             Tunable hyperparameters             \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Default value     Source    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
-                            "│   l1_weight          0.0       LassoTunable │\n",
-                            "│ learning_rate       0.001      LassoTunable │\n",
-                            "└────────────────┴───────────────┴──────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Tunable hyperparameters \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "└────────────────┴───────────────┴──────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
       Available metrics        \n",
-                            "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
-                            "┃     Metric          Mode    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
-                            "│ validation_loss     min     │\n",
-                            "└─────────────────┴────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Available metrics \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", - "└─────────────────┴────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                        Default search space                         \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
-                            "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Default search space \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", - "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "tuner = scvi.autotune.ModelTuner(LassoTunable)\n", - "tuner.info()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Using `TunableMixin`" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In practice, if a new model class is being developed using the base classes of `scvi-tools`, a simpler way to expose tunable hyperparameters and metrics is to use the `TunableMixin` class. This mixin class provides a flexible, default implementation of `_tunables` and `_metrics` that only requires the user to annotate keyword arguments with `Tunable`.\n", - "\n", - "It also allows for the recursive discovery of tunable hyperparameters, as is the case when higher-level model classes define modules as attributes, for example." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "class LassoModel(TunableMixin):\n", - " _module_cls = LassoTunable\n", - "\n", - " def __init__(self, adata, *args, **kwargs):\n", - " self.adata = adata\n", - " self.module = self._module_cls(*args, **kwargs)\n", - "\n", - " def model_func1(self, x, y):\n", - " pass\n", - "\n", - " def model_func2(self, x):\n", - " pass\n", - "\n", - " # etc..." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Additionally, if the model uses Lightning for the training procedure, calling `ray.tune.report` is not required as the integration is handled with a callback." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/martinkim/dev/scvi-tools/scvi/autotune/_manager.py:57: UserWarning: No default search space available for LassoModel.\n", - " self._defaults = self._get_defaults(self._model_cls)\n" - ] - }, - { - "data": { - "text/html": [ - "
ModelTuner registry for LassoModel\n",
-                            "
\n" - ], - "text/plain": [ - "ModelTuner registry for LassoModel\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
             Tunable hyperparameters             \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Default value     Source    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
-                            "│   l1_weight          0.0       LassoTunable │\n",
-                            "│ learning_rate       0.001      LassoTunable │\n",
-                            "└────────────────┴───────────────┴──────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Tunable hyperparameters \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", - "└────────────────┴───────────────┴──────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
       Available metrics        \n",
-                            "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
-                            "┃     Metric          Mode    ┃\n",
-                            "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
-                            "│ validation_loss     min     │\n",
-                            "└─────────────────┴────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Available metrics \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", - "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", - "└─────────────────┴────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
                        Default search space                         \n",
-                            "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
-                            "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
-                            "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
-                            "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
-                            "
\n" - ], - "text/plain": [ - "\u001b[3m Default search space \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", - "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "tuner = scvi.autotune.ModelTuner(LassoModel)\n", - "tuner.info()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "scvi-gpu", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "2f978838050607ec9770689d8200902a4128a2ce208b502e911dd714d57e924e" - } - } + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using autotune with a new model class" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{warning}\n", + "`scvi.autotune` development is still in progress. The API is subject to change.\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial provides an overview of how to prepare a new model to interact with [`scvi.autotune.ModelTuner`](https://docs.scvi-tools.org/en/latest/api/reference/scvi.autotune.ModelTuner.html#scvi.autotune.ModelTuner). For a high-level overview of `scvi.autotune`, see the tutorial for [model hyperparameter tuning with scVI](<>). This tutorial also assumes a general understanding of how models are implemented in `scvi-tools` as covered in the [model development tutorial](https://docs.scvi-tools.org/en/latest/tutorials/notebooks/model_user_guide.html).\n", + "\n", + "In particular, we will go through the following steps:\n", + "\n", + "1. Installing required packages\n", + "1. Creating a new model class\n", + "1. Exposing tunable hyperparameters\n", + "1. Exposing logged metrics\n", + "1. Using `TunableMixin`" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing required packages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Uncomment the following lines in Google Colab in order to install `scvi-tools`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:53.968764Z", + "iopub.status.busy": "2023-12-07T18:54:53.968632Z", + "iopub.status.idle": "2023-12-07T18:54:53.971169Z", + "shell.execute_reply": "2023-12-07T18:54:53.970760Z" + } + }, + "outputs": [], + "source": [ + "# !pip install --quiet scvi-colab\n", + "# from scvi_colab import install\n", + "\n", + "# install()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:53.972701Z", + "iopub.status.busy": "2023-12-07T18:54:53.972484Z", + "iopub.status.idle": "2023-12-07T18:54:57.174587Z", + "shell.execute_reply": "2023-12-07T18:54:57.174126Z" + } + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import scvi\n", + "from flax.core import freeze\n", + "from ray import tune\n", + "from scvi._decorators import classproperty\n", + "from scvi._types import Tunable, TunableMixin\n", + "from scvi.autotune import ModelTuner" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.176253Z", + "iopub.status.busy": "2023-12-07T18:54:57.176140Z", + "iopub.status.idle": "2023-12-07T18:54:57.179176Z", + "shell.execute_reply": "2023-12-07T18:54:57.178905Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 0\n" + ] }, - "nbformat": 4, - "nbformat_minor": 2 + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last run with scvi-tools version: 1.1.0\n" + ] + } + ], + "source": [ + "scvi.settings.seed = 0\n", + "print(\"Last run with scvi-tools version:\", scvi.__version__)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a new model class" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To showcase how to use `scvi.autotune.ModelTuner` with a new model class, we will create a simple linear regression model with an $\\ell_1$ penalty in Jax (*i.e.*, Lasso)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.198115Z", + "iopub.status.busy": "2023-12-07T18:54:57.197992Z", + "iopub.status.idle": "2023-12-07T18:54:57.969599Z", + "shell.execute_reply": "2023-12-07T18:54:57.969115Z" + } + }, + "outputs": [], + "source": [ + "class Lasso:\n", + " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " n_input: int,\n", + " n_output: int,\n", + " l1_weight: float = 0.0,\n", + " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", + " ):\n", + " k1, k2 = jax.random.split(rng)\n", + " self.l1_weight = l1_weight\n", + " self.params = freeze(\n", + " {\n", + " \"w\": jax.random.normal(k1, (n_input, n_output)),\n", + " \"b\": jax.random.normal(k2, (n_output,)),\n", + " }\n", + " )\n", + "\n", + " def forward(self, params, x):\n", + " \"\"\"Forward pass.\"\"\"\n", + " return jnp.dot(x, params[\"w\"]) + params[\"b\"]\n", + "\n", + " def loss(self, params, x, y):\n", + " \"\"\"Mean squared error loss with L1 regularization.\"\"\"\n", + " mse = jnp.mean((self.forward(params, x) - y) ** 2)\n", + " l1 = self.l1_weight * jnp.sum(jnp.abs(self.params[\"w\"]))\n", + " return mse + l1\n", + "\n", + " def train(self, x, y, learning_rate: float = 1e-3, n_epochs: int = 500):\n", + " \"\"\"Train the model using gradient descent.\"\"\"\n", + " losses = []\n", + " for _ in range(n_epochs):\n", + " loss = self.loss(self.params, x, y)\n", + " grads = jax.grad(self.loss)(self.params, x, y)\n", + " self.params = freeze(\n", + " jax.tree_util.tree_map(\n", + " lambda p, g: p - learning_rate * g, self.params, grads\n", + " )\n", + " )\n", + " losses.append(loss)\n", + " return losses" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exposing tunable hyperparameters" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the model class above, we would like to expose the following hyperparameters as tunables: `l1_weight` and `learning_rate`, since these are non-trainable. We need two modifications to allow this:\n", + "\n", + "- Annotate the hyperparameters with the `Tunable` typing class\n", + "- Add a `_tunables` class property or attribute referencing functions that contain the tunable hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.971456Z", + "iopub.status.busy": "2023-12-07T18:54:57.971337Z", + "iopub.status.idle": "2023-12-07T18:54:57.974931Z", + "shell.execute_reply": "2023-12-07T18:54:57.974495Z" + } + }, + "outputs": [], + "source": [ + "class LassoTunable(Lasso):\n", + " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " n_input: int,\n", + " n_output: int,\n", + " l1_weight: Tunable[float] = 0.0, # <<===== Add this\n", + " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", + " ):\n", + " super().__init__(n_input, n_output, l1_weight, rng)\n", + "\n", + " def train(\n", + " self,\n", + " x,\n", + " y,\n", + " learning_rate: Tunable[float] = 1e-3, # <<===== Add this\n", + " n_epochs: int = 500,\n", + " ):\n", + " super().train(x, y, learning_rate, n_epochs)\n", + "\n", + " # <<===== Add this =====>> #\n", + " @classproperty\n", + " def _tunables(cls):\n", + " return [cls.__init__, cls.train]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can set up a `ModelTuner` instance with our new model and quickly check everything is working as expected with `info()`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.976450Z", + "iopub.status.busy": "2023-12-07T18:54:57.976302Z", + "iopub.status.idle": "2023-12-07T18:54:57.985810Z", + "shell.execute_reply": "2023-12-07T18:54:57.985547Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/env/lib/python3.11/site-packages/scvi/autotune/_manager.py:57: UserWarning: No default search space available for LassoTunable.\n", + " self._defaults = self._get_defaults(self._model_cls)\n" + ] + }, + { + "data": { + "text/html": [ + "
ModelTuner registry for LassoTunable\n",
+       "
\n" + ], + "text/plain": [ + "ModelTuner registry for LassoTunable\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
             Tunable hyperparameters             \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Default value     Source    ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
+       "│   l1_weight          0.0       LassoTunable │\n",
+       "│ learning_rate       0.001      LassoTunable │\n",
+       "└────────────────┴───────────────┴──────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Tunable hyperparameters \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "└────────────────┴───────────────┴──────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
       Available metrics        \n",
+       "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
+       "┃     Metric          Mode    ┃\n",
+       "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
+       "│ validation_loss     min     │\n",
+       "└─────────────────┴────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Available metrics \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", + "└─────────────────┴────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
                        Default search space                         \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
+       "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Default search space \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tuner = ModelTuner(LassoTunable)\n", + "tuner.info()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exposing logged metrics" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To populate the metrics table, we add two new lines of code: a call to `ray.tune.report` in our `train` function that logs our loss, and a corresponding class property called `_metrics` that lists the key of the metric we log." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.987224Z", + "iopub.status.busy": "2023-12-07T18:54:57.987117Z", + "iopub.status.idle": "2023-12-07T18:54:57.990823Z", + "shell.execute_reply": "2023-12-07T18:54:57.990554Z" + } + }, + "outputs": [], + "source": [ + "class LassoTunable(Lasso):\n", + " \"\"\"Linear regression model with l1 penalty in Jax.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " n_input: int,\n", + " n_output: int,\n", + " l1_weight: Tunable[float] = 0.0,\n", + " rng: jax.random.PRNGKey = jax.random.PRNGKey(0),\n", + " ):\n", + " super().__init__(n_input, n_output, l1_weight, rng)\n", + "\n", + " def train(self, x, y, learning_rate: Tunable[float] = 1e-3, n_epochs: int = 500):\n", + " \"\"\"Train the model using gradient descent.\"\"\"\n", + " losses = []\n", + " for _ in range(n_epochs):\n", + " loss = self.loss(self.params, x, y)\n", + " grads = jax.grad(self.loss)(self.params, x, y)\n", + " self.params = freeze(\n", + " jax.tree_util.tree_map(\n", + " lambda p, g: p - learning_rate * g, self.params, grads\n", + " )\n", + " )\n", + " tune.report({\"mse_l1_loss\": loss}) # <<===== Add this\n", + " losses.append(loss)\n", + " return losses\n", + "\n", + " @classproperty\n", + " def _tunables(cls):\n", + " return [cls.__init__, cls.train]\n", + "\n", + " # <<===== Add this =====>> #\n", + " @classproperty\n", + " def _metrics(cls):\n", + " return [\"mse_l1_loss\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that our tuner instance has detected our desired metric, so now we can pass `mse_l1_loss` to `ModelTuner.fit` to be optimized." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.992163Z", + "iopub.status.busy": "2023-12-07T18:54:57.992054Z", + "iopub.status.idle": "2023-12-07T18:54:57.998183Z", + "shell.execute_reply": "2023-12-07T18:54:57.997887Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
ModelTuner registry for LassoTunable\n",
+       "
\n" + ], + "text/plain": [ + "ModelTuner registry for LassoTunable\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
             Tunable hyperparameters             \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Default value     Source    ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
+       "│   l1_weight          0.0       LassoTunable │\n",
+       "│ learning_rate       0.001      LassoTunable │\n",
+       "└────────────────┴───────────────┴──────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Tunable hyperparameters \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "└────────────────┴───────────────┴──────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
       Available metrics        \n",
+       "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
+       "┃     Metric          Mode    ┃\n",
+       "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
+       "│ validation_loss     min     │\n",
+       "└─────────────────┴────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Available metrics \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", + "└─────────────────┴────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
                        Default search space                         \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
+       "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Default search space \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tuner = scvi.autotune.ModelTuner(LassoTunable)\n", + "tuner.info()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using `TunableMixin`" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In practice, if a new model class is being developed using the base classes of `scvi-tools`, a simpler way to expose tunable hyperparameters and metrics is to use the `TunableMixin` class. This mixin class provides a flexible, default implementation of `_tunables` and `_metrics` that only requires the user to annotate keyword arguments with `Tunable`.\n", + "\n", + "It also allows for the recursive discovery of tunable hyperparameters, as is the case when higher-level model classes define modules as attributes, for example." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:57.999692Z", + "iopub.status.busy": "2023-12-07T18:54:57.999459Z", + "iopub.status.idle": "2023-12-07T18:54:58.001659Z", + "shell.execute_reply": "2023-12-07T18:54:58.001397Z" + } + }, + "outputs": [], + "source": [ + "class LassoModel(TunableMixin):\n", + " _module_cls = LassoTunable\n", + "\n", + " def __init__(self, adata, *args, **kwargs):\n", + " self.adata = adata\n", + " self.module = self._module_cls(*args, **kwargs)\n", + "\n", + " def model_func1(self, x, y):\n", + " pass\n", + "\n", + " def model_func2(self, x):\n", + " pass\n", + "\n", + " # etc..." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, if the model uses Lightning for the training procedure, calling `ray.tune.report` is not required as the integration is handled with a callback." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2023-12-07T18:54:58.003124Z", + "iopub.status.busy": "2023-12-07T18:54:58.002907Z", + "iopub.status.idle": "2023-12-07T18:54:58.009191Z", + "shell.execute_reply": "2023-12-07T18:54:58.008932Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/env/lib/python3.11/site-packages/scvi/autotune/_manager.py:57: UserWarning: No default search space available for LassoModel.\n", + " self._defaults = self._get_defaults(self._model_cls)\n" + ] + }, + { + "data": { + "text/html": [ + "
ModelTuner registry for LassoModel\n",
+       "
\n" + ], + "text/plain": [ + "ModelTuner registry for LassoModel\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
             Tunable hyperparameters             \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Default value     Source    ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
+       "│   l1_weight          0.0       LassoTunable │\n",
+       "│ learning_rate       0.001      LassoTunable │\n",
+       "└────────────────┴───────────────┴──────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Tunable hyperparameters \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDefault value\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Source \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33m l1_weight \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.0 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mlearning_rate \u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m 0.001 \u001b[0m\u001b[38;5;128m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mLassoTunable\u001b[0m\u001b[32m \u001b[0m│\n", + "└────────────────┴───────────────┴──────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
       Available metrics        \n",
+       "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n",
+       "┃     Metric          Mode    ┃\n",
+       "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n",
+       "│ validation_loss     min     │\n",
+       "└─────────────────┴────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Available metrics \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Mode \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", + "│\u001b[38;5;33m \u001b[0m\u001b[38;5;33mvalidation_loss\u001b[0m\u001b[38;5;33m \u001b[0m│\u001b[38;5;128m \u001b[0m\u001b[38;5;128m min \u001b[0m\u001b[38;5;128m \u001b[0m│\n", + "└─────────────────┴────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
                        Default search space                         \n",
+       "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃\n",
+       "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
+       "└────────────────┴─────────────────┴────────────┴───────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Default search space \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mHyperparameter\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSample function\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mArguments \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mKeyword arguments\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "└────────────────┴─────────────────┴────────────┴───────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tuner = scvi.autotune.ModelTuner(LassoModel)\n", + "tuner.info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "scvi-gpu", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + }, + "vscode": { + "interpreter": { + "hash": "2f978838050607ec9770689d8200902a4128a2ce208b502e911dd714d57e924e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 }