Using autotune with a new model class#
Warning
scvi.autotune
development is still in progress. The API is subject to change.
This tutorial provides an overview of how to prepare a new model to interact with scvi.autotune.ModelTuner
. For a high-level overview of scvi.autotune
, see the tutorial for model hyperparameter tuning with scVI. This tutorial also assumes a general understanding of how models are implemented in scvi-tools
as covered in the model development tutorial.
In particular, we will go through the following steps:
Installing required packages
Creating a new model class
Exposing tunable hyperparameters
Exposing logged metrics
Using
TunableMixin
Installing required packages#
Uncomment the following lines in Google Colab in order to install scvi-tools
:
# !pip install --quiet scvi-colab
# from scvi_colab import install
# install()
import jax
import jax.numpy as jnp
import scvi
from flax.core import freeze
from ray import tune
from scvi._decorators import classproperty
from scvi._types import Tunable, TunableMixin
from scvi.autotune import ModelTuner
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.1.0
Creating a new model class#
To showcase how to use scvi.autotune.ModelTuner
with a new model class, we will create a simple linear regression model with an \(\ell_1\) penalty in Jax (i.e., Lasso).
class Lasso:
"""Linear regression model with l1 penalty in Jax."""
def __init__(
self,
n_input: int,
n_output: int,
l1_weight: float = 0.0,
rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
):
k1, k2 = jax.random.split(rng)
self.l1_weight = l1_weight
self.params = freeze(
{
"w": jax.random.normal(k1, (n_input, n_output)),
"b": jax.random.normal(k2, (n_output,)),
}
)
def forward(self, params, x):
"""Forward pass."""
return jnp.dot(x, params["w"]) + params["b"]
def loss(self, params, x, y):
"""Mean squared error loss with L1 regularization."""
mse = jnp.mean((self.forward(params, x) - y) ** 2)
l1 = self.l1_weight * jnp.sum(jnp.abs(self.params["w"]))
return mse + l1
def train(self, x, y, learning_rate: float = 1e-3, n_epochs: int = 500):
"""Train the model using gradient descent."""
losses = []
for _ in range(n_epochs):
loss = self.loss(self.params, x, y)
grads = jax.grad(self.loss)(self.params, x, y)
self.params = freeze(
jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, self.params, grads
)
)
losses.append(loss)
return losses
Exposing tunable hyperparameters#
For the model class above, we would like to expose the following hyperparameters as tunables: l1_weight
and learning_rate
, since these are non-trainable. We need two modifications to allow this:
Annotate the hyperparameters with the
Tunable
typing classAdd a
_tunables
class property or attribute referencing functions that contain the tunable hyperparameters
class LassoTunable(Lasso):
"""Linear regression model with l1 penalty in Jax."""
def __init__(
self,
n_input: int,
n_output: int,
l1_weight: Tunable[float] = 0.0, # <<===== Add this
rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
):
super().__init__(n_input, n_output, l1_weight, rng)
def train(
self,
x,
y,
learning_rate: Tunable[float] = 1e-3, # <<===== Add this
n_epochs: int = 500,
):
super().train(x, y, learning_rate, n_epochs)
# <<===== Add this =====>> #
@classproperty
def _tunables(cls):
return [cls.__init__, cls.train]
Now, we can set up a ModelTuner
instance with our new model and quickly check everything is working as expected with info()
.
tuner = ModelTuner(LassoTunable)
tuner.info()
ModelTuner registry for LassoTunable
Tunable hyperparameters ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Default value ┃ Source ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ │ l1_weight │ 0.0 │ LassoTunable │ │ learning_rate │ 0.001 │ LassoTunable │ └────────────────┴───────────────┴──────────────┘
Available metrics ┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Metric ┃ Mode ┃ ┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ validation_loss │ min │ └─────────────────┴────────────┘
Default search space ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Sample function ┃ Arguments ┃ Keyword arguments ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ └────────────────┴─────────────────┴────────────┴───────────────────┘
Exposing logged metrics#
To populate the metrics table, we add two new lines of code: a call to ray.tune.report
in our train
function that logs our loss, and a corresponding class property called _metrics
that lists the key of the metric we log.
class LassoTunable(Lasso):
"""Linear regression model with l1 penalty in Jax."""
def __init__(
self,
n_input: int,
n_output: int,
l1_weight: Tunable[float] = 0.0,
rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
):
super().__init__(n_input, n_output, l1_weight, rng)
def train(self, x, y, learning_rate: Tunable[float] = 1e-3, n_epochs: int = 500):
"""Train the model using gradient descent."""
losses = []
for _ in range(n_epochs):
loss = self.loss(self.params, x, y)
grads = jax.grad(self.loss)(self.params, x, y)
self.params = freeze(
jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, self.params, grads
)
)
tune.report({"mse_l1_loss": loss}) # <<===== Add this
losses.append(loss)
return losses
@classproperty
def _tunables(cls):
return [cls.__init__, cls.train]
# <<===== Add this =====>> #
@classproperty
def _metrics(cls):
return ["mse_l1_loss"]
We see that our tuner instance has detected our desired metric, so now we can pass mse_l1_loss
to ModelTuner.fit
to be optimized.
tuner = scvi.autotune.ModelTuner(LassoTunable)
tuner.info()
ModelTuner registry for LassoTunable
Tunable hyperparameters ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Default value ┃ Source ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ │ l1_weight │ 0.0 │ LassoTunable │ │ learning_rate │ 0.001 │ LassoTunable │ └────────────────┴───────────────┴──────────────┘
Available metrics ┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Metric ┃ Mode ┃ ┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ validation_loss │ min │ └─────────────────┴────────────┘
Default search space ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Sample function ┃ Arguments ┃ Keyword arguments ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ └────────────────┴─────────────────┴────────────┴───────────────────┘
Using TunableMixin
#
In practice, if a new model class is being developed using the base classes of scvi-tools
, a simpler way to expose tunable hyperparameters and metrics is to use the TunableMixin
class. This mixin class provides a flexible, default implementation of _tunables
and _metrics
that only requires the user to annotate keyword arguments with Tunable
.
It also allows for the recursive discovery of tunable hyperparameters, as is the case when higher-level model classes define modules as attributes, for example.
class LassoModel(TunableMixin):
_module_cls = LassoTunable
def __init__(self, adata, *args, **kwargs):
self.adata = adata
self.module = self._module_cls(*args, **kwargs)
def model_func1(self, x, y):
pass
def model_func2(self, x):
pass
# etc...
Additionally, if the model uses Lightning for the training procedure, calling ray.tune.report
is not required as the integration is handled with a callback.
tuner = scvi.autotune.ModelTuner(LassoModel)
tuner.info()
ModelTuner registry for LassoModel
Tunable hyperparameters ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Default value ┃ Source ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ │ l1_weight │ 0.0 │ LassoTunable │ │ learning_rate │ 0.001 │ LassoTunable │ └────────────────┴───────────────┴──────────────┘
Available metrics ┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Metric ┃ Mode ┃ ┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ validation_loss │ min │ └─────────────────┴────────────┘
Default search space ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Hyperparameter ┃ Sample function ┃ Arguments ┃ Keyword arguments ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ └────────────────┴─────────────────┴────────────┴───────────────────┘