MrVI analysis over Tahoe100M cells dataset#
MrVI (Multi-resolution Variational Inference) is a model for analyzing multi-sample single-cell RNA-seq data. This tutorial show how to do run MrVI in PyTorch version over the Tahoe100M cells dataset and perform basic analysis.
!pip install --quiet scvi-colab
from scvi_colab import install
install()
import tempfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import scvi.hub
import seaborn as sns
import torch
from scvi.external import MRVI
run_autotune = False
# import inspect
# print(inspect.getsource(MRVI))
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.3.3
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()
%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"
pd.set_option("display.max_rows", 50)
pd.set_option("display.max_columns", 50)
pd.set_option("display.width", 1000)
Get the data#
We start by downloading the model from its hub in order to use its metadata Note that the model is very large therefore it will take time to being download.
# get the hub data
tahoe_hubmodel = scvi.hub.HubModel.pull_from_huggingface_hub(
repo_name="vevotx/Tahoe-100M-SCVI-v1", cache_dir="."
)
tahoe_hubmodel.model.adata.obs.head()
INFO Loading model...
INFO File ./models--vevotx--Tahoe-100M-SCVI-v1/snapshots/b5283a73fbbed812a95264ace360da538b20af89/model.pt
already downloaded
| sample | species | gene_count | tscp_count | mread_count | bc1_wind | bc2_wind | bc3_wind | bc1_well | bc2_well | bc3_well | id | drugname_drugconc | drug | INT_ID | NUM.SNPS | NUM.READS | demuxlet_call | BEST.LLK | NEXT.LLK | DIFF.LLK.BEST.NEXT | BEST.POSTERIOR | SNG.POSTERIOR | SNG.BEST.LLK | SNG.NEXT.LLK | SNG.ONLY.POSTERIOR | DBL.BEST.LLK | DIFF.LLK.SNG.DBL | sublibrary | BARCODE | pcnt_mito | S_score | G2M_score | phase | pass_filter | dataset | _scvi_batch | _scvi_labels | _scvi_observed_lib_size | plate | Cell_Name_Vevo | Cell_ID_Cellosaur | observed_lib_size | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| BARCODE_SUB_LIB_ID | |||||||||||||||||||||||||||||||||||||||||||
| 01_001_052-lib_1105 | smp_1783 | hg38 | 1878 | 2893 | 3284 | 1 | 1 | 52 | A1 | A1 | E4 | recgIHRi9MiCIr4CO | [('8-Hydroxyquinoline', 0.05, 'uM')] | 8-Hydroxyquinoline | 1.0 | 199.0 | 215.0 | singlet | -50.74 | -59.04 | 8.30 | -55.0 | 1.0 | -50.74 | -87.96 | 0.0 | -59.04 | 8.30 | lib_1105 | 01_001_052 | 0.019357 | 0.174603 | 0.179670 | G2M | full | 0 | 0 | 0 | 2893 | 4 | PANC-1 | CVCL_0480 | 2893 |
| 01_001_105-lib_1105 | smp_1783 | hg38 | 1765 | 2434 | 2764 | 1 | 1 | 105 | A1 | A1 | p2.A9 | recgIHRi9MiCIr4CO | [('8-Hydroxyquinoline', 0.05, 'uM')] | 8-Hydroxyquinoline | 3.0 | 137.0 | 140.0 | singlet | -37.97 | -42.41 | 4.44 | -43.0 | 1.0 | -37.97 | -64.52 | 0.0 | -42.41 | 4.44 | lib_1105 | 01_001_105 | 0.029581 | 0.297619 | 0.342857 | G2M | full | 0 | 0 | 0 | 2434 | 4 | SW480 | CVCL_0546 | 2434 |
| 01_001_165-lib_1105 | smp_1783 | hg38 | 3174 | 5691 | 6454 | 1 | 1 | 165 | A1 | A1 | p2.F9 | recgIHRi9MiCIr4CO | [('8-Hydroxyquinoline', 0.05, 'uM')] | 8-Hydroxyquinoline | 4.0 | 379.0 | 396.0 | singlet | -129.66 | -130.65 | 0.99 | -130.0 | 1.0 | -129.66 | -186.89 | 0.0 | -130.65 | 0.99 | lib_1105 | 01_001_165 | 0.031629 | 0.031746 | 0.099084 | G2M | full | 0 | 0 | 0 | 5691 | 4 | SW1417 | CVCL_1717 | 5691 |
| 01_003_094-lib_1105 | smp_1783 | hg38 | 1380 | 1804 | 2050 | 1 | 3 | 94 | A1 | A3 | H10 | recgIHRi9MiCIr4CO | [('8-Hydroxyquinoline', 0.05, 'uM')] | 8-Hydroxyquinoline | 7.0 | 122.0 | 125.0 | singlet | -31.79 | -33.98 | 2.19 | -36.0 | 1.0 | -31.79 | -49.36 | 0.0 | -33.98 | 2.19 | lib_1105 | 01_003_094 | 0.017738 | -0.063492 | 0.019780 | G2M | full | 0 | 0 | 0 | 1804 | 4 | SW1417 | CVCL_1717 | 1804 |
| 01_003_164-lib_1105 | smp_1783 | hg38 | 1179 | 1514 | 1715 | 1 | 3 | 164 | A1 | A3 | p2.F8 | recgIHRi9MiCIr4CO | [('8-Hydroxyquinoline', 0.05, 'uM')] | 8-Hydroxyquinoline | 8.0 | 87.0 | 93.0 | singlet | -28.99 | -27.07 | -1.92 | -34.0 | 1.0 | -28.99 | -41.61 | 0.0 | -27.07 | -1.92 | lib_1105 | 01_003_164 | 0.023118 | -0.075397 | -0.070879 | G1 | full | 0 | 0 | 0 | 1514 | 4 | A498 | CVCL_1056 | 1514 |
# Load Cell Line Metadata
cell_lines = pd.read_csv(
"/home/access/PycharmProjects/scvi-tools/Tahoe100M/cell_line_metadata.h5ad"
)
cell_lines.head()
| Unnamed: 0 | cell_name | Cell_ID_DepMap | Cell_ID_Cellosaur | Organ | Driver_Gene_Symbol | Driver_VarZyg | Driver_VarType | Driver_ProtEffect_or_CdnaEffect | Driver_Mech_InferDM | Driver_GeneType_DM | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | A549 | ACH-000681 | CVCL_0023 | Lung | CDKN2A | Hom | Deletion | DEL | LoF | Suppressor |
| 1 | 1 | A549 | ACH-000681 | CVCL_0023 | Lung | CDKN2B | Hom | Deletion | DEL | LoF | Suppressor |
| 2 | 2 | A549 | ACH-000681 | CVCL_0023 | Lung | KRAS | Hom | Missense | p.G12S | GoF | Oncogene |
| 3 | 3 | A549 | ACH-000681 | CVCL_0023 | Lung | SMARCA4 | Hom | Frameshift | p.Q729fs | LoF | Suppressor |
| 4 | 4 | A549 | ACH-000681 | CVCL_0023 | Lung | STK11 | Hom | Stopgain | p.Q37* | LoF | Suppressor |
# Load the .h5ad file
adata = sc.read_h5ad(
"/home/access/PycharmProjects/scvi-tools/Tahoe100M/tahoe100m_sample_100000_rand.h5ad"
)
adata.obs.head()
| drug | sample | BARCODE_SUB_LIB_ID | cell_line_id | moa-fine | canonical_smiles | pubchem_cid | plate | mean_gene_count | mean_tscp_count | mean_mread_count | mean_pcnt_mito | drugname_drugconc | targets | moa-broad | human-approved | clinical-trials | gpt-notes-approval | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Niclosamide (olamine) | smp_2257 | 91_109_060-lib_2401 | CVCL_1577 | unclear | C1=CC(=C(C=C1[N+](=O)[O-])Cl)NC(=O)C2=C(C=CC(=... | 14992.0 | plate8 | 1553.713059 | 2518.258820 | 2984.088782 | 0.037475 | [('Niclosamide (olamine)', 0.5, 'uM')] | STAT3 | inhibitor/antagonist | yes | yes | Approved for parasitic infections, investigate... |
| 1 | Peretinoin | smp_2258 | 92_149_042-lib_2401 | CVCL_0504 | Retinoic receptor agonist | CC(=CCCC(=CCCC(=CC=CC(=CC(=O)O)C)C)C)C | 6437836.0 | plate8 | 1588.514329 | 2725.711632 | 3232.796996 | 0.085470 | [('Peretinoin', 0.5, 'uM')] | RXRA, RXRB, RXRG | unclear | no | yes | Studied for liver cancer, not approved for hum... |
| 2 | Peretinoin | smp_2258 | 92_101_110-lib_2401 | CVCL_0459 | Retinoic receptor agonist | CC(=CCCC(=CCCC(=CC=CC(=CC(=O)O)C)C)C)C | 6437836.0 | plate8 | 1588.514329 | 2725.711632 | 3232.796996 | 0.085470 | [('Peretinoin', 0.5, 'uM')] | RXRA, RXRB, RXRG | unclear | no | yes | Studied for liver cancer, not approved for hum... |
| 3 | Niclosamide (olamine) | smp_2257 | 91_187_166-lib_2401 | CVCL_0366 | unclear | C1=CC(=C(C=C1[N+](=O)[O-])Cl)NC(=O)C2=C(C=CC(=... | 14992.0 | plate8 | 1553.713059 | 2518.258820 | 2984.088782 | 0.037475 | [('Niclosamide (olamine)', 0.5, 'uM')] | STAT3 | inhibitor/antagonist | yes | yes | Approved for parasitic infections, investigate... |
| 4 | Niclosamide (olamine) | smp_2257 | 91_186_141-lib_2401 | CVCL_0504 | unclear | C1=CC(=C(C=C1[N+](=O)[O-])Cl)NC(=O)C2=C(C=CC(=... | 14992.0 | plate8 | 1553.713059 | 2518.258820 | 2984.088782 | 0.037475 | [('Niclosamide (olamine)', 0.5, 'uM')] | STAT3 | inhibitor/antagonist | yes | yes | Approved for parasitic infections, investigate... |
We use a subset of data, show the plates stratification and perform HVG filtering following by merging the metadata and split to train and test
adata.obs.plate.value_counts()
plate
plate4 28225
plate8 28225
plate3 28224
plate7 15326
Name: count, dtype: int64
# HVG filtering
sc.pp.highly_variable_genes(
adata, n_top_genes=15000, inplace=True, subset=True, flavor="seurat_v3", batch_key="plate"
)
adata
AnnData object with n_obs × n_vars = 100000 × 15000
obs: 'drug', 'sample', 'BARCODE_SUB_LIB_ID', 'cell_line_id', 'moa-fine', 'canonical_smiles', 'pubchem_cid', 'plate', 'mean_gene_count', 'mean_tscp_count', 'mean_mread_count', 'mean_pcnt_mito', 'drugname_drugconc', 'targets', 'moa-broad', 'human-approved', 'clinical-trials', 'gpt-notes-approval'
var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
uns: 'hvg'
# merge metadata
adata.obs = adata.obs.merge(
tahoe_hubmodel.model.adata.obs[
[
"Cell_Name_Vevo",
"dataset",
"phase",
"observed_lib_size",
"S_score",
"G2M_score",
"sublibrary",
]
],
how="left",
left_on="BARCODE_SUB_LIB_ID",
right_index=True,
)
adata.layers["counts"] = adata.X.copy() # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
from sklearn.model_selection import train_test_split
train_ind, valid_ind = train_test_split(
adata.obs.plate.index.astype(int), test_size=0.9, stratify=adata.obs.plate
)
Init the model#
We will initialize the MRVI model with its “pytorch” backend. A JAX backend version can be also be used using backend=”jax”.
sample_key = "sample" # target covariate sample/cell_line_id
batch_key = "plate" # nuisance variable identifier
MRVI.setup_anndata(
adata, sample_key=sample_key, batch_key=batch_key, layer="counts", backend="torch"
)
Train mrVI#
import gc
import time
gc.collect()
start = time.time()
model = MRVI(adata, backend="torch")
model.train(
max_epochs=400,
early_stopping=True,
plan_kwargs={"lr": 1e-3, "n_epochs_kl_warmup": 40},
batch_size=512,
early_stopping_patience=5,
check_val_every_n_epoch=1,
datasplitter_kwargs={"external_indexing": [np.array(train_ind), np.array(valid_ind)]},
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")
Monitored metric elbo_validation did not improve in the last 5 records. Best score: 1306.802. Signaling Trainer to stop.
Elapsed time: 340.90 seconds
train_ind
Index([12671, 38287, 63621, 80870, 1403, 26510, 88553, 19564, 38816, 33316,
...
41111, 87852, 54111, 79936, 61032, 1657, 72773, 28364, 96382, 16134], dtype='int64', length=10000)
valid_ind
Index([37550, 90863, 33241, 24035, 10737, 21757, 92804, 99954, 43105, 33898,
...
74013, 65779, 83502, 95378, 2676, 88516, 70633, 13290, 17217, 2395], dtype='int64', length=90000)
plt.plot(model.history["elbo_validation"])
plt.xlabel("Epoch")
plt.ylabel("Validation ELBO")
plt.show()
plt.plot(model.history["reconstruction_loss_validation"])
plt.xlabel("Epoch")
plt.ylabel("Validation reconstruction_loss")
plt.show()
plt.plot(model.history["kl_local_validation"])
plt.xlabel("Epoch")
plt.ylabel("Validation KL")
plt.show()
Visualize cell embeddings and sample distances#
The latent representations of the cells can also be accessed and visualized using the get_latent_representation method. MrVI learns two latent representations: u and z. u is designed to capture broad cell states invariant to sample and nuisance covariates, while z augments u with sample-specific effects but remains corrected for nuisance covariate effects.
# run PCA then generate UMAP plots
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=50, n_neighbors=50)
sc.tl.umap(adata, min_dist=0.1)
u = model.get_latent_representation()
adata.obsm["X_mrVI_Torch"] = u
sc.pp.neighbors(adata, use_rep="X_mrVI_Torch")
sc.tl.umap(adata, min_dist=0.3)
u.shape
(100000, 10)
Train regular SCVI model for comparison#
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key=batch_key)
model_scvi = scvi.model.SCVI(adata)
model_scvi.train(
max_epochs=100,
early_stopping=True,
check_val_every_n_epoch=1,
datasplitter_kwargs={"external_indexing": [np.array(train_ind), np.array(valid_ind)]},
)
plt.plot(model_scvi.history["elbo_validation"])
plt.xlabel("Epoch")
plt.ylabel("Validation ELBO")
plt.show()
SCVI_LATENT_KEY = "X_scVI"
latent = model_scvi.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
latent.shape
(100000, 10)
# use scVI latent space for UMAP generation
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata, min_dist=0.3)
sc.pl.umap(
adata,
color=["plate", "cell_line_id"],
title=["Plate ID SCVI", "Cell Line ID SCVI"],
ncols=2,
frameon=False,
)
Compare results#
from scib_metrics.benchmark import BatchCorrection, Benchmarker, BioConservation
bm = Benchmarker(
adata[list(np.random.choice(np.arange(adata.n_obs), size=1000, replace=False)), :],
batch_key="plate",
bio_conservation_metrics=BioConservation(
isolated_labels=True,
nmi_ari_cluster_labels_leiden=True,
silhouette_label=True,
clisi_knn=True,
nmi_ari_cluster_labels_kmeans=True,
),
batch_correction_metrics=BatchCorrection(
bras=True,
pcr_comparison=True,
kbet_per_label=True,
graph_connectivity=False,
ilisi_knn=True,
),
label_key="cell_line_id",
embedding_obsm_keys=["X_pca", "X_scVI", "X_mrVI_Torch"],
n_jobs=-1,
)
bm.benchmark()
bm.plot_results_table(min_max_scale=False)