Introduction to scvi-tools#

In this introductory tutorial, we go through the different steps of an scvi-tools workflow.

While we focus on scVI in this tutorial, the API is consistent across all models.

Note

Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
import os
import tempfile

import scanpy as sc
import scvi
import seaborn as sns
import torch
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.4.2

Note

You can modify save_dir below to change where the data files for this tutorial are saved.

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

Loading and preparing data#

Let us first load a subsampled version of the heart cell atlas dataset described in Litviňuková et al. (2020). scvi-tools has many “built-in” datasets as well as support for loading arbitrary .csv, .loom, and .h5ad (AnnData) files. Please see our tutorial on data loading for more examples.

  • Litviňuková, M., Talavera-López, C., Maatz, H., Reichart, D., Worth, C. L., Lindberg, E. L., … & Teichmann, S. A. (2020). Cells of the adult human heart. Nature, 588(7838), 466-472.

Important

All scvi-tools models require AnnData objects as input.

adata = scvi.data.heart_cell_atlas_subsampled(save_path=save_dir.name)
adata
INFO     Downloading file at /tmp/tmp2lq3dsue/hca_subsampled_20k.h5ad
AnnData object with n_obs × n_vars = 18641 × 26662
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts'
    uns: 'cell_type_colors'

Now we preprocess the data to remove, for example, genes that are very lowly expressed and other outliers. For these tasks we prefer the Scanpy preprocessing module.

sc.pp.filter_genes(adata, min_counts=3)

In scRNA-seq analysis, it’s popular to normalize the data. These values are not used by scvi-tools, but given their popularity in other tasks as well as for visualization, we store them in the anndata object separately (via the .raw attribute).

Important

Unless otherwise specified, scvi-tools models require the raw counts (not log library size normalized). scvi-tools models will run for non-negative real-valued data, but we strongly suggest checking that these possibly non-count values are intended to represent pseudocounts (e.g. SoupX-corrected counts), and not some other normalized data, in which the variance/covariance structure of the data has changed dramatically.

adata.layers["counts"] = adata.X.copy()  # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # freeze the state in `.raw`

Finally, we perform feature selection, to reduce the number of features (genes in this case) used as input to the scvi-tools model. For best practices of how/when to perform feature selection, please refer to the model-specific tutorial. For scVI, we recommend anywhere from 1,000 to 10,000 HVGs, but it will be context-dependent.

sc.pp.highly_variable_genes(
    adata,
    n_top_genes=1200,
    subset=True,
    layer="counts",
    flavor="seurat_v3",
    batch_key="cell_source",
)

Now it’s time to run setup_anndata(), which alerts scvi-tools to the locations of various matrices inside the anndata. It’s important to run this function with the correct arguments so scvi-tools is notified that your dataset has batches, annotations, etc. For example, if batches are registered with scvi-tools, the subsequent model will correct for batch effects. See the full documentation for details.

In this dataset, there is a “cell_source” categorical covariate, and within each “cell_source”, multiple “donors”, “gender” and “age_group”. There are also two continuous covariates we’d like to correct for: “percent_mito” and “percent_ribo”. These covariates can be registered using the categorical_covariate_keys argument. If you only have one categorical covariate, you can also use the batch_key argument instead.

scvi.model.SCVI.setup_anndata(
    adata,
    layer="counts",
    categorical_covariate_keys=["cell_source", "donor"],
    continuous_covariate_keys=["percent_mito", "percent_ribo"],
)

Warning

If the adata is modified after running setup_anndata, please run setup_anndata again, before creating an instance of a model.

Creating and training a model#

While we highlight the scVI model here, the API is consistent across all scvi-tools models and is inspired by that of scikit-learn. For a full list of options, see the scvi documentation.

model = scvi.model.SCVI(adata)

We can see an overview of the model by printing it.

model
SCVI model with the following parameters: 
n_hidden: 128, n_latent: 10, n_layers: 1, dropout_rate: 0.1, dispersion: gene, gene_likelihood: zinb, 
latent_distribution: normal.
Training status: Not Trained
Model's adata is minified?: False

Important

All scvi-tools models run faster when using a GPU. By default, scvi-tools will use a GPU if one is found to be available. Please see the installation page for more information about installing scvi-tools when a GPU is available.

model.train(train_size=0.8, check_val_every_n_epoch=1)

Note how we trained the model with train_size=0.8, that means 80% of data will be used for training, while 20% will be used for model validation. The additional check_val_every_n_epoch = 1 means that validation losses will be recorded each epoch and may be used for other callbacks, for example early_stopping (stopping the training process before overfitting the model, based on validation loss).

Other parameters will not be discussed here, and we encourage the reader to view the rest of the API.

Saving and loading#

Saving consists of saving the model neural network weights, as well as parameters used to initialize the model.

model_dir = os.path.join(save_dir.name, "scvi_model")
model.save(model_dir, overwrite=True)
model = scvi.model.SCVI.load(model_dir, adata=adata)
INFO     File /tmp/tmp2lq3dsue/scvi_model/model.pt already downloaded

1.2.2. Loss Curves#

scVI-tools models stores their losses in the model.history object. We use it to plot the loss curves:

import matplotlib.pyplot as plt

hist = model.history


def plot_metric(train_key, val_key, title, ylabel):
    plt.figure(figsize=(6, 4))
    plt.plot(hist[train_key], label="train")
    plt.plot(hist[val_key], label="validation")
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()
# Negative ELBO (Evidence lower bound)
if "elbo_train" in hist:
    plot_metric("elbo_train", "elbo_validation", "ELBO", "ELBO")

The ELBO is the general loss SCVI model optimize and it consist of 2 parts: the Reconstrucion loss and the KL-divergence

# Reconstruction loss
plot_metric(
    "reconstruction_loss_train",
    "reconstruction_loss_validation",
    "Reconstruction loss",
    "Negative log likelihood",
)

The reconstruction loss is equal to the maximum likelihood term of the ELBO loss which is particaly how well the original gene counts data being reconstructed/decoded from the model embedding layer

# Local KL divergence
plot_metric(
    "kl_local_train", "kl_local_validation", "KL divergence (latent z)", "KL(q(z|x)||p(z))"
)

The KL divergence controls the latent space regularization and minizmizing it means to better approximate posterior and the prior

When trainnig SCVI models we often would like to prevent overfitting (validation loss curve increase for ELBO, KL divergence not stable, Reconstrucion loss explodes) and also train enough for not doing underfitting (train ELBO did not converge).

Both things are importnat for the correctness of the downstream analysis of the model.

Obtaining model outputs#

It’s often useful to store the outputs of scvi-tools back into the original anndata, as it permits interoperability with Scanpy.

SCVI_LATENT_KEY = "X_scVI"

latent = model.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
latent.shape
(18641, 10)

The model.get...() functions default to using the AnnData that was used to initialize the model. It’s possible to also query a subset of the anndata, or even use a completely independent anndata object as long as the anndata is organized in an equivalent fashion.

Here we show the usage of get_latent_representation() - The model learned latent (embeddings) space and get_normalized_expression() - the models estimated gene space

adata_subset = adata[adata.obs.cell_type == "Fibroblast"]
latent_subset = model.get_latent_representation(adata_subset)
latent_subset.shape
INFO     Received view of anndata, making copy.                                                                    
INFO     Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup
(2446, 10)
denoised = model.get_normalized_expression(adata_subset, library_size=1e4)
denoised.iloc[:5, :5]
ISG15 TNFRSF18 VWA1 HES5 SPSB1
GTCAAGTCATGCCACG-1-HCAHeart7702879 1.087156 0.006268 2.812034 0.027343 4.835241
GAGTCATTCTCCGTGT-1-HCAHeart8287128 1.302901 0.054053 2.712585 0.009016 25.552311
CCTCTGATCGTGACAT-1-HCAHeart7702881 1.066492 0.096505 3.442979 0.008418 1.798697
CGCCATTCATCATCTT-1-H0035_apex 0.174524 0.007493 0.618598 0.003379 9.349258
TCGTAGAGTAGGACTG-1-H0015_septum 0.725282 0.059405 1.161942 0.028128 6.372717

Let’s store the normalized values back in the anndata.

SCVI_NORMALIZED_KEY = "scvi_normalized"

adata.layers[SCVI_NORMALIZED_KEY] = model.get_normalized_expression(library_size=10e4)

If you wish to obtain a full gene-space matrix with batch correction you must use model.get_normalized_expression(..., transform_batch=<reference_batch>). Without specifying transform_batch, the default decoded matrix is conditioned on each cell’s own batch and will therefore not remove batch effects, but a per-batch decoded expression, rather than a harmonized expression across batches.

Interoperability with Scanpy#

Scanpy is a powerful python library for visualization and downstream analysis of scRNA-seq data. We show here how to feed the objects produced by scvi-tools into a scanpy workflow.

Visualization without batch correction#

Warning

We use UMAP to qualitatively assess our low-dimension embeddings of cells. We do not advise using UMAP or any similar approach quantitatively. We do recommend using the embeddings produced by scVI as a plug-in replacement of what you would get from PCA, as we show below.

First, we demonstrate the presence of nuisance variation with respect to nuclei/whole cell, age group, and donor by plotting the UMAP results of the top 30 PCA components for the raw count data.

# run PCA then generate UMAP plots
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=20)
sc.tl.umap(adata, min_dist=0.3)
sc.pl.umap(
    adata,
    color=["cell_type"],
    frameon=False,
)
sc.pl.umap(
    adata,
    color=["donor", "cell_source"],
    ncols=2,
    frameon=False,
)

We see that while the cell types are generally well separated, nuisance variation plays a large part in the variation of the data.

Visualization with batch correction (scVI)#

Now, let us try using the scVI latent space to generate the same UMAP plots to see if scVI successfully accounts for batch effects in the data.

# use scVI latent space for UMAP generation
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata, min_dist=0.3)
sc.pl.umap(
    adata,
    color=["cell_type"],
    frameon=False,
)
sc.pl.umap(
    adata,
    color=["donor", "cell_source"],
    ncols=2,
    frameon=False,
)

We can see that scVI was able to correct for nuisance variation due to nuclei/whole cell, age group, and donor, while maintaining separation of cell types.

Clustering on the scVI latent space#

The user will note that we imported curated labels from the original publication. Our interface with scanpy makes it easy to cluster the data with scanpy from scVI’s latent space and then reinject them into scVI (e.g., for differential expression).

# neighbors were already computed using scVI
SCVI_CLUSTERS_KEY = "leiden_scVI"
sc.tl.leiden(adata, key_added=SCVI_CLUSTERS_KEY, resolution=0.5)
sc.pl.umap(
    adata,
    color=[SCVI_CLUSTERS_KEY],
    frameon=False,
)

Differential expression#

We can also use many scvi-tools models for differential expression. For further details on the methods underlying these functions as well as additional options, please see the API docs.

adata.obs.cell_type.head()
AACTCCCCACGAGAGT-1-HCAHeart7844001                      Myeloid
ATAACGCAGAGCTGGT-1-HCAHeart7829979    Ventricular_Cardiomyocyte
GTCAAGTCATGCCACG-1-HCAHeart7702879                   Fibroblast
GGTGATTCAAATGAGT-1-HCAHeart8102858                  Endothelial
AGAGAATTCTTAGCAG-1-HCAHeart8102863                  Endothelial
Name: cell_type, dtype: category
Categories (11, object): ['Adipocytes', 'Atrial_Cardiomyocyte', 'Endothelial', 'Fibroblast', ..., 'Neuronal', 'Pericytes', 'Smooth_muscle_cells', 'Ventricular_Cardiomyocyte']

For example, a 1-vs-1 DE test is as simple as:

de_df = model.differential_expression(
    groupby="cell_type", group1="Endothelial", group2="Fibroblast"
)
de_df.head()
proba_m1 proba_m2 bayes_factor scale1 scale2 raw_mean1 raw_mean2 non_zeros_proportion1 non_zeros_proportion2 raw_normalized_mean1 raw_normalized_mean2 comparison group1 group2
EGFL7 0.9986 0.0014 6.569875 0.006987 0.000390 2.376779 0.036795 0.741543 0.025756 89.507553 1.169474 Endothelial vs Fibroblast Endothelial Fibroblast
PECAM1 0.9968 0.0032 5.741396 0.005235 0.000642 2.065984 0.075634 0.653930 0.054374 60.612019 3.404117 Endothelial vs Fibroblast Endothelial Fibroblast
VWF 0.9958 0.0042 5.468460 0.014344 0.000672 5.072563 0.054375 0.808226 0.032298 169.693512 2.207696 Endothelial vs Fibroblast Endothelial Fibroblast
SLC9A3R2 0.9950 0.0050 5.293303 0.010908 0.000279 4.451492 0.045380 0.712339 0.034342 111.582703 1.657324 Endothelial vs Fibroblast Endothelial Fibroblast
STC1 0.9948 0.0052 5.253881 0.001578 0.000037 0.785346 0.004088 0.198832 0.003271 17.157606 0.194772 Endothelial vs Fibroblast Endothelial Fibroblast

We can also do a 1-vs-all DE test, which compares each cell type with the rest of the dataset:

de_df = model.differential_expression(groupby="cell_type", mode="change")
de_df.head()
proba_de proba_not_de bayes_factor scale1 scale2 pseudocounts delta lfc_mean lfc_median lfc_std ... raw_mean1 raw_mean2 non_zeros_proportion1 non_zeros_proportion2 raw_normalized_mean1 raw_normalized_mean2 is_de_fdr_0.05 comparison group1 group2
SLC19A3 0.9988 0.0012 6.724225 0.004704 0.000038 0.000012 0.25 8.246908 8.084410 2.934052 ... 2.910343 0.004650 0.572414 0.004325 51.325165 0.268085 True Adipocytes vs Rest Adipocytes Rest
GPAM 0.9974 0.0026 5.949637 0.021378 0.000182 0.000012 0.25 7.260304 7.302740 2.326387 ... 17.372416 0.035791 0.896552 0.031520 280.340485 1.565905 True Adipocytes vs Rest Adipocytes Rest
ADIPOQ 0.9970 0.0030 5.806135 0.003019 0.000041 0.000012 0.25 7.871395 7.550627 3.345690 ... 2.324136 0.003622 0.593103 0.003352 33.361748 0.217763 True Adipocytes vs Rest Adipocytes Rest
PNPLA3 0.9966 0.0034 5.680571 0.002892 0.000030 0.000012 0.25 7.149240 7.132330 2.578820 ... 2.020689 0.002811 0.537931 0.002757 28.458160 0.127947 True Adipocytes vs Rest Adipocytes Rest
LGALS12 0.9964 0.0036 5.623212 0.001283 0.000014 0.000012 0.25 7.710620 7.627983 2.920600 ... 0.689655 0.000973 0.351724 0.000973 10.667459 0.059477 True Adipocytes vs Rest Adipocytes Rest

5 rows × 22 columns

We now extract top markers for each cluster using the DE results.

markers = {}
cats = adata.obs.cell_type.cat.categories
for c in cats:
    cid = f"{c} vs Rest"
    cell_type_df = de_df.loc[de_df.comparison == cid]

    cell_type_df = cell_type_df[cell_type_df.lfc_mean > 0]

    cell_type_df = cell_type_df[cell_type_df["bayes_factor"] > 3]
    cell_type_df = cell_type_df[cell_type_df["non_zeros_proportion1"] > 0.1]

    markers[c] = cell_type_df.index.tolist()[:3]
sc.tl.dendrogram(adata, groupby="cell_type", use_rep="X_scVI")
sc.pl.dotplot(
    adata,
    markers,
    groupby="cell_type",
    dendrogram=True,
    color_map="Blues",
    swap_axes=True,
    use_raw=True,
    standard_scale="var",
)

We can also visualize the scVI normalized gene expression values with the layer option.

sc.pl.heatmap(
    adata,
    markers,
    groupby="cell_type",
    layer="scvi_normalized",
    standard_scale="var",
    dendrogram=True,
    figsize=(8, 12),
)

Logging information#

Verbosity varies in the following way:

  • logger.setLevel(logging.WARNING) will show a progress bar.

  • logger.setLevel(logging.INFO) will show global logs including the number of jobs done.

  • logger.setLevel(logging.DEBUG) will show detailed logs for each training (e.g the parameters tested).

This function’s behaviour can be customized, please refer to its documentation for information about the different parameters available.

In general, you can use scvi.settings.verbosity to set the verbosity of the scvi package. Note that verbosity corresponds to the logging levels of the standard python logging module. By default, that verbosity level is set to INFO (=20). As a reminder the logging levels are:

Level

Numeric value

CRITICAL

50

ERROR

40

WARNING

30

INFO

20

DEBUG

10

NOTSET

0