Train a scVI model using Lamin#

This notebook demonstrates a scalable approach to training an scVI model on Census data using Lamin dataloader. LaminDB is a database system based on its MappedCollection designed to support efficient storage, management, and querying of scientific data, particularly in machine learning, bioinformatics, and data science applications. It allows for the easy organization, sharing, and querying of complex datasets, such as those involved in research, experiments, or models. See here for more information

Note

Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
import time

import scanpy as sc
import scvi
from scvi.dataloaders import MappedCollectionDataModule
→ connected lamindb: anonymous/lamindb_collection_scanvi
# os.system("lamin init --storage ./lamindb_collection")  # one time for github runner (comment)
import lamindb as ln
# ln.setup.init()
pbmc_dataset = scvi.data.pbmc_dataset(
    save_path=".",
    remove_extracted_data=True,
)
INFO     File ./gene_info_pbmc.csv already downloaded                                                              
INFO     File ./pbmc_metadata.pickle already downloaded                                                            
INFO     File ./pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded                                         
INFO     Extracting tar file                                                                                       
INFO     Removing extracted data at ./pbmc8k/filtered_gene_bc_matrices                                             
INFO     File ./pbmc4k/filtered_gene_bc_matrices.tar.gz already downloaded                                         
INFO     Extracting tar file                                                                                       
INFO     Removing extracted data at ./pbmc4k/filtered_gene_bc_matrices
pbmc_seurat_v4_cite_seq = scvi.data.pbmc_seurat_v4_cite_seq(save_path=".")
INFO     File ./pbmc_seurat_v4.h5ad already downloaded
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs.Phase
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs["batch"].astype("str")
pbmc_dataset.obs["batch"] = pbmc_dataset.obs["batch"].astype("str")
import numpy as np

gene_intersection = np.intersect1d(
    pbmc_dataset.var.gene_symbols.values, pbmc_seurat_v4_cite_seq.var.index.values
)
pbmc_dataset_filtered = pbmc_dataset[:, pbmc_dataset.var["gene_symbols"].isin(gene_intersection)]
pbmc_seurat_v4_cite_seq_filtered = pbmc_seurat_v4_cite_seq[
    :, pbmc_seurat_v4_cite_seq.var_names.isin(gene_intersection)
]
pbmc_dataset_filtered.var_names = pbmc_dataset_filtered.var["gene_symbols"].values
pbmc_dataset_filtered.obs["cell_type"] = pbmc_dataset_filtered.obs["str_labels"].astype("str")
pbmc_dataset_filtered.obs.loc[
    pbmc_dataset_filtered.obs["cell_type"] == "FCGR3A+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
    pbmc_dataset_filtered.obs["cell_type"] == "CD14+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
    pbmc_dataset_filtered.obs["cell_type"] == "Megakaryocytes", "cell_type"
] = "Other"
pbmc_dataset_filtered
AnnData object with n_obs × n_vars = 11990 × 3315
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'
    uns: 'cell_types'
    obsm: 'design', 'raw_qc', 'normalized_qc', 'qc_pc'
pbmc_dataset_filtered.obs["cell_type"].value_counts()
cell_type
CD4 T cells        4996
Monocytes          2578
B cells            1621
CD8 T cells        1448
Other               551
NK cells            457
Dendritic Cells     339
Name: count, dtype: int64
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] = pbmc_seurat_v4_cite_seq_filtered.obs[
    "celltype.l1"
].astype("str")
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "other", "cell_type"
] = "Other"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "B", "cell_type"
] = "B cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "DC", "cell_type"
] = "Dendritic Cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "NK", "cell_type"
] = "NK cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD4 T", "cell_type"
] = "CD4 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD8 T", "cell_type"
] = "CD8 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
    pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "Mono", "cell_type"
] = "Monocytes"
pbmc_seurat_v4_cite_seq_filtered
AnnData object with n_obs × n_vars = 152094 × 3315
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'batch', 'cell_type'
    var: 'mt'
    obsm: 'protein_counts'
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"].value_counts()
cell_type
Monocytes          47217
CD4 T cells        40245
CD8 T cells        24399
NK cells           15384
B cells            13192
other T             5828
Dendritic Cells     3488
Other               2341
Name: count, dtype: int64
ln.track()
→ created Transform('hUtoQ4gffICo0000'), started new Run('miLp7BSR...') at 2025-05-08 09:03:42 UTC
# prepare test data
# adata1 = synthetic_iid()
# adata2 = synthetic_iid()

artifact1 = ln.Artifact.from_anndata(pbmc_dataset_filtered, key="part_one1.h5ad").save()
artifact2 = ln.Artifact.from_anndata(pbmc_seurat_v4_cite_seq_filtered, key="part_two1.h5ad").save()

collection = ln.Collection([artifact1, artifact2], key="gather")
collection.save()
Collection(uid='WYhawczHOi6K1dYG0004', is_latest=True, key='gather', hash='wPM0DpO-bWnuNH5Kqd1zEA', space_id=1, created_by_id=1, run_id=1, created_at=2025-05-08 09:03:45 UTC)
# We load the collection to see it consists of many h5ad files
artifacts = collection.artifacts.all()
artifacts.df()
uid key description suffix kind otype size hash n_files n_observations _hash_type _key_is_virtual _overwrite_versions space_id storage_id schema_id version is_latest run_id created_at created_by_id _aux _branch_code
id
9 cq4K7ESTSPnef6Ol0000 part_one1.h5ad None .h5ad dataset AnnData 49787918 Qa1fqQcaprCKNBrFOF1l1g None 11990 md5 True False 1 1 None None True 1 2025-05-08 09:03:44.063000+00:00 1 None 1
10 0DZgfpOHWcVI3tgI0000 part_two1.h5ad None .h5ad dataset AnnData 998889103 DozCh58Pwbgf9-wzsb976e None 152094 sha1-fl True False 1 1 None None True 1 2025-05-08 09:03:44.913000+00:00 1 None 1
# we can now define the batch and data loader
batch_keys = "batch"
datamodule = MappedCollectionDataModule(
    collection,
    batch_key=batch_keys,
    batch_size=1024,
    shuffle=True,
    join="inner",
)
print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch)
164084 3315 5
# print(datamodule.registry)
# Init the model
model = scvi.model.SCVI(registry=datamodule.registry)
# Training the model
import gc

gc.collect()
start = time.time()
model.train(
    max_epochs=100,
    batch_size=1024,
    plan_kwargs={"lr": 0.003, "compile": False},
    early_stopping=False,
    datamodule=datamodule.inference_dataloader(),
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")
Elapsed time: 485.27 seconds
model.history["elbo_train"].tail()
elbo_train
epoch
95 1853.080566
96 1852.66333
97 1852.367554
98 1852.291748
99 1852.092651
# Save the model
model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule)
model.history.keys()
dict_keys(['kl_weight', 'train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train'])
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_dataloader = datamodule.inference_dataloader()
latent = model.get_latent_representation(dataloader=inference_dataloader)
# We extract the adata of the model, to be able to plot it
adata = collection.load(join="inner")
adata.obsm["scvi"] = latent
# adata.obs['cell_type'] = pbmc_dataset_filtered.obs.str_labels.values.tolist() + pbmc_seurat_v4_cite_seq_filtered.obs['celltype.l1'].values.tolist()
adata.obs
batch cell_type artifact_uid
AAACCTGAGCTAGTGG-1 0 CD4 T cells cq4K7ESTSPnef6Ol0000
AAACCTGCACATTAGC-1 0 CD4 T cells cq4K7ESTSPnef6Ol0000
AAACCTGCACTGTTAG-1 0 Monocytes cq4K7ESTSPnef6Ol0000
AAACCTGCATAGTAAG-1 0 Monocytes cq4K7ESTSPnef6Ol0000
AAACCTGCATGAACCT-1 0 CD8 T cells cq4K7ESTSPnef6Ol0000
... ... ... ...
E2L8_TTTGTTGGTCGTGATT S CD8 T cells 0DZgfpOHWcVI3tgI0000
E2L8_TTTGTTGGTGTGCCTG G1 Monocytes 0DZgfpOHWcVI3tgI0000
E2L8_TTTGTTGGTTAGTTCG S B cells 0DZgfpOHWcVI3tgI0000
E2L8_TTTGTTGGTTGGCTAT G1 Monocytes 0DZgfpOHWcVI3tgI0000
E2L8_TTTGTTGTCTCATGGA G1 Monocytes 0DZgfpOHWcVI3tgI0000

164084 rows × 3 columns

# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCVI")
../../../../_images/151c2cb210e599dd59c52a85d1b9f206298973859587ec819a4c8055881d8ae7.png
sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI")
../../../../_images/4754ea6d25adafffc3a0595c62b98cc5af59bc678681dfd4115c3149028f468d.png

scanvi#

labels_keys = "cell_type"
datamodule_scanvi = MappedCollectionDataModule(
    collection,
    batch_key=batch_keys,
    label_key=labels_keys,
    batch_size=1024,
    shuffle=True,
    model_name="SCANVI",
    join="inner",
)
print(
    datamodule_scanvi.n_obs,
    datamodule_scanvi.n_vars,
    datamodule_scanvi.n_batch,
    datamodule_scanvi.n_labels,
)
164084 3315 5 9
# We can now create the scanVI model object and train it:
datamodule_scanvi.setup(stage="train")
model_scanvi = scvi.model.SCANVI(
    adata=None,
    registry=datamodule_scanvi.registry,
    datamodule=datamodule_scanvi,
)
# datamodule_scanvi.registry
# Training the model
import gc

gc.collect()
start3 = time.time()
model_scanvi.train(
    max_epochs=20,
    batch_size=1024,
    plan_kwargs={"lr": 0.01, "compile": False},
    early_stopping=False,
    n_samples_per_label=100,
    datamodule=datamodule_scanvi,
)
end3 = time.time()
print(f"Elapsed time: {end3 - start3:.2f} seconds")
INFO     Training for 20 epochs.
Elapsed time: 135.02 seconds
# Save the model
model_scanvi.save(
    "lamin_scanvi_model", save_anndata=False, overwrite=True, datamodule=datamodule_scanvi
)
# from pprint import pprint
# pprint(model_scanvi.registry)
model_scanvi.history.keys()
dict_keys(['train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train', 'train_classification_loss', 'train_accuracy', 'train_f1_score', 'train_calibration_error'])
model_scanvi.history["train_accuracy"].tail()
train_accuracy
epoch
15 0.993022
16 0.993101
17 0.993912
18 0.993973
19 0.994369
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_scanvi_dataloader = datamodule_scanvi.inference_dataloader()
latent_scanvi = model_scanvi.get_latent_representation(dataloader=inference_scanvi_dataloader)
adata.obsm["scanvi"] = latent_scanvi
# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scanvi", key_added="scanvi")
sc.tl.umap(adata, neighbors_key="scanvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCANVI")
../../../../_images/15dc7c74365228cb7b06c3239d27d1c1bddc96b0fd36afe8c15519775bdcf09c.png
sc.pl.umap(adata, color="cell_type", title="cell_type_SCANVI")
../../../../_images/13aa3dcec3d878f9a74ccfb7503e182f2190d7e3aa41b7c36ce6dc5557e435d1.png
adata.obs["predictions_scanvi"] = model_scanvi.predict(
    dataloader=inference_scanvi_dataloader, batch_size=1024
)
# adata.obs["cell_type"]
# adata.obs["predictions_scanvi"]
df = adata.obs.groupby(["cell_type", "predictions_scanvi"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')
../../../../_images/8b8b0f240cc1219a0b65eea02fa4f7eec2d001e5c9ac975c9498bb1584cef155.png

run regulary using adata and compare#

adata.layers["counts"] = adata.X.copy()  # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # freeze the state in `.raw`
# sc.pp.highly_variable_genes(
#    adata,
#    n_top_genes=top_n_hvg,
#    subset=True,
#    layer="counts",
#    flavor="seurat_v3",
#    batch_key="dataset_id",
# )
adata.obs
batch cell_type artifact_uid predictions_scanvi
AAACCTGAGCTAGTGG-1 0 CD4 T cells cq4K7ESTSPnef6Ol0000 CD4 T cells
AAACCTGCACATTAGC-1 0 CD4 T cells cq4K7ESTSPnef6Ol0000 CD4 T cells
AAACCTGCACTGTTAG-1 0 Monocytes cq4K7ESTSPnef6Ol0000 Monocytes
AAACCTGCATAGTAAG-1 0 Monocytes cq4K7ESTSPnef6Ol0000 Monocytes
AAACCTGCATGAACCT-1 0 CD8 T cells cq4K7ESTSPnef6Ol0000 CD8 T cells
... ... ... ... ...
E2L8_TTTGTTGGTCGTGATT S CD8 T cells 0DZgfpOHWcVI3tgI0000 CD8 T cells
E2L8_TTTGTTGGTGTGCCTG G1 Monocytes 0DZgfpOHWcVI3tgI0000 Monocytes
E2L8_TTTGTTGGTTAGTTCG S B cells 0DZgfpOHWcVI3tgI0000 B cells
E2L8_TTTGTTGGTTGGCTAT G1 Monocytes 0DZgfpOHWcVI3tgI0000 Monocytes
E2L8_TTTGTTGTCTCATGGA G1 Monocytes 0DZgfpOHWcVI3tgI0000 Monocytes

164084 rows × 4 columns

scvi.model.SCVI.setup_anndata(adata, batch_key="batch", layer="counts")
# model_census3 = scvi.model.SCVI.load("census_model", adata=adata)
model_census3 = scvi.model.SCVI(adata)
start2 = time.time()
model_census3.train(
    max_epochs=100,
)
end2 = time.time()
print(f"Elapsed time: {end2 - start2:.2f} seconds")
Elapsed time: 548.92 seconds
model_census3.history["elbo_train"].tail()
elbo_train
epoch
95 1828.02417
96 1827.902954
97 1827.885864
98 1827.774658
99 1827.692383
adata.obsm["scvi_non_dataloder"] = model_census3.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scvi_non_dataloder", key_added="scvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scvi_non_dataloder")
sc.pl.umap(adata, color="batch", title="batch_SCVI_adata")
../../../../_images/2c3f8dae4eae91ecd042b1d6e0765dd8f10dbb31f0eb0c57053a20067b8f21c8.png
sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI_adata")
../../../../_images/5e9620c93362ebe97aaffef25fccea4a92960cff72e5e0e11dc2d715b74ecbba.png

scanvi#

adata
AnnData object with n_obs × n_vars = 164084 × 3315
    obs: 'batch', 'cell_type', 'artifact_uid', 'predictions_scanvi', '_scvi_batch', '_scvi_labels'
    uns: 'scvi', 'umap', 'batch_colors', 'cell_type_colors', 'scanvi', 'log1p', '_scvi_uuid', '_scvi_manager_uuid', 'scvi_non_dataloder'
    obsm: 'scvi', 'X_umap', 'scanvi', 'scvi_non_dataloder'
    layers: 'counts'
    obsp: 'scvi_distances', 'scvi_connectivities', 'scanvi_distances', 'scanvi_connectivities', 'scvi_non_dataloder_distances', 'scvi_non_dataloder_connectivities'
scvi.model.SCANVI.setup_anndata(
    adata,
    layer="counts",
    labels_key="cell_type",
    unlabeled_category="label_0",
    batch_key=batch_keys,
)
# model_census4 = scvi.model.SCVI.load("census_model", adata=adata)
model_census4 = scvi.model.SCANVI(adata)
start4 = time.time()
model_census4.train(
    max_epochs=100,
)
end4 = time.time()
print(f"Elapsed time: {end4 - start4:.2f} seconds")
INFO     Training for 100 epochs.
Elapsed time: 1110.17 seconds
model_census4.history["train_accuracy"].tail()
train_accuracy
epoch
95 0.99784
96 0.997176
97 0.997765
98 0.99765
99 0.997522
adata.obsm["scanvi_non_dataloder"] = model_census4.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scanvi_non_dataloder", key_added="scanvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scanvi_non_dataloder")
sc.pl.umap(adata, color=["batch"], title=["SCANVI__non_dataloder_" + x for x in ["batch"]])
../../../../_images/a8afd8b1f1b9161d18d70abde7ffb0baf31bc30e5657b7a3c6186253e71e6a6e.png
sc.pl.umap(adata, color="cell_type", title="SCANVI_non_dataloder")
../../../../_images/2211c4aee51a6d3bd15e0d2005812302df031d40e0f73963883e14206b74cc64.png
adata.obs["predictions_scanvi_non_dataloder"] = model_census4.predict()
df = (
    adata.obs.groupby(["cell_type", "predictions_scanvi_non_dataloder"])
    .size()
    .unstack(fill_value=0)
)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')
../../../../_images/1e208e00a8071ca8a2d8c618f19ef7eced0d78abf0a93cfb525e86cc0f5f50d6.png

Compute integration metrics

from scib_metrics.benchmark import Benchmarker

bm = Benchmarker(
    adata,
    batch_key="batch",
    label_key="cell_type",
    embedding_obsm_keys=["X_pca", "scvi", "scanvi", "scvi_non_dataloder", "scanvi_non_dataloder"],
    n_jobs=-1,
)
bm.benchmark()
INFO     8 clusters consist of a single batch or are too small. Skip.
INFO     8 clusters consist of a single batch or are too small. Skip.
INFO     8 clusters consist of a single batch or are too small. Skip.
INFO     8 clusters consist of a single batch or are too small. Skip.
INFO     8 clusters consist of a single batch or are too small. Skip.
bm.plot_results_table(min_max_scale=False)
../../../../_images/4d6b905939199410f79cc576b85e0818b2e09644227e7a099f22a5c7d2a20b85.png
<plottable.table.Table at 0x7438a6749940>