Note

This page was generated from autotune_scvi.ipynb. Interactive online version: Colab badge. Some tutorial content may look better in light mode.

Model hyperparameter tuning with scVI#

Warning

scvi.autotune development is still in progress. The API is subject to change.

Finding an effective set of model hyperparameters (e.g. learning rate, number of hidden layers, etc.) is an important component in training a model as its performance can be highly dependent on these non-trainable parameters. Manually tuning a model often involves picking a set of hyperparameters to search over and then evaluating different configurations over a validation set for a desired metric. This process can be time consuming and can require some prior intuition about a model and dataset pair, which is not always feasible.

In this tutorial, we show how to use scvi’s `autotune <https://docs.scvi-tools.org/en/latest/api/user.html#model-hyperparameter-autotuning>`__ module, which allows us to automatically find a good set of model hyperparameters using Ray Tune. We will use SCVI and a subsample of the heart cell atlas for the task of batch correction, but the principles outlined here can be applied to any model and dataset. In particular, we will go through the following steps:

  1. Installing required packages

  2. Loading and preprocessing the dataset

  3. Defining the tuner and discovering hyperparameters

  4. Running the tuner

  5. Comparing latent spaces

  6. Optional: Monitoring progress with Tensorboard

  7. Optional: Tuning over integration metrics with scib-metrics

Installing required packages#

[ ]:
!pip install --quiet hyperopt
!pip install --quiet "ray[tune]"
!pip install --quiet scvi-colab
from scvi_colab import install

install()
[1]:
import ray
import scanpy as sc
import scvi
from ray import tune
from scvi import autotune
Global seed set to 0

Loading and preprocessing the dataset#

[2]:
adata = scvi.data.heart_cell_atlas_subsampled()
adata
INFO     File data/hca_subsampled_20k.h5ad already downloaded
[2]:
AnnData object with n_obs × n_vars = 18641 × 26662
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts'
    uns: 'cell_type_colors'

The only preprocessing step we will perform in this case will be to subsample the dataset for 2000 highly variable genes using scanpy for faster model training.

[3]:
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
adata
[3]:
AnnData object with n_obs × n_vars = 18641 × 2000
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'cell_type_colors', 'hvg'

Defining the tuner and discovering hyperparameters#

The first part of our workflow is the same as the standard scvi-tools workflow: we start with our desired model class, and we register our dataset with it using setup_anndata. All datasets must be registered prior to hyperparameter tuning.

[4]:
model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata)

Our main entry point to the autotune module is the ModelTuner class, a wrapper around `ray.tune.Tuner <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tuner>`__ with additional functionality specific to scvi-tools. We can define a new ModelTuner by providing it with our model class.

[5]:
scvi_tuner = autotune.ModelTuner(model_cls)

ModelTuner will register all tunable hyperparameters in SCVI – these can be viewed by calling info(). By default, this method will display three tables:

  1. Tunable hyperparameters: The names of hyperparameters that can be tuned, their default values, and the internal classes they are defined in.

  2. Available metrics: The metrics that can be used to evaluate the performance of the model. One of these must be provided when running the tuner.

  3. Default search space: The default search space for the model class, which will be used if no search space is provided by the user.

[6]:
scvi_tuner.info()
ModelTuner registry for SCVI
                  Tunable hyperparameters                  
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃      Hyperparameter       Default value     Source    ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│         n_hidden               128           VAE      │
│         n_latent               10            VAE      │
│         n_layers                1            VAE      │
│       dropout_rate             0.1           VAE      │
│        dispersion             gene           VAE      │
│     gene_likelihood           zinb           VAE      │
│   latent_distribution        normal          VAE      │
│    encode_covariates          False          VAE      │
│ deeply_inject_covariates      True           VAE      │
│      use_batch_norm           both           VAE      │
│      use_layer_norm           none           VAE      │
│        optimizer              Adam       TrainingPlan │
│            lr                 0.001      TrainingPlan │
│       weight_decay            1e-06      TrainingPlan │
│           eps                 0.01       TrainingPlan │
│    n_steps_kl_warmup          None       TrainingPlan │
│    n_epochs_kl_warmup          400       TrainingPlan │
│   reduce_lr_on_plateau        False      TrainingPlan │
│        lr_factor               0.6       TrainingPlan │
│       lr_patience              30        TrainingPlan │
│       lr_threshold             0.0       TrainingPlan │
│          lr_min                 0        TrainingPlan │
│      max_kl_weight             1.0       TrainingPlan │
│      min_kl_weight             0.0       TrainingPlan │
└──────────────────────────┴───────────────┴──────────────┘
       Available metrics        
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃     Metric          Mode    ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ validation_loss     min     │
└─────────────────┴────────────┘
                         Default search space                         
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Hyperparameter  Sample function   Arguments   Keyword arguments ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│    n_hidden         choice       [[64, 128]]         {}         │
└────────────────┴─────────────────┴─────────────┴───────────────────┘

Running the tuner#

Now that we know what hyperparameters are available to us, we can define a search space using the search space API in ray.tune. For this tutorial, we choose a simple search space with two model hyperparameters and one training plan hyperparameter. These can all be combined into a single dictionary that we pass into the fit method.

[7]:
search_space = {
    "n_hidden": tune.choice([64, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
}

There are a couple more arguments we should be aware of before fitting the tuner:

  • num_samples: The total number of hyperparameter sets to sample from search_space. This is the total number of models that will be trained.

    For example, if we set num_samples=2, we might sample two models with the following hyperparameter configurations:

    model1 = {
        "n_hidden": 64,
        "n_layers": 1,
        "lr": 0.001,
    }
    model2 = {
        "n_hidden": 64,
        "n_layers": 3,
        "lr": 0.0001,
    }
    
  • max_epochs: The maximum number of epochs to train each model for.

    Note: This does not mean that each model will be trained for max_epochs. Depending on the scheduler used, some trials are likely to be stopped early.

  • resources: A dictionary of maximum resources to allocate for the whole experiment. This allows us to run concurrent trials on limited hardware.

Now, we can call fit on the tuner to start the hyperparameter sweep. This will return a TuneAnalysis dataclass, which will contain the best set of hyperparameters, as well as other information.

[8]:
ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=5,
    max_epochs=100,
    resources={"cpu": 20, "gpu": 1},
)

Tune Status

Current time:2023-01-30 13:44:04
Running for: 00:03:01.04
Memory: 10.9/125.7 GiB

System Info

Using AsyncHyperBand: num_stopped=5
Bracket: Iter 64.000: -466.8729553222656 | Iter 32.000: -467.76768493652344 | Iter 16.000: -473.09934997558594 | Iter 8.000: -481.9038391113281 | Iter 4.000: -493.9166717529297 | Iter 2.000: -516.9609985351562 | Iter 1.000: -559.62353515625
Resources requested: 0/20 CPUs, 0/1 GPUs, 0.0/74.12 GiB heap, 0.0/35.76 GiB objects (0.0/1.0 accelerator_type:G)

Trial Status

Trial name status loc n_hidden n_layers lr validation_loss
_trainable_cb47e_00000TERMINATED128.32.142.133:515229 128 10.000109758 468.995
_trainable_cb47e_00001TERMINATED128.32.142.133:515229 256 10.00906226 469.203
_trainable_cb47e_00002TERMINATED128.32.142.133:515229 128 30.00276226 616.49
_trainable_cb47e_00003TERMINATED128.32.142.133:515229 128 10.00110585 516.961
_trainable_cb47e_00004TERMINATED128.32.142.133:515229 256 30.000817148 595.537
2023-01-30 13:44:04,458 INFO tune.py:762 -- Total run time: 181.26 seconds (181.03 seconds for the tuning loop).
[9]:
print(results.model_kwargs)
print(results.train_kwargs)
{'n_hidden': 128, 'n_layers': 1}
{'plan_kwargs': {}}

Comparing latent spaces#

Work in progress: please check back in the next release!

Optional: Monitoring progress with Tensorboard#

Work in progress: please check back in the next release!

Optional: Tuning over integration metrics with scib-metrics#

Work in progress: please check back in the next release!