MrVI Quick Start Tutorial#

MrVI (Multi-resolution Variational Inference) is a model for analyzing multi-sample single-cell RNA-seq data. This tutorial will guide you through the main features of MrVI. MrVI is particularly suited for single-cell RNA sequencing datasets with comparable observations across many samples. By comparable, we mean observations derived from the same tissue or from the same cell line. This ensures that MrVI can provide accurate, single-cell-resolution estimates.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
import os
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import seaborn as sns
from scvi.external import MRVI

scvi.settings.seed = 0  # optional: ensures reproducibility
print("Last run with scvi-tools version:", scvi.__version__)
save_dir = tempfile.TemporaryDirectory()
Last run with scvi-tools version: 1.4.2
# import inspect
# print(inspect.getsource(MRVI))

Preprocessing and model fitting#

Note

For general pre-processing for various datatypes used by scvi-tools models, see the preprocessing tutorial.

For this tutorial, we will use a subset of the COVID-19 single-cell RNA dataset from Stephenson et al. 2021 (Nature Medicine, https://doi.org/10.1038/s41591-021-01329-2) for the purpose of demonstrating the functionality of MrVI. Specifically, this subset includes PBMCs from 16 donors in the Newcastle cohort (one of the sites comprising the dataset), randomly subsetted to 30,000 cells.

adata_path = os.path.join(save_dir.name, "haniffa_tutorial_subset.h5ad")

adata = sc.read(
    adata_path,
    backup_url="https://exampledata.scverse.org/scvi-tools/haniffa_tutorial_subset.h5ad",
)
sc.pp.highly_variable_genes(
    adata, n_top_genes=10000, inplace=True, subset=True, flavor="seurat_v3"
)
adata
AnnData object with n_obs × n_vars = 30000 × 10000
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'age_int'
    var: 'feature_types', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'

Before training, we need to specify which covariates in obs should be used as target (sample_key) and nuisance variables (batch_key). In this tutorial, we will use donor IDs (patient_id) as the target variable, and leave the batch variable empty since the data is already subsetted to the Newcastle cohort (denoted in Site).

Otherwise, we will focus on the following obs keys for the analysis:

  • initial_clustering: coarse cell-type annotations from the original study.

  • Status: whether the donor had COVID-19 or was healthy.

  • Days_from_onset: how many days it had been since the onset of symptoms before the sample was taken.

We will initialize the MRVI model with its “pytorch” backend. A JAX backend version can be also be used using backend=”jax”.

sample_key = "patient_id"  # target covariate
# batch_key="Site"  # nuisance variable identifier
MRVI.setup_anndata(adata, sample_key=sample_key, backend="torch")
model = MRVI(adata)
model.train(max_epochs=400)

Once trained, we can plot the ELBO of the model to check if the model has converged.

plt.plot(model.history["elbo_validation"].iloc[5:])
plt.xlabel("Epoch")
plt.ylabel("Validation ELBO")
plt.show()
../../../_images/f2ff4a3ba918e5cb865c4042a0078870e3fba6a3318bd7ef63e433c6a119c5a7.png
plt.plot(model.history["elbo_train"])
plt.xlabel("Epoch")
plt.ylabel("Training ELBO")
plt.show()
../../../_images/7b431bdd99b8eb37acad4ff7dbd909bd38468bd4d8f0ef17c6beb199dc89e139.png

Visualize cell embeddings and sample distances#

The latent representations of the cells can also be accessed and visualized using the get_latent_representation method. MrVI learns two latent representations: u and z. u is designed to capture broad cell states invariant to sample and nuisance covariates, while z augments u with sample-specific effects but remains corrected for nuisance covariate effects.

u = model.get_latent_representation()
adata.obsm["u"] = u
sc.pp.neighbors(adata, use_rep="u")
sc.tl.umap(adata, min_dist=0.3)
sc.pl.umap(
    adata,
    color=["initial_clustering", "Status"],
    frameon=False,
    ncols=1,
)
../../../_images/cd03a18d29dae8447d5c6ab3768d754d6442b23dacc7ad740a387225358bc976.png

Sample distances can be computed using the get_local_sample_distances method, which characterizes sample relationships for any cell in the dataset. This method can return cell-specific distances (keep_cell=True), as well as averaged distances within cell subpopulations, characterized by the groupby argument. Specifying keep_cell=False will ensure that cell-specific distances are not returned, which can reduce the memory footprint of the returned object in the case where many samples are present.

dists = model.get_local_sample_distances(groupby="initial_clustering")
d1 = dists.loc[{"initial_clustering_name": "CD16"}].initial_clustering
dists = model.get_local_sample_distances(groupby="patient_id")
# Avg distance from other samples per cell:
# vector to quantify sample effect heterogeneity within cell populations
cell_dists = dists["cell"]
dist_vec = []
celltype_anno = []
sample_anno = []
target_anno = []

for i in range(cell_dists.shape[0]):
    subannot_temp = adata.obs["patient_id"].iloc[i]
    subannot_temp_index = np.where(cell_dists["sample_x"].to_numpy() == subannot_temp)[0][0]
    vec_temp = cell_dists[i, subannot_temp_index, :].to_numpy()
    dist_vec.append(vec_temp)
    celltype_anno.append([subannot_temp] * len(vec_temp))
    sample_anno.append([adata.obs["patient_id"].iloc[i]] * len(vec_temp))
    target_anno.append(cell_dists[i, :]["sample_y"].values.tolist())
dist_vec = np.concatenate(dist_vec, axis=0)
celltype_anno = np.concatenate(celltype_anno, axis=0)
sample_anno = np.concatenate(sample_anno, axis=0)
target_anno = np.concatenate(target_anno, axis=0)
# Overall distribution of distances
plt.hist(dist_vec, bins=100)
plt.show()
../../../_images/14b96701b6f1c8d766af220bb2ed3a0eee61062d68926aa7da9acb0a39796df5.png

The following cell provides useful utility functions to perform hierarchical clustering based on sample distances, as well as to extract sample metadata of interest to visualize the distance matrices

from matplotlib.colors import to_hex
from scipy.cluster.hierarchy import linkage, optimal_leaf_ordering
from scipy.spatial.distance import squareform


def get_sample_colors():
    cmap = sns.color_palette("viridis", as_cmap=True)

    def get_onset_colors(x):
        if x == "Healthy":
            return to_hex(np.array([0.5, 0.5, 0.5, 1.0]))
        else:
            x_ = int(x) / 30.0
            return to_hex(cmap(x_))

    covid_map = {
        "Covid": "red",
        "Healthy": "green",
    }
    sample_info = model.sample_info.set_index("sample_id")
    covid_colors = sample_info.Status.map(covid_map).values
    onset_colors = sample_info.Days_from_onset.map(get_onset_colors)
    colors = pd.DataFrame(
        {
            "covid": covid_colors,
            "onset": onset_colors,
        }
    )
    return colors


def get_dendrogram(dists):
    ds = squareform(dists)
    Z = linkage(ds, method="ward")
    Z = optimal_leaf_ordering(Z, ds)
    return Z
Z = get_dendrogram(d1)
colors = get_sample_colors()

sns.clustermap(
    d1.to_pandas(),
    row_linkage=Z,
    col_linkage=Z,
    xticklabels=False,
    yticklabels=False,
    row_colors=colors,
)
<seaborn.matrix.ClusterGrid at 0x74f0521a0590>
../../../_images/29f12dfada65f81753e10b6b9aa40d546a81d4e1ca44bd3347c5c46cdfcb0f5e.png

Differential expression and differential abundance analysis#

In this section of the tutorial, we will explore how to compute differential expression (DE) estimates that are linked to specific covariates of interest at the sample level. For a list of target covariates, MrVI will return covariate-specific effect sizes and p-values for each cell. This allows for a detailed analysis of how different covariates influence gene expression across different cell types. Additionally, you can visualize a summary of the overall effect size of a covariate, which helps in understanding the magnitude of the estimated latent effects of each covariate on the gene expression.

sample_cov_keys = ["Status"]  # Replace with your sample covariate of interest
model.sample_info["Status"] = model.sample_info["Status"].cat.reorder_categories(
    ["Healthy", "Covid"]
)  # Reorder categories such that the coefficient corresponds to Covid
de_res = model.differential_expression(
    sample_cov_keys=sample_cov_keys, store_lfc=True, use_vmap=False
)
adata.obs["Covid_DE_eff_size"] = de_res.effect_size.sel(covariate="Status_Covid").values
sc.pl.umap(
    adata,
    color=["initial_clustering", "Covid_DE_eff_size"],
    frameon=False,
    ncols=1,
    vmax=np.quantile(de_res.effect_size.values, 0.95),
    cmap="viridis",
)
../../../_images/ccfcd10d6d42c1d1260c6c94fafaed13aaea94008e95043be9f5f11c9a6f952a.png

For the cell types with large effect sizes corresponding to the COVID status, we can look into which genes had the highest average LFCs.

cell_types = ["CD16", "DCs", "CD14"]
top_genes_per_cell_type = {}

for cell_type in cell_types:
    cell_idxs = adata[(adata.obs["initial_clustering"] == cell_type)].obs.index
    top_genes = set(
        de_res.sel(cell_name=cell_idxs, covariate="Status_Covid")
        .mean(dim="cell_name")
        .lfc.to_pandas()
        .abs()
        .nlargest(5)
        .index
    )
    top_genes_per_cell_type[cell_type] = top_genes

all_top_genes = list(set.union(*top_genes_per_cell_type.values()))

# Add B Cells for comparison
cell_types.append("B_cell")

avg_lfcs = []
for cell_type in cell_types:
    cell_idxs = adata[(adata.obs["initial_clustering"] == cell_type)].obs.index
    avg_lfcs.append(
        de_res.sel(cell_name=cell_idxs, gene=all_top_genes).mean(dim="cell_name").lfc.values
    )

heatmap_data = pd.DataFrame(
    np.concatenate(avg_lfcs, axis=0), index=cell_types, columns=all_top_genes
)

plt.figure(figsize=(10, 8))
sns.clustermap(heatmap_data, annot=True, cmap="viridis", fmt=".2f")
plt.title("Average LFCs attributed to Covid Status")
plt.xlabel("Cell Type")
plt.ylabel("Gene")
plt.show()
<Figure size 1000x800 with 0 Axes>
../../../_images/c38693d78d3917da082cd23faf8bb23020179626b00bcb8764fe10dfdab0937c.png

Next, we will demonstrate how MrVI can be used to estimate local, covariate-linked differential abundance (DA) in single-cell data. Provided with sample-level target covariates, MrVI will return log likelihood values corresponding cell state abundance for each covariate. Differential abundance refers to the change in cellular composition correlated with target covariates, in this case, cellular composition defined over the u space.

da_res = model.differential_abundance(sample_cov_keys=sample_cov_keys)
covid_log_probs = da_res.Status_log_probs.loc[{"Status": "Covid"}]
healthy_log_probs = da_res.Status_log_probs.loc[{"Status": "Healthy"}]
covid_healthy_log_prob_ratio = covid_log_probs - healthy_log_probs

We can take the difference between the log likelihood values to get a log likelihood ratio between two sample covariate groups. In this case, a positive value corresponds to an enrichment of a certain cell state for donors with COVID.

adata.obs["Covid_DA_lfc"] = covid_healthy_log_prob_ratio.values
sc.pl.umap(
    adata,
    color=["initial_clustering", "Covid_DA_lfc"],
    frameon=False,
    ncols=1,
    vmin=-1,
    vmax=1,
    cmap="coolwarm",
    sort_order=False,
)
../../../_images/a73e728b515c6b8ef19aac1e4660de4c4d9644e588f8289f323427013db848f6.png

Finaly, we will show the DA results above displayed as violin plots to demonstrate changes in abundance across cell types. Displayed is the log density ratio between samples from healty and Covid19 patients. We report the significance of the hypothesis that the difference of log density ratios between a given cell type and all other cell types is above 1 in absolute value

cell_types = adata.obs["initial_clustering"].unique()
medians = adata.obs.groupby("initial_clustering")["Covid_DA_lfc"].median()
median_order = medians.sort_values().index.tolist()
star_flags = {ct: (abs(medians[ct]) > 1) for ct in median_order}
adata.obs["cell_type_ordered_for_violin"] = pd.Categorical(
    adata.obs["initial_clustering"], categories=median_order, ordered=True
)
sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(8, 10))
sns.violinplot(
    data=adata.obs,
    y="cell_type_ordered_for_violin",
    x="Covid_DA_lfc",
    hue=None,
    palette="tab20",
    cut=0,
    linewidth=1,
    ax=ax,
)
ax.axvline(0, color="red", linestyle="--", linewidth=1.5)
sns.despine(left=True)
ax.set_ylabel("Cell Types", fontsize=12)
ax.set_xlabel("Log Density Ratio Covid19 vs Healthy behavior", fontsize=12)
ax.set_xlim(-30, 300)
x_pos = np.min([adata.obs["Covid_DA_lfc"].max(), 300]) * 1.15  # slightly beyond max value
for i, ct in enumerate(median_order):
    if star_flags[ct]:
        ax.text(x=x_pos, y=i, s="★", fontsize=14, va="center", ha="left")
../../../_images/28887fc9884dae41852ca38785ba18898b8a72c604c710498946e970923f5c80.png