Using autotune with a new model class#

Warning

scvi.autotune development is still in progress. The API is subject to change.

This tutorial provides an overview of how to prepare a new model to interact with scvi.autotune.ModelTuner. For a high-level overview of scvi.autotune, see the tutorial for model hyperparameter tuning with scVI. This tutorial also assumes a general understanding of how models are implemented in scvi-tools as covered in the model development tutorial.

In particular, we will go through the following steps:

  1. Installing required packages

  2. Creating a new model class

  3. Exposing tunable hyperparameters

  4. Exposing logged metrics

  5. Using TunableMixin

Installing required packages#

Note

Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

import jax
import jax.numpy as jnp
import scvi
from flax.core import freeze
from ray import tune
from scvi._decorators import classproperty
from scvi._types import Tunable, TunableMixin
from scvi.autotune import ModelTuner
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.1.0

Creating a new model class#

To showcase how to use scvi.autotune.ModelTuner with a new model class, we will create a simple linear regression model with an \(\ell_1\) penalty in Jax (i.e., Lasso).

class Lasso:
    """Linear regression model with l1 penalty in Jax."""

    def __init__(
        self,
        n_input: int,
        n_output: int,
        l1_weight: float = 0.0,
        rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
    ):
        k1, k2 = jax.random.split(rng)
        self.l1_weight = l1_weight
        self.params = freeze(
            {
                "w": jax.random.normal(k1, (n_input, n_output)),
                "b": jax.random.normal(k2, (n_output,)),
            }
        )

    def forward(self, params, x):
        """Forward pass."""
        return jnp.dot(x, params["w"]) + params["b"]

    def loss(self, params, x, y):
        """Mean squared error loss with L1 regularization."""
        mse = jnp.mean((self.forward(params, x) - y) ** 2)
        l1 = self.l1_weight * jnp.sum(jnp.abs(self.params["w"]))
        return mse + l1

    def train(self, x, y, learning_rate: float = 1e-3, n_epochs: int = 500):
        """Train the model using gradient descent."""
        losses = []
        for _ in range(n_epochs):
            loss = self.loss(self.params, x, y)
            grads = jax.grad(self.loss)(self.params, x, y)
            self.params = freeze(
                jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, self.params, grads)
            )
            losses.append(loss)
        return losses

Exposing tunable hyperparameters#

For the model class above, we would like to expose the following hyperparameters as tunables: l1_weight and learning_rate, since these are non-trainable. We need two modifications to allow this:

  • Annotate the hyperparameters with the Tunable typing class

  • Add a _tunables class property or attribute referencing functions that contain the tunable hyperparameters

class LassoTunable(Lasso):
    """Linear regression model with l1 penalty in Jax."""

    def __init__(
        self,
        n_input: int,
        n_output: int,
        l1_weight: Tunable[float] = 0.0,  # <<===== Add this
        rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
    ):
        super().__init__(n_input, n_output, l1_weight, rng)

    def train(
        self,
        x,
        y,
        learning_rate: Tunable[float] = 1e-3,  # <<===== Add this
        n_epochs: int = 500,
    ):
        super().train(x, y, learning_rate, n_epochs)

    # <<===== Add this =====>> #
    @classproperty
    def _tunables(cls):
        return [cls.__init__, cls.train]

Now, we can set up a ModelTuner instance with our new model and quickly check everything is working as expected with info().

tuner = ModelTuner(LassoTunable)
tuner.info()
ModelTuner registry for LassoTunable
             Tunable hyperparameters             
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ Hyperparameter  Default value     Source    ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│   l1_weight          0.0       LassoTunable │
│ learning_rate       0.001      LassoTunable │
└────────────────┴───────────────┴──────────────┘
       Available metrics        
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃     Metric          Mode    ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ validation_loss     min     │
└─────────────────┴────────────┘
                        Default search space                         
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
└────────────────┴─────────────────┴────────────┴───────────────────┘

Exposing logged metrics#

To populate the metrics table, we add two new lines of code: a call to ray.tune.report in our train function that logs our loss, and a corresponding class property called _metrics that lists the key of the metric we log.

class LassoTunable(Lasso):
    """Linear regression model with l1 penalty in Jax."""

    def __init__(
        self,
        n_input: int,
        n_output: int,
        l1_weight: Tunable[float] = 0.0,
        rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
    ):
        super().__init__(n_input, n_output, l1_weight, rng)

    def train(self, x, y, learning_rate: Tunable[float] = 1e-3, n_epochs: int = 500):
        """Train the model using gradient descent."""
        losses = []
        for _ in range(n_epochs):
            loss = self.loss(self.params, x, y)
            grads = jax.grad(self.loss)(self.params, x, y)
            self.params = freeze(
                jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, self.params, grads)
            )
            tune.report({"mse_l1_loss": loss})  # <<===== Add this
            losses.append(loss)
        return losses

    @classproperty
    def _tunables(cls):
        return [cls.__init__, cls.train]

    # <<===== Add this =====>> #
    @classproperty
    def _metrics(cls):
        return ["mse_l1_loss"]

We see that our tuner instance has detected our desired metric, so now we can pass mse_l1_loss to ModelTuner.fit to be optimized.

tuner = scvi.autotune.ModelTuner(LassoTunable)
tuner.info()
ModelTuner registry for LassoTunable
             Tunable hyperparameters             
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ Hyperparameter  Default value     Source    ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│   l1_weight          0.0       LassoTunable │
│ learning_rate       0.001      LassoTunable │
└────────────────┴───────────────┴──────────────┘
       Available metrics        
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃     Metric          Mode    ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ validation_loss     min     │
└─────────────────┴────────────┘
                        Default search space                         
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
└────────────────┴─────────────────┴────────────┴───────────────────┘

Using TunableMixin#

In practice, if a new model class is being developed using the base classes of scvi-tools, a simpler way to expose tunable hyperparameters and metrics is to use the TunableMixin class. This mixin class provides a flexible, default implementation of _tunables and _metrics that only requires the user to annotate keyword arguments with Tunable.

It also allows for the recursive discovery of tunable hyperparameters, as is the case when higher-level model classes define modules as attributes, for example.

class LassoModel(TunableMixin):
    _module_cls = LassoTunable

    def __init__(self, adata, *args, **kwargs):
        self.adata = adata
        self.module = self._module_cls(*args, **kwargs)

    def model_func1(self, x, y):
        pass

    def model_func2(self, x):
        pass

    # etc...

Additionally, if the model uses Lightning for the training procedure, calling ray.tune.report is not required as the integration is handled with a callback.

tuner = scvi.autotune.ModelTuner(LassoModel)
tuner.info()
ModelTuner registry for LassoModel
             Tunable hyperparameters             
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ Hyperparameter  Default value     Source    ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│   l1_weight          0.0       LassoTunable │
│ learning_rate       0.001      LassoTunable │
└────────────────┴───────────────┴──────────────┘
       Available metrics        
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃     Metric          Mode    ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ validation_loss     min     │
└─────────────────┴────────────┘
                        Default search space                         
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Hyperparameter  Sample function  Arguments   Keyword arguments ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
└────────────────┴─────────────────┴────────────┴───────────────────┘