Annotation with CellAssign#
Assigning single-cell RNA-seq data to known cell types#
CellAssign is a probabilistic model that uses prior knowledge of cell-type marker genes to annotate scRNA data into predefined cell types. Unlike other methods for assigning cell types, CellAssign does not require labeled single cell data and only needs to know whether or not each given gene is a marker of each cell type. The original paper and R code are linked below.
Code: https://github.com/Irrationone/cellassign
This notebook will demonstrate how to use CellAssign on follicular lymphoma and HGSC scRNA data.
Uncomment the following lines in Google Colab in order to install scvi-tools
:
# !pip install --quiet scvi-colab
# from scvi_colab import install
# install()
import os
import tempfile
import gdown
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scvi.external import CellAssign
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.0.3
You can modify save_dir
below to change where the data files for this tutorial are saved.
sc.set_figure_params(figsize=(4, 4), frameon=False)
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()
%config InlineBackend.print_figure_kwargs={"facecolor" : "w"}
%config InlineBackend.figure_format="retina"
To demonstrate CellAssign, we use the data from the original publication, which we converted into h5ad format. The data are originally available from here:
https://zenodo.org/record/3372746
def download_data(save_path: str):
sce_follicular_path = os.path.join(save_path, "sce_follicular.h5ad")
sce_hgsc_path = os.path.join(save_path, "sce_hgsc.h5ad")
fl_celltype_path = os.path.join(save_path, "fl_celltype.csv")
hgsc_celltype_path = os.path.join(save_path, "hgsc_celltype.csv")
gdown.download(
"https://drive.google.com/uc?id=10l6m2KKKioCZnQlRHomheappHh-jTFmx",
sce_follicular_path,
quiet=False,
)
gdown.download(
"https://drive.google.com/uc?id=1Pae7VEcoZbKRvtllGAEWG4SOLWSjjtCO",
sce_hgsc_path,
quiet=False,
)
gdown.download(
"https://drive.google.com/uc?id=1tJSOI9ve0i78WmszMLx2ul8F8tGycBTd",
fl_celltype_path,
quiet=False,
)
gdown.download(
"https://drive.google.com/uc?id=1Mk5uPdnPC4IMRnuG5N4uFvypT8hPdJ74",
hgsc_celltype_path,
quiet=False,
)
return (
sce_follicular_path,
sce_hgsc_path,
fl_celltype_path,
hgsc_celltype_path,
)
(
sce_follicular_path,
sce_hgsc_path,
fl_celltype_path,
hgsc_celltype_path,
) = download_data(save_dir.name)
Follicular Lymphoma Data#
Load follicular lymphoma data and marker gene matrix (see Supplementary Table 2 from the original paper).
follicular_adata = sc.read(sce_follicular_path)
fl_celltype_markers = pd.read_csv(fl_celltype_path, index_col=0)
follicular_adata.obs.index = follicular_adata.obs.index.astype("str")
follicular_adata.var.index = follicular_adata.var.index.astype("str")
follicular_adata.var_names_make_unique()
follicular_adata.obs_names_make_unique()
follicular_adata
AnnData object with n_obs × n_vars = 9156 × 33694
obs: 'Sample', 'dataset', 'patient', 'timepoint', 'progression_status', 'patient_progression', 'sample_barcode', 'is_cell_control', 'total_features_by_counts', 'log10_total_features_by_counts', 'total_counts', 'log10_total_counts', 'pct_counts_in_top_50_features', 'pct_counts_in_top_100_features', 'pct_counts_in_top_200_features', 'pct_counts_in_top_500_features', 'total_features_by_counts_endogenous', 'log10_total_features_by_counts_endogenous', 'total_counts_endogenous', 'log10_total_counts_endogenous', 'pct_counts_endogenous', 'pct_counts_in_top_50_features_endogenous', 'pct_counts_in_top_100_features_endogenous', 'pct_counts_in_top_200_features_endogenous', 'pct_counts_in_top_500_features_endogenous', 'total_features_by_counts_feature_control', 'log10_total_features_by_counts_feature_control', 'total_counts_feature_control', 'log10_total_counts_feature_control', 'pct_counts_feature_control', 'pct_counts_in_top_50_features_feature_control', 'pct_counts_in_top_100_features_feature_control', 'pct_counts_in_top_200_features_feature_control', 'pct_counts_in_top_500_features_feature_control', 'total_features_by_counts_mitochondrial', 'log10_total_features_by_counts_mitochondrial', 'total_counts_mitochondrial', 'log10_total_counts_mitochondrial', 'pct_counts_mitochondrial', 'pct_counts_in_top_50_features_mitochondrial', 'pct_counts_in_top_100_features_mitochondrial', 'pct_counts_in_top_200_features_mitochondrial', 'pct_counts_in_top_500_features_mitochondrial', 'total_features_by_counts_ribosomal', 'log10_total_features_by_counts_ribosomal', 'total_counts_ribosomal', 'log10_total_counts_ribosomal', 'pct_counts_ribosomal', 'pct_counts_in_top_50_features_ribosomal', 'pct_counts_in_top_100_features_ribosomal', 'pct_counts_in_top_200_features_ribosomal', 'pct_counts_in_top_500_features_ribosomal', 'size_factor', 'cellassign_cluster_broad', 'cellassign_cluster_specific', 'B.cells..broad.', 'T.cells..broad.', 'other..broad.', 'B.cells', 'Cytotoxic.T.cells', 'CD4.T.cells', 'Tfh', 'other', 'celltype', 'malignant_status_manual', 'celltype_full', 'G1', 'S', 'G2M', 'Cell_Cycle', 't_seurat_cluster', 't_seurat_0.8_cluster', 't_phenograph_cluster', 't_cluster', 'malignant_seurat_cluster', 'malignant_seurat_0.8_cluster', 'malignant_phenograph_cluster', 'malignant_cluster', 'b_seurat_cluster', 'b_seurat_0.8_cluster', 'b_phenograph_cluster', 'b_cluster', 'all_seurat_cluster', 'all_seurat_0.8_cluster', 'all_seurat_1.2_cluster', 'all_sc3_cluster', 'all_SC3_cluster', 'all_cluster', 'all_subset_seurat_cluster', 'all_subset_seurat_0.8_cluster', 'all_subset_seurat_1.2_cluster', 'all_subset_cluster'
var: 'ID', 'is_feature_control', 'is_feature_control_mitochondrial', 'is_feature_control_ribosomal', 'mean_counts', 'log10_mean_counts', 'n_cells_by_counts', 'pct_dropout_by_counts', 'total_counts', 'log10_total_counts'
uns: 'log.exprs.offset'
obsm: 'X_pca', 'X_tsne', 'X_umap'
layers: 'logcounts'
Create and fit CellAssign model#
The anndata object and cell type marker matrix should contain the same genes, so we index into adata
to include only the genes from marker_gene_mat
.
follicular_bdata = follicular_adata[:, fl_celltype_markers.index].copy()
Then we setup anndata and initialize a CellAssign
model. Here we set the size_factor_key
to “size_factor”, which is a column in bdata.obs
.
Note
A size factor may be defined manually as scaled library size (total UMI count) and should not be placed on the log scale, as the model will do this manually. The library size should be computed before any gene subsetting (in this case, technically, a few notebook cells up).
This can be acheived as follows:
lib_size = adata.X.sum(1)
adata.obs["size_factor"] = lib_size / np.mean(lib_size)
scvi.external.CellAssign.setup_anndata(follicular_bdata, size_factor_key="size_factor")
follicular_model = CellAssign(follicular_bdata, fl_celltype_markers)
follicular_model.train()
Epoch 400/400: 100%|██████████| 400/400 [00:20<00:00, 19.74it/s, v_num=1, train_loss_step=20.6, train_loss_epoch=19.9]Epoch 400/400: 100%|██████████| 400/400 [00:20<00:00, 19.80it/s, v_num=1, train_loss_step=20.6, train_loss_epoch=19.9]
Inspecting the convergence:
follicular_model.history["elbo_validation"].plot()
<Axes: xlabel='epoch'>

Predict and plot assigned cell types#
Predict the soft cell type assignment probability for each cell.
predictions = follicular_model.predict()
predictions.head()
B cells | Cytotoxic T cells | CD4 T cells | Tfh | other | |
---|---|---|---|---|---|
0 | 1.000000e+00 | 2.514913e-19 | 5.140069e-19 | 7.447620e-16 | 5.360906e-15 |
1 | 1.000000e+00 | 3.205316e-21 | 1.535639e-24 | 2.355200e-18 | 9.440110e-17 |
2 | 1.000000e+00 | 3.232029e-26 | 3.399115e-29 | 3.017093e-23 | 1.899120e-21 |
3 | 1.000000e+00 | 1.341051e-43 | 8.725416e-40 | 6.327488e-38 | 9.754956e-34 |
4 | 1.627084e-14 | 2.318600e-10 | 5.974996e-03 | 9.940250e-01 | 8.278363e-16 |
We can visualize the probabilities of assignment with a heatmap that returns the probability matrix for each cell and cell type.
sns.clustermap(predictions, cmap="viridis")
<seaborn.matrix.ClusterGrid at 0x7f56b060c850>

We then create a UMAP plot labeled by maximum probability assignments from the CellAssign model. The left plot contains the true cell types and the right plot contains our model’s predictions.
follicular_bdata.obs["cellassign_predictions"] = predictions.idxmax(axis=1).values
# celltype is the original CellAssign prediction
sc.pl.umap(
follicular_bdata,
color=["celltype", "cellassign_predictions"],
frameon=False,
ncols=1,
)

Model reproducibility#
We see that the scvi-tools implementation highly reproduces the original implementation’s predictions.
df = follicular_bdata.obs
confusion_matrix = pd.crosstab(
df["cellassign_predictions"],
df["celltype"],
rownames=["cellassign_predictions"],
colnames=["Original predictions"],
)
confusion_matrix /= confusion_matrix.sum(1).ravel().reshape(-1, 1)
fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(
confusion_matrix,
cmap=sns.diverging_palette(245, 320, s=60, as_cmap=True),
ax=ax,
square=True,
cbar_kws=dict(shrink=0.4, aspect=12),
)
<Axes: xlabel='Original predictions', ylabel='cellassign_predictions'>

HGSC Data#
We can repeat the same process for HGSC data.
hgsc_adata = scvi.data.read_h5ad(sce_hgsc_path)
hgsc_celltype_markers = pd.read_csv(hgsc_celltype_path, index_col=0)
hgsc_adata.var_names_make_unique()
hgsc_adata.obs_names_make_unique()
hgsc_adata
AnnData object with n_obs × n_vars = 4848 × 33694
obs: 'Sample', 'dataset', 'patient', 'timepoint', 'site', 'sample_barcode', 'is_cell_control', 'total_features_by_counts', 'log10_total_features_by_counts', 'total_counts', 'log10_total_counts', 'pct_counts_in_top_50_features', 'pct_counts_in_top_100_features', 'pct_counts_in_top_200_features', 'pct_counts_in_top_500_features', 'total_features_by_counts_endogenous', 'log10_total_features_by_counts_endogenous', 'total_counts_endogenous', 'log10_total_counts_endogenous', 'pct_counts_endogenous', 'pct_counts_in_top_50_features_endogenous', 'pct_counts_in_top_100_features_endogenous', 'pct_counts_in_top_200_features_endogenous', 'pct_counts_in_top_500_features_endogenous', 'total_features_by_counts_feature_control', 'log10_total_features_by_counts_feature_control', 'total_counts_feature_control', 'log10_total_counts_feature_control', 'pct_counts_feature_control', 'pct_counts_in_top_50_features_feature_control', 'pct_counts_in_top_100_features_feature_control', 'pct_counts_in_top_200_features_feature_control', 'pct_counts_in_top_500_features_feature_control', 'total_features_by_counts_mitochondrial', 'log10_total_features_by_counts_mitochondrial', 'total_counts_mitochondrial', 'log10_total_counts_mitochondrial', 'pct_counts_mitochondrial', 'pct_counts_in_top_50_features_mitochondrial', 'pct_counts_in_top_100_features_mitochondrial', 'pct_counts_in_top_200_features_mitochondrial', 'pct_counts_in_top_500_features_mitochondrial', 'total_features_by_counts_ribosomal', 'log10_total_features_by_counts_ribosomal', 'total_counts_ribosomal', 'log10_total_counts_ribosomal', 'pct_counts_ribosomal', 'pct_counts_in_top_50_features_ribosomal', 'pct_counts_in_top_100_features_ribosomal', 'pct_counts_in_top_200_features_ribosomal', 'pct_counts_in_top_500_features_ribosomal', 'size_factor', 'cellassign_cluster_broad', 'cellassign_cluster_specific', 'B.cells..broad.', 'T.cells..broad.', 'Monocyte.Macrophage..broad.', 'Epithelial.cells..broad.', 'Ovarian.stromal.cells..broad.', 'Ovarian.myofibroblast..broad.', 'Vascular.smooth.muscle.cells..broad.', 'Endothelial.cells..broad.', 'other..broad.', 'B.cells', 'CD4.T.cells', 'Cytotoxic.T.cells', 'Monocyte.Macrophage', 'Epithelial.cells', 'Ovarian.stromal.cells', 'Ovarian.myofibroblast', 'Vascular.smooth.muscle.cells', 'Endothelial.cells', 'other', 'celltype', 'G1', 'S', 'G2M', 'Cell_Cycle', 'epithelial_seurat_cluster', 'epithelial_seurat_0.2_cluster', 'epithelial_phenograph_cluster', 'epithelial_sc3_cluster', 'epithelial_SC3_cluster', 'epithelial_cluster', 'all_seurat_cluster', 'all_seurat_0.8_cluster', 'all_seurat_1.2_cluster', 'all_sc3_cluster', 'all_SC3_cluster', 'all_cluster', 'all_subset_seurat_cluster', 'all_subset_seurat_0.8_cluster', 'all_subset_seurat_1.2_cluster', 'all_subset_cluster'
var: 'ID', 'is_feature_control', 'is_feature_control_mitochondrial', 'is_feature_control_ribosomal', 'mean_counts', 'log10_mean_counts', 'n_cells_by_counts', 'pct_dropout_by_counts', 'total_counts', 'log10_total_counts'
uns: 'log.exprs.offset'
obsm: 'X_pca', 'X_tsne', 'X_umap'
layers: 'logcounts'
Create and fit CellAssign model#
hgsc_bdata = hgsc_adata[:, hgsc_celltype_markers.index].copy()
scvi.external.CellAssign.setup_anndata(hgsc_bdata, "size_factor")
hgsc_model = CellAssign(hgsc_bdata, hgsc_celltype_markers)
hgsc_model.train()
Epoch 400/400: 100%|██████████| 400/400 [00:14<00:00, 26.89it/s, v_num=1, train_loss_step=43.1, train_loss_epoch=40.8]Epoch 400/400: 100%|██████████| 400/400 [00:14<00:00, 27.03it/s, v_num=1, train_loss_step=43.1, train_loss_epoch=40.8]
hgsc_model.history["elbo_validation"].plot()
<Axes: xlabel='epoch'>

Predict and plot assigned cell types#
predictions_hgsc = hgsc_model.predict()
sns.clustermap(predictions_hgsc, cmap="viridis")
<seaborn.matrix.ClusterGrid at 0x7f560079a0d0>

hgsc_bdata.obs["cellassign_predictions"] = predictions_hgsc.idxmax(axis=1).values
sc.pl.umap(
hgsc_bdata,
color=["celltype", "cellassign_predictions"],
ncols=1,
frameon=False,
)

Model reproducibility#
df = hgsc_bdata.obs
confusion_matrix = pd.crosstab(
df["cellassign_predictions"],
df["celltype"],
rownames=["cellassign_predictions"],
colnames=["Original predictions"],
)
confusion_matrix /= confusion_matrix.sum(1).ravel().reshape(-1, 1)
fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(
confusion_matrix,
cmap=sns.diverging_palette(245, 320, s=60, as_cmap=True),
ax=ax,
square=True,
cbar_kws=dict(shrink=0.4, aspect=12),
)
<Axes: xlabel='Original predictions', ylabel='cellassign_predictions'>
