MrVI Quick Start Tutorial#

MrVI (Multi-resolution Variational Inference) is a model for analyzing multi-sample single-cell RNA-seq data. This tutorial will guide you through the main features of MrVI. MrVI is particularly suited for single-cell RNA sequencing datasets with comparable observations across many samples. By comparable, we mean observations derived from the same tissue or from the same cell line. This ensures that MrVI can provide accurate, single-cell-resolution estimates.

!pip install --quiet scvi-colab
from scvi_colab import install

import os
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import seaborn as sns
from scvi.external import MRVI

scvi.settings.seed = 0  # optional: ensures reproducibility
print("Last run with scvi-tools version:", scvi.__version__)
save_dir = tempfile.TemporaryDirectory()
Last run with scvi-tools version: 1.1.3

Preprocessing and model fitting#

For this tutorial, we will use a subset of the COVID-19 single-cell RNA dataset from Stephenson et al. 2021 (Nature Medicine, for the purpose of demonstrating the functionality of MrVI. Specifically, this subset includes PBMCs from 16 donors in the Newcastle cohort (one of the sites comprising the dataset), randomly subsetted to 30,000 cells.

adata_path = os.path.join(, "haniffa_tutorial_subset.h5ad")

adata =, backup_url="")
    adata, n_top_genes=10000, inplace=True, subset=True, flavor="seurat_v3"
AnnData object with n_obs × n_vars = 30000 × 10000
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'age_int'
    var: 'feature_types', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'

Before training, we need to specify which covariates in obs should be used as target (sample_key) and nuisance variables (batch_key). In this tutorial, we will use donor IDs (patient_id) as the target variable, and leave the batch variable empty since the data is already subsetted to the Newcastle cohort (denoted in Site).

Otherwise, we will focus on the following obs keys for the analysis:

  • initial_clustering: coarse cell-type annotations from the original study.

  • Status: whether the donor had COVID-19 or was healthy.

  • Days_from_onset: how many days it had been since the onset of symptoms before the sample was taken.

sample_key = "patient_id"  # target covariate
# batch_key="Site"  # nuisance variable identifier
MRVI.setup_anndata(adata, sample_key=sample_key)
model = MRVI(adata)
INFO     Jax module moved to cuda:0.Note: Pytorch lightning will show GPU is not being used for the Trainer.       
Epoch 1/400:   0%|          | 0/400 [00:00<?, ?it/s]Epoch 400/400: 100%|██████████| 400/400 [06:34<00:00,  1.07it/s, v_num=1, train_loss_step=1.45e+3, train_loss_epoch=1.37e+3]Epoch 400/400: 100%|██████████| 400/400 [06:34<00:00,  1.01it/s, v_num=1, train_loss_step=1.45e+3, train_loss_epoch=1.37e+3]

Once trained, we can plot the ELBO of the model to check if the model has converged.

plt.ylabel("Validation ELBO")

Visualize cell embeddings and sample distances#

The latent representations of the cells can also be accessed and visualized using the get_latent_representation method. MrVI learns two latent representations: u and z. u is designed to capture broad cell states invariant to sample and nuisance covariates, while z augments u with sample-specific effects but remains corrected for nuisance covariate effects.

Here, we visualize the u latent space in 2D using minimum-distortion embeddings~(MDE).

u = model.get_latent_representation()
# or z = model.get_latent_representation(give_z=True) to get z instead of u
u_mde = scvi.model.utils.mde(u)
adata.obsm["u_mde"] = u_mde
    color=["initial_clustering", "Status"],
INFO     Using cuda:0 for `pymde.preserve_neighbors`.                                                              

Sample distances can be computed using the get_local_sample_distances method, which characterizes sample relationships for any cell in the dataset. This method can return cell-specific distances (keep_cell=True), as well as averaged distances within cell subpopulations, characterized by the groupby argument. Specifying keep_cell=False will ensure that cell-specific distances are not returned, which can reduce the memory footprint of the returned object in the case where many samples are present.

dists = model.get_local_sample_distances(
    keep_cell=False, groupby="initial_clustering", batch_size=32
d1 = dists.loc[{"initial_clustering_name": "CD16"}].initial_clustering

The following cell provides useful utility functions to perform hierarchical clustering based on sample distances, as well as to extract sample metadata of interest to visualize the distance matrices

from matplotlib.colors import to_hex
from scipy.cluster.hierarchy import linkage, optimal_leaf_ordering
from scipy.spatial.distance import squareform

def get_sample_colors():
    cmap = sns.color_palette("viridis", as_cmap=True)

    def get_onset_colors(x):
        if x == "Healthy":
            return to_hex(np.array([0.5, 0.5, 0.5, 1.0]))
            x_ = int(x) / 30.0
            return to_hex(cmap(x_))

    covid_map = {
        "Covid": "red",
        "Healthy": "green",
    sample_info = model.sample_info.set_index("sample_id")
    covid_colors =
    onset_colors =
    colors = pd.DataFrame(
            "covid": covid_colors,
            "onset": onset_colors,
    return colors

def get_dendrogram(dists):
    ds = squareform(dists)
    Z = linkage(ds, method="ward")
    Z = optimal_leaf_ordering(Z, ds)
    return Z
Z = get_dendrogram(d1)
colors = get_sample_colors()

<seaborn.matrix.ClusterGrid at 0x7dc3004bda90>

Differential expression and differential abundance analysis#

In this section of the tutorial, we will explore how to compute differential expression (DE) estimates that are linked to specific covariates of interest at the sample level. For a list of target covariates, MrVI will return covariate-specific effect sizes and p-values for each cell. This allows for a detailed analysis of how different covariates influence gene expression across different cell types. Additionally, you can visualize a summary of the overall effect size of a covariate, which helps in understanding the magnitude of the estimated latent effects of each covariate on the gene expression.

sample_cov_keys = ["Status"]  # Replace with your sample covariate of interest
model.sample_info["Status"] = model.sample_info["Status"].cat.reorder_categories(
    ["Healthy", "Covid"]
)  # Reorder categories such that the coefficient corresponds to Covid
de_res = model.differential_expression(
adata.obs["Covid_DE_eff_size"] = de_res.effect_size.sel(covariate="Status_Covid").values
    color=["initial_clustering", "Covid_DE_eff_size"],
    vmax=np.quantile(de_res.effect_size.values, 0.95),

For the cell types with large effect sizes corresponding to the COVID status, we can look into which genes had the highest average LFCs.

cell_types = ["CD16", "DCs", "CD14"]
top_genes_per_cell_type = {}

for cell_type in cell_types:
    cell_idxs = adata[(adata.obs["initial_clustering"] == cell_type)].obs.index
    top_genes = set(
        de_res.sel(cell_name=cell_idxs, covariate="Status_Covid")
    top_genes_per_cell_type[cell_type] = top_genes

all_top_genes = list(set.union(*top_genes_per_cell_type.values()))

# Add B Cells for comparison

avg_lfcs = []
for cell_type in cell_types:
    cell_idxs = adata[(adata.obs["initial_clustering"] == cell_type)].obs.index
        de_res.sel(cell_name=cell_idxs, gene=all_top_genes).mean(dim="cell_name").lfc.values

heatmap_data = pd.DataFrame(
    np.concatenate(avg_lfcs, axis=0), index=cell_types, columns=all_top_genes

plt.figure(figsize=(10, 8))
sns.clustermap(heatmap_data, annot=True, cmap="viridis", fmt=".2f")
plt.title("Average LFCs attributed to Covid Status")
plt.xlabel("Cell Type")
<Figure size 1000x800 with 0 Axes>

Next, we will demonstrate how MrVI can be used to estimate local, covariate-linked differential abundance (DA) in single-cell data. Provided with sample-level target covariates, MrVI will return log likelihood values corresponding cell state abundance for each covariate. Differential abundance refers to the change in cellular composition correlated with target covariates, in this case, cellular composition defined over the u space.

da_res = model.differential_abundance(sample_cov_keys=sample_cov_keys)
covid_log_probs = da_res.Status_log_probs.loc[{"Status": "Covid"}]
healthy_log_probs = da_res.Status_log_probs.loc[{"Status": "Healthy"}]
covid_healthy_log_prob_ratio = covid_log_probs - healthy_log_probs

We can take the difference between the log likelihood values to get a log likelihood ratio between two sample covariate groups. In this case, a positive value corresponds to an enrichment of a certain cell state for donors with COVID. We see that B cells, despite being low for DE attributed to COVID, are enriched in COVID patients.

adata.obs["Covid_DA_lfc"] = covid_healthy_log_prob_ratio.values
    color=["initial_clustering", "Covid_DA_lfc"],