Train a scVI model using Lamin#
This notebook demonstrates a scalable approach to training an scVI model on Census data using Lamin dataloader. LaminDB is a database system based on its MappedCollection designed to support efficient storage, management, and querying of scientific data, particularly in machine learning, bioinformatics, and data science applications. It allows for the easy organization, sharing, and querying of complex datasets, such as those involved in research, experiments, or models. See here for more information
Note
Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.
!pip install --quiet scvi-colab
from scvi_colab import install
install()
import time
import scanpy as sc
import scvi
from scvi.dataloaders import MappedCollectionDataModule
→ connected lamindb: anonymous/lamindb_collection_scanvi
# os.system("lamin init --storage ./lamindb_collection") # one time for github runner (comment)
import lamindb as ln
# ln.setup.init()
pbmc_dataset = scvi.data.pbmc_dataset(
save_path=".",
remove_extracted_data=True,
)
INFO File ./gene_info_pbmc.csv already downloaded
INFO File ./pbmc_metadata.pickle already downloaded
INFO File ./pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO Extracting tar file
INFO Removing extracted data at ./pbmc8k/filtered_gene_bc_matrices
INFO File ./pbmc4k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO Extracting tar file
INFO Removing extracted data at ./pbmc4k/filtered_gene_bc_matrices
pbmc_seurat_v4_cite_seq = scvi.data.pbmc_seurat_v4_cite_seq(save_path=".")
INFO File ./pbmc_seurat_v4.h5ad already downloaded
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs.Phase
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs["batch"].astype("str")
pbmc_dataset.obs["batch"] = pbmc_dataset.obs["batch"].astype("str")
import numpy as np
gene_intersection = np.intersect1d(
pbmc_dataset.var.gene_symbols.values, pbmc_seurat_v4_cite_seq.var.index.values
)
pbmc_dataset_filtered = pbmc_dataset[:, pbmc_dataset.var["gene_symbols"].isin(gene_intersection)]
pbmc_seurat_v4_cite_seq_filtered = pbmc_seurat_v4_cite_seq[
:, pbmc_seurat_v4_cite_seq.var_names.isin(gene_intersection)
]
pbmc_dataset_filtered.var_names = pbmc_dataset_filtered.var["gene_symbols"].values
pbmc_dataset_filtered.obs["cell_type"] = pbmc_dataset_filtered.obs["str_labels"].astype("str")
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "FCGR3A+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "CD14+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "Megakaryocytes", "cell_type"
] = "Other"
pbmc_dataset_filtered
AnnData object with n_obs × n_vars = 11990 × 3315
obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type'
var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'
uns: 'cell_types'
obsm: 'design', 'raw_qc', 'normalized_qc', 'qc_pc'
pbmc_dataset_filtered.obs["cell_type"].value_counts()
cell_type
CD4 T cells 4996
Monocytes 2578
B cells 1621
CD8 T cells 1448
Other 551
NK cells 457
Dendritic Cells 339
Name: count, dtype: int64
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] = pbmc_seurat_v4_cite_seq_filtered.obs[
"celltype.l1"
].astype("str")
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "other", "cell_type"
] = "Other"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "B", "cell_type"
] = "B cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "DC", "cell_type"
] = "Dendritic Cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "NK", "cell_type"
] = "NK cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD4 T", "cell_type"
] = "CD4 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD8 T", "cell_type"
] = "CD8 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "Mono", "cell_type"
] = "Monocytes"
pbmc_seurat_v4_cite_seq_filtered
AnnData object with n_obs × n_vars = 152094 × 3315
obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'batch', 'cell_type'
var: 'mt'
obsm: 'protein_counts'
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"].value_counts()
cell_type
Monocytes 47217
CD4 T cells 40245
CD8 T cells 24399
NK cells 15384
B cells 13192
other T 5828
Dendritic Cells 3488
Other 2341
Name: count, dtype: int64
ln.track()
→ created Transform('hUtoQ4gffICo0000'), started new Run('miLp7BSR...') at 2025-05-08 09:03:42 UTC
# prepare test data
# adata1 = synthetic_iid()
# adata2 = synthetic_iid()
artifact1 = ln.Artifact.from_anndata(pbmc_dataset_filtered, key="part_one1.h5ad").save()
artifact2 = ln.Artifact.from_anndata(pbmc_seurat_v4_cite_seq_filtered, key="part_two1.h5ad").save()
collection = ln.Collection([artifact1, artifact2], key="gather")
collection.save()
Collection(uid='WYhawczHOi6K1dYG0004', is_latest=True, key='gather', hash='wPM0DpO-bWnuNH5Kqd1zEA', space_id=1, created_by_id=1, run_id=1, created_at=2025-05-08 09:03:45 UTC)
# We load the collection to see it consists of many h5ad files
artifacts = collection.artifacts.all()
artifacts.df()
uid | key | description | suffix | kind | otype | size | hash | n_files | n_observations | _hash_type | _key_is_virtual | _overwrite_versions | space_id | storage_id | schema_id | version | is_latest | run_id | created_at | created_by_id | _aux | _branch_code | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
id | |||||||||||||||||||||||
9 | cq4K7ESTSPnef6Ol0000 | part_one1.h5ad | None | .h5ad | dataset | AnnData | 49787918 | Qa1fqQcaprCKNBrFOF1l1g | None | 11990 | md5 | True | False | 1 | 1 | None | None | True | 1 | 2025-05-08 09:03:44.063000+00:00 | 1 | None | 1 |
10 | 0DZgfpOHWcVI3tgI0000 | part_two1.h5ad | None | .h5ad | dataset | AnnData | 998889103 | DozCh58Pwbgf9-wzsb976e | None | 152094 | sha1-fl | True | False | 1 | 1 | None | None | True | 1 | 2025-05-08 09:03:44.913000+00:00 | 1 | None | 1 |
# we can now define the batch and data loader
batch_keys = "batch"
datamodule = MappedCollectionDataModule(
collection,
batch_key=batch_keys,
batch_size=1024,
shuffle=True,
join="inner",
)
print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch)
164084 3315 5
# print(datamodule.registry)
# Init the model
model = scvi.model.SCVI(registry=datamodule.registry)
# Training the model
import gc
gc.collect()
start = time.time()
model.train(
max_epochs=100,
batch_size=1024,
plan_kwargs={"lr": 0.003, "compile": False},
early_stopping=False,
datamodule=datamodule.inference_dataloader(),
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")
Elapsed time: 485.27 seconds
model.history["elbo_train"].tail()
elbo_train | |
---|---|
epoch | |
95 | 1853.080566 |
96 | 1852.66333 |
97 | 1852.367554 |
98 | 1852.291748 |
99 | 1852.092651 |
# Save the model
model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule)
model.history.keys()
dict_keys(['kl_weight', 'train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train'])
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_dataloader = datamodule.inference_dataloader()
latent = model.get_latent_representation(dataloader=inference_dataloader)
# We extract the adata of the model, to be able to plot it
adata = collection.load(join="inner")
adata.obsm["scvi"] = latent
# adata.obs['cell_type'] = pbmc_dataset_filtered.obs.str_labels.values.tolist() + pbmc_seurat_v4_cite_seq_filtered.obs['celltype.l1'].values.tolist()
adata.obs
batch | cell_type | artifact_uid | |
---|---|---|---|
AAACCTGAGCTAGTGG-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 |
AAACCTGCACATTAGC-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 |
AAACCTGCACTGTTAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 |
AAACCTGCATAGTAAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 |
AAACCTGCATGAACCT-1 | 0 | CD8 T cells | cq4K7ESTSPnef6Ol0000 |
... | ... | ... | ... |
E2L8_TTTGTTGGTCGTGATT | S | CD8 T cells | 0DZgfpOHWcVI3tgI0000 |
E2L8_TTTGTTGGTGTGCCTG | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
E2L8_TTTGTTGGTTAGTTCG | S | B cells | 0DZgfpOHWcVI3tgI0000 |
E2L8_TTTGTTGGTTGGCTAT | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
E2L8_TTTGTTGTCTCATGGA | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
164084 rows × 3 columns
# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCVI")

sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI")

scanvi#
labels_keys = "cell_type"
datamodule_scanvi = MappedCollectionDataModule(
collection,
batch_key=batch_keys,
label_key=labels_keys,
batch_size=1024,
shuffle=True,
model_name="SCANVI",
join="inner",
)
print(
datamodule_scanvi.n_obs,
datamodule_scanvi.n_vars,
datamodule_scanvi.n_batch,
datamodule_scanvi.n_labels,
)
164084 3315 5 9
# We can now create the scanVI model object and train it:
datamodule_scanvi.setup(stage="train")
model_scanvi = scvi.model.SCANVI(
adata=None,
registry=datamodule_scanvi.registry,
datamodule=datamodule_scanvi,
)
# datamodule_scanvi.registry
# Training the model
import gc
gc.collect()
start3 = time.time()
model_scanvi.train(
max_epochs=20,
batch_size=1024,
plan_kwargs={"lr": 0.01, "compile": False},
early_stopping=False,
n_samples_per_label=100,
datamodule=datamodule_scanvi,
)
end3 = time.time()
print(f"Elapsed time: {end3 - start3:.2f} seconds")
INFO Training for 20 epochs.
Elapsed time: 135.02 seconds
# Save the model
model_scanvi.save(
"lamin_scanvi_model", save_anndata=False, overwrite=True, datamodule=datamodule_scanvi
)
# from pprint import pprint
# pprint(model_scanvi.registry)
model_scanvi.history.keys()
dict_keys(['train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train', 'train_classification_loss', 'train_accuracy', 'train_f1_score', 'train_calibration_error'])
model_scanvi.history["train_accuracy"].tail()
train_accuracy | |
---|---|
epoch | |
15 | 0.993022 |
16 | 0.993101 |
17 | 0.993912 |
18 | 0.993973 |
19 | 0.994369 |
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_scanvi_dataloader = datamodule_scanvi.inference_dataloader()
latent_scanvi = model_scanvi.get_latent_representation(dataloader=inference_scanvi_dataloader)
adata.obsm["scanvi"] = latent_scanvi
# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scanvi", key_added="scanvi")
sc.tl.umap(adata, neighbors_key="scanvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCANVI")

sc.pl.umap(adata, color="cell_type", title="cell_type_SCANVI")

adata.obs["predictions_scanvi"] = model_scanvi.predict(
dataloader=inference_scanvi_dataloader, batch_size=1024
)
# adata.obs["cell_type"]
# adata.obs["predictions_scanvi"]
df = adata.obs.groupby(["cell_type", "predictions_scanvi"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')

run regulary using adata and compare#
adata.layers["counts"] = adata.X.copy() # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
# sc.pp.highly_variable_genes(
# adata,
# n_top_genes=top_n_hvg,
# subset=True,
# layer="counts",
# flavor="seurat_v3",
# batch_key="dataset_id",
# )
adata.obs
batch | cell_type | artifact_uid | predictions_scanvi | |
---|---|---|---|---|
AAACCTGAGCTAGTGG-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 | CD4 T cells |
AAACCTGCACATTAGC-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 | CD4 T cells |
AAACCTGCACTGTTAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 | Monocytes |
AAACCTGCATAGTAAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 | Monocytes |
AAACCTGCATGAACCT-1 | 0 | CD8 T cells | cq4K7ESTSPnef6Ol0000 | CD8 T cells |
... | ... | ... | ... | ... |
E2L8_TTTGTTGGTCGTGATT | S | CD8 T cells | 0DZgfpOHWcVI3tgI0000 | CD8 T cells |
E2L8_TTTGTTGGTGTGCCTG | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
E2L8_TTTGTTGGTTAGTTCG | S | B cells | 0DZgfpOHWcVI3tgI0000 | B cells |
E2L8_TTTGTTGGTTGGCTAT | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
E2L8_TTTGTTGTCTCATGGA | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
164084 rows × 4 columns
scvi.model.SCVI.setup_anndata(adata, batch_key="batch", layer="counts")
# model_census3 = scvi.model.SCVI.load("census_model", adata=adata)
model_census3 = scvi.model.SCVI(adata)
start2 = time.time()
model_census3.train(
max_epochs=100,
)
end2 = time.time()
print(f"Elapsed time: {end2 - start2:.2f} seconds")
Elapsed time: 548.92 seconds
model_census3.history["elbo_train"].tail()
elbo_train | |
---|---|
epoch | |
95 | 1828.02417 |
96 | 1827.902954 |
97 | 1827.885864 |
98 | 1827.774658 |
99 | 1827.692383 |
adata.obsm["scvi_non_dataloder"] = model_census3.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scvi_non_dataloder", key_added="scvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scvi_non_dataloder")
sc.pl.umap(adata, color="batch", title="batch_SCVI_adata")

sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI_adata")

scanvi#
adata
AnnData object with n_obs × n_vars = 164084 × 3315
obs: 'batch', 'cell_type', 'artifact_uid', 'predictions_scanvi', '_scvi_batch', '_scvi_labels'
uns: 'scvi', 'umap', 'batch_colors', 'cell_type_colors', 'scanvi', 'log1p', '_scvi_uuid', '_scvi_manager_uuid', 'scvi_non_dataloder'
obsm: 'scvi', 'X_umap', 'scanvi', 'scvi_non_dataloder'
layers: 'counts'
obsp: 'scvi_distances', 'scvi_connectivities', 'scanvi_distances', 'scanvi_connectivities', 'scvi_non_dataloder_distances', 'scvi_non_dataloder_connectivities'
scvi.model.SCANVI.setup_anndata(
adata,
layer="counts",
labels_key="cell_type",
unlabeled_category="label_0",
batch_key=batch_keys,
)
# model_census4 = scvi.model.SCVI.load("census_model", adata=adata)
model_census4 = scvi.model.SCANVI(adata)
start4 = time.time()
model_census4.train(
max_epochs=100,
)
end4 = time.time()
print(f"Elapsed time: {end4 - start4:.2f} seconds")
INFO Training for 100 epochs.
Elapsed time: 1110.17 seconds
model_census4.history["train_accuracy"].tail()
train_accuracy | |
---|---|
epoch | |
95 | 0.99784 |
96 | 0.997176 |
97 | 0.997765 |
98 | 0.99765 |
99 | 0.997522 |
adata.obsm["scanvi_non_dataloder"] = model_census4.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scanvi_non_dataloder", key_added="scanvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scanvi_non_dataloder")
sc.pl.umap(adata, color=["batch"], title=["SCANVI__non_dataloder_" + x for x in ["batch"]])

sc.pl.umap(adata, color="cell_type", title="SCANVI_non_dataloder")

adata.obs["predictions_scanvi_non_dataloder"] = model_census4.predict()
df = (
adata.obs.groupby(["cell_type", "predictions_scanvi_non_dataloder"])
.size()
.unstack(fill_value=0)
)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')

Compute integration metrics
from scib_metrics.benchmark import Benchmarker
bm = Benchmarker(
adata,
batch_key="batch",
label_key="cell_type",
embedding_obsm_keys=["X_pca", "scvi", "scanvi", "scvi_non_dataloder", "scanvi_non_dataloder"],
n_jobs=-1,
)
bm.benchmark()
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
bm.plot_results_table(min_max_scale=False)

<plottable.table.Table at 0x7438a6749940>