CITE-seq reference mapping with totalVI#

With totalVI, we can train a reference model that can be used for mapping new query data. Here we dive into this functionality.

If you use totalVI, please consider citing:

  • Gayoso, A., Steier, Z., Lopez, R., Regier, J., Nazor, K. L., Streets, A., & Yosef, N. (2021). Joint probabilistic modeling of single-cell multi-omic data with totalVI. Nature Methods, 18(3), 272-282.

Uncomment the following lines in Google Colab in order to install scvi-tools:

# !pip install --quiet scvi-colab
# from scvi_colab import install

# install()
import os
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pooch
import scanpy as sc
import scrublet as scr
import scvi
import torch
import umap
from scvi.model import TOTALVI
from sklearn.ensemble import RandomForestClassifier

Building a reference model#

scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.0.3

You can modify save_dir below to change where the data files for this tutorial are saved.

sc.set_figure_params(figsize=(4, 4))
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

This dataset was filtered as described in the scvi-tools manuscript (low quality cells, doublets, lowly expressed genes, etc.).


Here we use the argument mask_protein_batches=5. What this does is treats the protein data in 5/24 of the batches in this dataset as missing. This will help totalVI integrate query data that has a subset of the proteins measured in this dataset or no proteins at all (i.e., just scRNA-seq data).

adata =,
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
adata.raw = adata
INFO     Downloading file at /tmp/tmpgcskgeq2/pbmc_seurat_v4.h5ad                                                  
Downloading...: 100%|██████████| 1002579/1002579.0 [01:33<00:00, 10717.95it/s]
AnnData object with n_obs × n_vars = 152094 × 4000
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size'
    var: 'mt', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'log1p', 'hvg'
    obsm: 'protein_counts'
    layers: 'counts'
INFO     Using column names from columns of adata.obsm['protein_counts']                                           
INFO     Found batches with missing protein expression                                                             

Prepare and run model#

Here we use the scArches-specific totalVI parameters, as shown in the scArches tutorial.

Rather than training the model here, which would take 1 hour, we instead download a pretrained object.

# arches_params = dict(
#     use_layer_norm="both",
#     use_batch_norm="none",
#     n_layers_decoder=2,
#     n_layers_encoder=2,
# )

# model = TOTALVI(adata, **arches_params)
# model.train(max_epochs=250)

# model_dir = os.path.join(, "seurat_reference_model")
#, overwrite=True)
def download_model(save_path: str, fname: str = "legacy_seurat_reference_model"):
    """Download the pre-trained model."""
    paths = pooch.retrieve(
    return str(Path(paths[0]).parent)
legacy_model_dir = download_model(
model_dir = os.path.join(, "seurat_reference_model")
TOTALVI.convert_legacy_save(legacy_model_dir, model_dir)

model = TOTALVI.load(model_dir, adata=adata)
INFO     File /tmp/tmpgcskgeq2/seurat_reference_model/ already downloaded                                  
INFO     Found batches with missing protein expression                                                             
INFO     Computing empirical prior initialization for protein background.                                          
Anndata setup with scvi-tools version 1.0.3.

Setup via `TOTALVI.setup_anndata` with arguments:
'protein_expression_obsm_key': 'protein_counts',
'protein_names_uns_key': '_protein_names',
'batch_key': 'orig.ident',
'layer': 'counts',
'size_factor_key': None,
'categorical_covariate_keys': None,
'continuous_covariate_keys': None

         Summary Statistics          
┃     Summary Stat Key      Value  ┃
│         n_batch             24   │
│         n_cells           152094 │
│ n_extra_categorical_covs    0    │
│ n_extra_continuous_covs     0    │
│         n_labels            1    │
│        n_proteins          217   │
│          n_vars            4000  │
                 Data Registry                 
┃ Registry Key      scvi-tools Location      ┃
│      X           adata.layers['counts']    │
│    batch        adata.obs['_scvi_batch']   │
│    labels      adata.obs['_scvi_labels']   │
│   proteins    adata.obsm['protein_counts'] │
                     labels State Registry                      
┃      Source Location       Categories  scvi-tools Encoding ┃
│ adata.obs['_scvi_labels']      0                0          │
                     batch State Registry                     
┃     Source Location      Categories  scvi-tools Encoding ┃
│ adata.obs['orig.ident']     P1_0              0          │
│                             P1_3              1          │
│                             P1_7              2          │
│                             P2_0              3          │
│                             P2_3              4          │
│                             P2_7              5          │
│                             P3_0              6          │
│                             P3_3              7          │
│                             P3_7              8          │
│                             P4_0              9          │
│                             P4_3             10          │
│                             P4_7             11          │
│                             P5_0             12          │
│                             P5_3             13          │
│                             P5_7             14          │
│                             P6_0             15          │
│                             P6_3             16          │
│                             P6_7             17          │
│                             P7_0             18          │
│                             P7_3             19          │
│                             P7_7             20          │
│                             P8_0             21          │
│                             P8_3             22          │
│                             P8_7             23          │
plt.plot(model.history["elbo_train"].iloc[10:], label="train")
plt.plot(model.history["elbo_validation"].iloc[10:], label="validation")
plt.title("Negative ELBO over training epochs")
<matplotlib.legend.Legend at 0x7f6846273590>

adata.obsm[TOTALVI_LATENT_KEY] = model.get_latent_representation()

Train a classifier on the latent space#

The representation of the reference data will not change when the model gets updated with query data. Therefore, we can train any simple classifier on the latent space. Here we use a Random Forest from scikit-learn.

By storing the classifier object in our totalVI instance (and giving it an attribute that ends with _), we can automatically save and load this classifier along with our model.

y_train = adata.obs["celltype.l2"].astype("category")
X_train = adata.obsm[TOTALVI_LATENT_KEY]
clf = RandomForestClassifier(
), y_train)

model.latent_space_classifer_ = clf

Inspect reference model#

We can view the latent space using UMAP and plot the study-derived cell type labels.

Here we use the umap API instead of scanpy so that we can later use the same umap object on the query data.


X = adata.obsm[TOTALVI_LATENT_KEY]
trans = umap.UMAP(
adata.obsm[TOTALVI_UMAP_KEY] = trans.fit_transform(X)

model.umap_op_ = trans
    color=["celltype.l2", "orig.ident"],

Map query data#

Now we download our query data and cover the preprocessing necessary to map with our reference model.

In this case our query data are also PBMCs and come from the following study:

  • Arunachalam, Prabhu S., et al. “Systems biological assessment of immunity to mild versus severe COVID-19 infection in humans.” Science 369.6508 (2020): 1210-1220.

This dataset is also CITE-seq, but contains a subset of the proteins of the reference data.

query_path = os.path.join(, "covid_cite.h5ad")

query =
    query_path, backup_url=""


First we remove any doublets from the query dataset using Scrublet. This has to be run per batch, which in this dataset is denoted with "set".

query.obs["doublet_scores"] = 0
query.obs["predicted_doublets"] = True
for s in np.unique(query.obs["set"]):
    mask = query.obs["set"] == s
    counts_matrix = query[mask].X.copy()
    scrub = scr.Scrublet(counts_matrix)
    doublet_scores, predicted_doublets = scrub.scrub_doublets()
    query.obs["doublet_scores"].iloc[mask] = doublet_scores
    query.obs["predicted_doublets"].iloc[mask] = predicted_doublets
Simulating doublets...
Embedding transcriptomes using PCA...
Calculating doublet scores...
Automatically set threshold at doublet score = 0.37
Detected doublet rate = 3.6%
Estimated detectable doublet fraction = 40.3%
Overall doublet rate:
	Expected   = 10.0%
	Estimated  = 9.0%
Elapsed time: 32.2 seconds
Simulating doublets...
Embedding transcriptomes using PCA...
Calculating doublet scores...
Automatically set threshold at doublet score = 0.25
Detected doublet rate = 6.2%
Estimated detectable doublet fraction = 56.4%
Overall doublet rate:
	Expected   = 10.0%
	Estimated  = 11.0%
Elapsed time: 32.9 seconds

Next we do the following:

  1. Store the counts in a layer, perform standard preprocessing

  2. Add blank metadata that we will later fill in with predicted labels

  3. Rename the batch key to correspond to the reference data

  4. Add proteins with all 0 counts for the missing proteins in this query dataset

query.layers["counts"] = query.X.copy()
sc.pp.normalize_total(query, target_sum=1e4)
query.raw = query
# subset to reference vars
query = query[:, adata.var_names].copy()

query.obsm["protein_counts"] = query.obsm["pro_exp"].copy()
query.obs["celltype.l2"] = "Unknown"
query.obs["orig.ident"] = query.obs["set"]
query.obsm["X_umap"] = query.obs[["UMAP1", "UMAP2"]].values

# reorganize query proteins, missing proteins become all 0
for p in adata.obsm["protein_counts"].columns:
    if p not in query.obsm["protein_counts"].columns:
        query.obsm["protein_counts"][p] = 0.0
# ensure columns are in same order
query.obsm["protein_counts"] = query.obsm["protein_counts"].loc[
    :, adata.obsm["protein_counts"].columns

For later visualization, we add some metadata to denote whether the dataset is reference of query

adata.obs["dataset_name"] = "Reference"
query.obs["dataset_name"] = "Query"

Query model training#

query_model = TOTALVI.load_query_data(query, model)
    plan_kwargs={"weight_decay": 0.0, "scale_adversarial_loss": 0.0},
INFO     Found batches with missing protein expression                                                             
INFO     Computing empirical prior initialization for protein background.                                          
Epoch 150/150: 100%|██████████| 150/150 [08:08<00:00,  3.11s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=1.07e+3]Epoch 150/150: 100%|██████████| 150/150 [08:08<00:00,  3.26s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=1.07e+3]

And now we retrieve the latent representation for the query data

TOTALVI_QUERY_LATENT_KEY = "X_totalvi_scarches"

query.obsm[TOTALVI_QUERY_LATENT_KEY] = query_model.get_latent_representation(query)

Query cell type prediction#

We can use the random forest that we previously trained using the latent representation of the query data in the updated model!

# predict cell types of query
predictions = query_model.latent_space_classifer_.predict(
categories = adata.obs["celltype.l2"].astype("category").cat.categories
cat_preds = [categories[i] for i in predictions]
query.obs["celltype.l2"] = cat_preds
query.obs["predicted_l2_scarches"] = cat_preds

Evaluate label transfer#

Now we plot the predicted labels and the labels from the query study (here denoted as "final_clust_review") on the UMAP coordinated used in the query study. We can see that we have an overall strong concordance, but also some important differences. For example, the study found a population of monocytes with interferon signaling, but as the reference contains no such cell type, it can only predict them to be CD14 monocytes.
    color=["celltype.l2", "final_clust_review"],

Use reference UMAP#

We can also use the same UMAP operator from the reference step to visualize our query data.


query.obsm[TOTALVI_UMAP_PROJECT_KEY] = query_model.umap_op_.transform(
    color=["celltype.l2", "final_clust_review"],

And finally we can combine the umaps together and visualize jointly.

umap_adata = sc.AnnData(
umap_adata.obs["celltype"] = np.concatenate(
    [query.obs["celltype.l2"].values, adata.obs["celltype.l2"].values]
umap_adata.obs["dataset"] = np.concatenate(
    [query.shape[0] * ["query"], adata.shape[0] * ["reference"]]
umap_adata.obsm["X_umap"] = umap_adata.X
inds = np.random.permutation(np.arange(umap_adata.shape[0]))
    color=["celltype", "dataset"],

Clean up#

Uncomment the following line to remove all data files created in this tutorial:

# save_dir.cleanup()