Note
This page was generated from
amortized_lda.ipynb.
Interactive online version:
.
Some tutorial content may look better in light mode.
Topic Modeling with Amortized LDA#
In this tutorial, we will explore how to run the amortized Latent Dirichlet Allocation (LDA) model implementation in scvi-tools. LDA is a topic modelling method first introduced in the natural language processing field. By treating each cell as a document and each gene expression count as a word, we can carry over the method to the single-cell biology field.
Below, we will train the model over a dataset, plot the topics over a UMAP of the reference set, and inspect the topics for characteristic gene sets.
As an example, we use the PBMC 10K dataset from 10x Genomics.
[ ]:
!pip install --quiet scvi-colab
from scvi_colab import install
install()
[1]:
import os
import pandas as pd
import scanpy as sc
import scvi
[2]:
sc.set_figure_params(figsize=(4, 4))
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'
Global seed set to 0
/usr/local/lib/python3.7/dist-packages/numba/np/ufunc/parallel.py:363: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.
warnings.warn(problem)
Load and process data#
Load the 10x genomics PBMC dataset. Generally, it is good practice for LDA to remove ubiquitous genes, to prevent the model from modeling these genes as a separate topic. Here, we first filter out all mitochrondrial genes, then select the top 1000 variable genes with seurat_v3 method from the remaining genes.
[3]:
save_path = "data"
adata = sc.read(
os.path.join(save_path, "pbmc_10k_protein_v3.h5ad"),
backup_url="https://github.com/YosefLab/scVI-data/raw/master/pbmc_10k_protein_v3.h5ad?raw=true",
)
adata.layers["counts"] = adata.X.copy() # preserve counts
sc.pp.normalize_total(adata, target_sum=10e4)
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
adata = adata[:, ~adata.var_names.str.startswith("MT-")]
sc.pp.highly_variable_genes(
adata, flavor="seurat_v3", layer="counts", n_top_genes=1000, subset=True
)
Trying to set attribute `._uns` of view, copying.
Create and fit AmortizedLDA
model#
Here, we initialize and fit an AmortizedLDA
model on the dataset. We pick 10 topics to model in this case.
[4]:
n_topics = 10
scvi.model.AmortizedLDA.setup_anndata(adata, layer="counts")
model = scvi.model.AmortizedLDA(adata, n_topics=n_topics)
Note
By default we train with KL annealing which means the effective loss will generally not decrease steadily in the beginning. Our Pyro implementations present this train loss term as the elbo_train
in the progress bar which is misleading. We plan on correcting this in the future.
[5]:
model.train()
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/configuration_validator.py:120: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 1000/1000: 100%|██████████| 1000/1000 [12:06<00:00, 1.38it/s, v_num=1, elbo_train=1.87e+7]
Visualizing learned topics#
By calling model.get_latent_representation()
, the model will compute a Monte Carlo estimate of the topic proportions for each cell. Since we use a logistic-Normal distribution to approximate the Dirichlet distribution, the model cannot compute the analytic mean. The number of samples used to compute the latent representation can be configured with the optional argument n_samples
.
[6]:
topic_prop = model.get_latent_representation()
topic_prop.head()
[6]:
topic_0 | topic_1 | topic_2 | topic_3 | topic_4 | topic_5 | topic_6 | topic_7 | topic_8 | topic_9 | |
---|---|---|---|---|---|---|---|---|---|---|
index | ||||||||||
AAACCCAAGATTGTGA-1 | 0.000474 | 0.021273 | 0.116839 | 0.005243 | 0.000534 | 0.852167 | 0.000484 | 0.001973 | 0.000373 | 0.000640 |
AAACCCACATCGGTTA-1 | 0.000340 | 0.005373 | 0.000217 | 0.000532 | 0.000214 | 0.992013 | 0.000402 | 0.000553 | 0.000161 | 0.000195 |
AAACCCAGTACCGCGT-1 | 0.002523 | 0.018910 | 0.602798 | 0.011032 | 0.002445 | 0.356765 | 0.002098 | 0.001950 | 0.000712 | 0.000767 |
AAACCCAGTATCGAAA-1 | 0.011468 | 0.004124 | 0.003098 | 0.004504 | 0.006355 | 0.002744 | 0.003459 | 0.003941 | 0.004475 | 0.955833 |
AAACCCAGTCGTCATA-1 | 0.000981 | 0.000920 | 0.000645 | 0.000751 | 0.001073 | 0.000595 | 0.000921 | 0.000838 | 0.000661 | 0.992613 |
[7]:
# Save topic proportions in obsm and obs columns.
adata.obsm["X_LDA"] = topic_prop
for i in range(n_topics):
adata.obs[f"LDA_topic_{i}"] = topic_prop[[f"topic_{i}"]]
Plot UMAP#
[9]:
sc.tl.pca(adata, svd_solver="arpack")
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=20)
sc.tl.umap(adata)
sc.tl.leiden(adata, key_added="leiden_scVI", resolution=0.8)
# Save UMAP to custom .obsm field.
adata.obsm["raw_counts_umap"] = adata.obsm["X_umap"].copy()
[10]:
sc.pl.embedding(adata, "raw_counts_umap", color=["leiden_scVI"], frameon=False)

Color UMAP by topic proportions#
By coloring by UMAP by topic proportions, we find that the learned topics are generally dominant in cells close together in the UMAP space. In some cases, a topic is dominant in multiple clusters in the UMAP, which indicates similarilty between these clusters despite being far apart in the plot. This is not surprising considering that UMAP does not preserve local relationships beyond a certain threshold.
[11]:
sc.pl.embedding(
adata,
"raw_counts_umap",
color=[f"LDA_topic_{i}" for i in range(n_topics)],
frameon=False,
)

Plot UMAP in topic space#
[12]:
sc.pp.neighbors(adata, use_rep="X_LDA", n_neighbors=20, metric="hellinger")
sc.tl.umap(adata)
# Save UMAP to custom .obsm field.
adata.obsm["topic_space_umap"] = adata.obsm["X_umap"].copy()
[13]:
sc.pl.embedding(
adata,
"topic_space_umap",
color=[f"LDA_topic_{i}" for i in range(n_topics)],
frameon=False,
)

Find top genes per topic#
Similar to the topic proportions, model.get_feature_by_topic()
returns a Monte Carlo estimate of the gene by topic matrix, which contains the proportion that a gene is weighted in each topic. This is also due to another approximation of the Dirichlet with a logistic-Normal distribution. We can inspect each topic in this matrix and sort by proportion allocated to each gene to determine top genes characterizing each topic.
[14]:
feature_by_topic = model.get_feature_by_topic()
feature_by_topic.head()
[14]:
topic_0 | topic_1 | topic_2 | topic_3 | topic_4 | topic_5 | topic_6 | topic_7 | topic_8 | topic_9 | |
---|---|---|---|---|---|---|---|---|---|---|
index | ||||||||||
AL645608.8 | 0.000001 | 0.000003 | 0.000002 | 0.000002 | 0.000041 | 3.110339e-06 | 0.000005 | 0.000003 | 0.000004 | 0.000002 |
HES4 | 0.000008 | 0.000013 | 0.000008 | 0.000019 | 0.000843 | 8.526634e-06 | 0.000005 | 0.000012 | 0.000009 | 0.000006 |
ISG15 | 0.001131 | 0.000169 | 0.000402 | 0.000596 | 0.000610 | 2.966939e-04 | 0.001417 | 0.000501 | 0.000305 | 0.001375 |
TNFRSF18 | 0.000296 | 0.000003 | 0.000001 | 0.000002 | 0.000004 | 9.796551e-07 | 0.000553 | 0.000006 | 0.000025 | 0.000132 |
TNFRSF4 | 0.000667 | 0.000004 | 0.000001 | 0.000004 | 0.000004 | 1.897183e-06 | 0.000922 | 0.000004 | 0.000006 | 0.000059 |
[15]:
rank_by_topic = pd.DataFrame()
for i in range(n_topics):
topic_name = f"topic_{i}"
topic = feature_by_topic[topic_name].sort_values(ascending=False)
rank_by_topic[topic_name] = topic.index
rank_by_topic[f"{topic_name}_prop"] = topic.values
[16]:
rank_by_topic.head()
[16]:
topic_0 | topic_0_prop | topic_1 | topic_1_prop | topic_2 | topic_2_prop | topic_3 | topic_3_prop | topic_4 | topic_4_prop | topic_5 | topic_5_prop | topic_6 | topic_6_prop | topic_7 | topic_7_prop | topic_8 | topic_8_prop | topic_9 | topic_9_prop | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | TMSB4X | 0.126198 | LYZ | 0.071308 | LYZ | 0.056180 | FTL | 0.068417 | FTL | 0.098839 | S100A9 | 0.133041 | ACTB | 0.164749 | CD74 | 0.122255 | IGKC | 0.189100 | ACTB | 0.090690 |
1 | TMSB10 | 0.086073 | ACTB | 0.071091 | S100A9 | 0.053984 | FTH1 | 0.051274 | ACTB | 0.070052 | S100A8 | 0.094452 | TMSB4X | 0.132325 | HLA-DRA | 0.089028 | IGLC2 | 0.099806 | TMSB4X | 0.088525 |
2 | ACTB | 0.084189 | HLA-DRA | 0.040470 | FTH1 | 0.050848 | ACTB | 0.047326 | TMSB4X | 0.059721 | LYZ | 0.056391 | TMSB10 | 0.082519 | TMSB4X | 0.059973 | IGHA1 | 0.065741 | GNLY | 0.069525 |
3 | JUNB | 0.035964 | CD74 | 0.036424 | FTL | 0.050564 | TMSB4X | 0.043904 | FTH1 | 0.057384 | FTL | 0.044824 | ACTG1 | 0.055351 | ACTB | 0.055289 | IGHM | 0.046249 | NKG7 | 0.052817 |
4 | FTL | 0.028489 | TMSB4X | 0.035223 | ACTB | 0.034482 | LYZ | 0.040371 | S100A4 | 0.029948 | ACTB | 0.039486 | S100A4 | 0.037912 | HLA-DRB1 | 0.050972 | CD74 | 0.045520 | TMSB10 | 0.040158 |