Train a scVI model using Lamin#
This notebook demonstrates a scalable approach to training an scVI model on Census data using Lamin dataloader. LaminDB is a database system based on its MappedCollection designed to support efficient storage, management, and querying of scientific data, particularly in machine learning, bioinformatics, and data science applications. It allows for the easy organization, sharing, and querying of complex datasets, such as those involved in research, experiments, or models. See here for more information
Note
Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.
!pip install --quiet scvi-colab
from scvi_colab import install
install()
import time
import scanpy as sc
import scvi
from scvi.dataloaders import MappedCollectionDataModule
→ connected lamindb: anonymous/lamindb_collection_scanvi
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Well start by init the lamindb backend and read the known PBMC data
# os.system("lamin init --storage ./lamindb_collection")
import lamindb as ln
# ln.setup.init()
pbmc_dataset = scvi.data.pbmc_dataset(
save_path=".",
remove_extracted_data=True,
)
INFO File ./gene_info_pbmc.csv already downloaded
INFO File ./pbmc_metadata.pickle already downloaded
INFO File ./pbmc8k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO Extracting tar file
INFO Removing extracted data at ./pbmc8k/filtered_gene_bc_matrices
INFO File ./pbmc4k/filtered_gene_bc_matrices.tar.gz already downloaded
INFO Extracting tar file
INFO Removing extracted data at ./pbmc4k/filtered_gene_bc_matrices
pbmc_seurat_v4_cite_seq = scvi.data.pbmc_seurat_v4_cite_seq(save_path=".")
INFO File ./pbmc_seurat_v4.h5ad already downloaded
Preprocessing of the data#
In this case we read 2 PBMC data so that we will later show the integration power. We will select the intersection of those 2 datasets gene names, and consolidate cell types names so that they will be alligned
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs.Phase
pbmc_seurat_v4_cite_seq.obs["batch"] = pbmc_seurat_v4_cite_seq.obs["batch"].astype("str")
pbmc_dataset.obs["batch"] = pbmc_dataset.obs["batch"].astype("str")
import numpy as np
gene_intersection = np.intersect1d(
pbmc_dataset.var.gene_symbols.values, pbmc_seurat_v4_cite_seq.var.index.values
)
pbmc_dataset_filtered = pbmc_dataset[:, pbmc_dataset.var["gene_symbols"].isin(gene_intersection)]
pbmc_seurat_v4_cite_seq_filtered = pbmc_seurat_v4_cite_seq[
:, pbmc_seurat_v4_cite_seq.var_names.isin(gene_intersection)
]
pbmc_dataset_filtered.var_names = pbmc_dataset_filtered.var["gene_symbols"].values
pbmc_dataset_filtered.obs["cell_type"] = pbmc_dataset_filtered.obs["str_labels"].astype("str")
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "FCGR3A+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "CD14+ Monocytes", "cell_type"
] = "Monocytes"
pbmc_dataset_filtered.obs.loc[
pbmc_dataset_filtered.obs["cell_type"] == "Megakaryocytes", "cell_type"
] = "Other"
pbmc_dataset_filtered
AnnData object with n_obs × n_vars = 11990 × 3315
obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type'
var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts'
uns: 'cell_types'
obsm: 'design', 'raw_qc', 'normalized_qc', 'qc_pc'
The list of different cell types for the first dataaset can be seen
pbmc_dataset_filtered.obs["cell_type"].value_counts()
cell_type
CD4 T cells 4996
Monocytes 2578
B cells 1621
CD8 T cells 1448
Other 551
NK cells 457
Dendritic Cells 339
Name: count, dtype: int64
We will repeat for the other dataset
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] = pbmc_seurat_v4_cite_seq_filtered.obs[
"celltype.l1"
].astype("str")
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "other", "cell_type"
] = "Other"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "B", "cell_type"
] = "B cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "DC", "cell_type"
] = "Dendritic Cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "NK", "cell_type"
] = "NK cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD4 T", "cell_type"
] = "CD4 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "CD8 T", "cell_type"
] = "CD8 T cells"
pbmc_seurat_v4_cite_seq_filtered.obs.loc[
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"] == "Mono", "cell_type"
] = "Monocytes"
pbmc_seurat_v4_cite_seq_filtered
AnnData object with n_obs × n_vars = 152094 × 3315
obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'batch', 'cell_type'
var: 'mt'
obsm: 'protein_counts'
pbmc_seurat_v4_cite_seq_filtered.obs["cell_type"].value_counts()
cell_type
Monocytes 47217
CD4 T cells 40245
CD8 T cells 24399
NK cells 15384
B cells 13192
other T 5828
Dendritic Cells 3488
Other 2341
Name: count, dtype: int64
In the next part we are creating artifacts from those adata’s and unite them to a collection. Artifacts and collections are the ways lamindb interactes with the data. From this point forward we will not use adatas again.
ln.track()
→ created Transform('hUtoQ4gffICo0000'), started new Run('miLp7BSR...') at 2025-05-08 09:03:42 UTC
artifact1 = ln.Artifact.from_anndata(pbmc_dataset_filtered, key="part_one1.h5ad").save()
artifact2 = ln.Artifact.from_anndata(pbmc_seurat_v4_cite_seq_filtered, key="part_two1.h5ad").save()
collection = ln.Collection([artifact1, artifact2], key="gather")
collection.save()
Collection(uid='WYhawczHOi6K1dYG0004', is_latest=True, key='gather', hash='wPM0DpO-bWnuNH5Kqd1zEA', space_id=1, created_by_id=1, run_id=1, created_at=2025-05-08 09:03:45 UTC)
# We load the collection to see it consists of many h5ad files
artifacts = collection.artifacts.all()
artifacts.df()
| uid | key | description | suffix | kind | otype | size | hash | n_files | n_observations | _hash_type | _key_is_virtual | _overwrite_versions | space_id | storage_id | schema_id | version | is_latest | run_id | created_at | created_by_id | _aux | _branch_code | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| id | |||||||||||||||||||||||
| 9 | cq4K7ESTSPnef6Ol0000 | part_one1.h5ad | None | .h5ad | dataset | AnnData | 49787918 | Qa1fqQcaprCKNBrFOF1l1g | None | 11990 | md5 | True | False | 1 | 1 | None | None | True | 1 | 2025-05-08 09:03:44.063000+00:00 | 1 | None | 1 |
| 10 | 0DZgfpOHWcVI3tgI0000 | part_two1.h5ad | None | .h5ad | dataset | AnnData | 998889103 | DozCh58Pwbgf9-wzsb976e | None | 152094 | sha1-fl | True | False | 1 | 1 | None | None | True | 1 | 2025-05-08 09:03:44.913000+00:00 | 1 | None | 1 |
we can now define the batch and data loader which replaces the default AnnDataloder of SCVI.
batch_keys = "batch"
datamodule = MappedCollectionDataModule(
collection,
batch_key=batch_keys,
batch_size=1024,
shuffle=True,
join="inner",
)
print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch)
164084 3315 5
From here we continue like always, define the model (with the registry and not AnnDataManager ) and train it
# Init the model
model = scvi.model.SCVI(registry=datamodule.registry)
# Training the model
import gc
gc.collect()
start = time.time()
model.train(
max_epochs=100,
batch_size=1024,
plan_kwargs={"lr": 0.003, "compile": False},
early_stopping=False,
datamodule=datamodule.inference_dataloader(),
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")
Elapsed time: 485.27 seconds
model.history["elbo_train"].tail()
| elbo_train | |
|---|---|
| epoch | |
| 95 | 1853.080566 |
| 96 | 1852.66333 |
| 97 | 1852.367554 |
| 98 | 1852.291748 |
| 99 | 1852.092651 |
# Save the model
model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule)
model.history.keys()
dict_keys(['kl_weight', 'train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train'])
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_dataloader = datamodule.inference_dataloader()
latent = model.get_latent_representation(dataloader=inference_dataloader)
# We extract the adata of the model, to be able to use it for plot umaps
# To save time we could also select a sub set of it
adata = collection.load(join="inner")
adata.obsm["scvi"] = latent
adata.obs
| batch | cell_type | artifact_uid | |
|---|---|---|---|
| AAACCTGAGCTAGTGG-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 |
| AAACCTGCACATTAGC-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 |
| AAACCTGCACTGTTAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 |
| AAACCTGCATAGTAAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 |
| AAACCTGCATGAACCT-1 | 0 | CD8 T cells | cq4K7ESTSPnef6Ol0000 |
| ... | ... | ... | ... |
| E2L8_TTTGTTGGTCGTGATT | S | CD8 T cells | 0DZgfpOHWcVI3tgI0000 |
| E2L8_TTTGTTGGTGTGCCTG | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
| E2L8_TTTGTTGGTTAGTTCG | S | B cells | 0DZgfpOHWcVI3tgI0000 |
| E2L8_TTTGTTGGTTGGCTAT | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
| E2L8_TTTGTTGTCTCATGGA | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 |
164084 rows × 3 columns
# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCVI")
sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI")
scanvi#
We will repeat the process just did for SCVI to run a SCANVI model
labels_keys = "cell_type"
datamodule_scanvi = MappedCollectionDataModule(
collection,
batch_key=batch_keys,
label_key=labels_keys,
batch_size=1024,
shuffle=True,
model_name="SCANVI",
join="inner",
)
164084 3315 5 9
print(
datamodule_scanvi.n_obs,
datamodule_scanvi.n_vars,
datamodule_scanvi.n_batch,
datamodule_scanvi.n_labels,
)
# We can now create the scanVI model object and train it:
datamodule_scanvi.setup(stage="train")
model_scanvi = scvi.model.SCANVI(
adata=None,
registry=datamodule_scanvi.registry,
datamodule=datamodule_scanvi,
)
# Training the model
import gc
gc.collect()
start3 = time.time()
model_scanvi.train(
max_epochs=20,
batch_size=1024,
plan_kwargs={"lr": 0.01, "compile": False},
early_stopping=False,
n_samples_per_label=100,
datamodule=datamodule_scanvi,
)
end3 = time.time()
print(f"Elapsed time: {end3 - start3:.2f} seconds")
INFO Training for 20 epochs.
Elapsed time: 135.02 seconds
# Save the model
model_scanvi.save(
"lamin_scanvi_model", save_anndata=False, overwrite=True, datamodule=datamodule_scanvi
)
model_scanvi.history.keys()
dict_keys(['train_loss_step', 'train_loss_epoch', 'elbo_train', 'reconstruction_loss_train', 'kl_local_train', 'kl_global_train', 'train_classification_loss', 'train_accuracy', 'train_f1_score', 'train_calibration_error'])
model_scanvi.history["train_accuracy"].tail()
| train_accuracy | |
|---|---|
| epoch | |
| 15 | 0.993022 |
| 16 | 0.993101 |
| 17 | 0.993912 |
| 18 | 0.993973 |
| 19 | 0.994369 |
# The way to extract the internal model analysis is by the inference_dataloader
# Datamodule will always require to pass it into all downstream functions.
inference_scanvi_dataloader = datamodule_scanvi.inference_dataloader()
latent_scanvi = model_scanvi.get_latent_representation(dataloader=inference_scanvi_dataloader)
adata.obsm["scanvi"] = latent_scanvi
# We can now generate the neighbors and the UMAP.
sc.pp.neighbors(adata, use_rep="scanvi", key_added="scanvi")
sc.tl.umap(adata, neighbors_key="scanvi")
sc.pl.umap(adata, color=batch_keys, title="batch_SCANVI")
sc.pl.umap(adata, color="cell_type", title="cell_type_SCANVI")
Beucase its a scanvi model we can also produce the cell type predictions now
adata.obs["predictions_scanvi"] = model_scanvi.predict(
dataloader=inference_scanvi_dataloader, batch_size=1024
)
df = adata.obs.groupby(["cell_type", "predictions_scanvi"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')
Run regulary using adata and compare#
We will use the adata we already extracted and train an SCVI and SCANVI models under the same conditions as was done for Lamin, in order to compare the results
adata.layers["counts"] = adata.X.copy() # preserve counts
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
adata.obs
| batch | cell_type | artifact_uid | predictions_scanvi | |
|---|---|---|---|---|
| AAACCTGAGCTAGTGG-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 | CD4 T cells |
| AAACCTGCACATTAGC-1 | 0 | CD4 T cells | cq4K7ESTSPnef6Ol0000 | CD4 T cells |
| AAACCTGCACTGTTAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 | Monocytes |
| AAACCTGCATAGTAAG-1 | 0 | Monocytes | cq4K7ESTSPnef6Ol0000 | Monocytes |
| AAACCTGCATGAACCT-1 | 0 | CD8 T cells | cq4K7ESTSPnef6Ol0000 | CD8 T cells |
| ... | ... | ... | ... | ... |
| E2L8_TTTGTTGGTCGTGATT | S | CD8 T cells | 0DZgfpOHWcVI3tgI0000 | CD8 T cells |
| E2L8_TTTGTTGGTGTGCCTG | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
| E2L8_TTTGTTGGTTAGTTCG | S | B cells | 0DZgfpOHWcVI3tgI0000 | B cells |
| E2L8_TTTGTTGGTTGGCTAT | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
| E2L8_TTTGTTGTCTCATGGA | G1 | Monocytes | 0DZgfpOHWcVI3tgI0000 | Monocytes |
164084 rows × 4 columns
scvi.model.SCVI.setup_anndata(adata, batch_key="batch", layer="counts")
# model_census3 = scvi.model.SCVI.load("census_model", adata=adata)
model_census3 = scvi.model.SCVI(adata)
start2 = time.time()
model_census3.train(
max_epochs=100,
)
end2 = time.time()
print(f"Elapsed time: {end2 - start2:.2f} seconds")
Elapsed time: 548.92 seconds
We can see that under same conditions, lamin training was faster by about 10% than using the AnnDataLoader
model_census3.history["elbo_train"].tail()
| elbo_train | |
|---|---|
| epoch | |
| 95 | 1828.02417 |
| 96 | 1827.902954 |
| 97 | 1827.885864 |
| 98 | 1827.774658 |
| 99 | 1827.692383 |
adata.obsm["scvi_non_dataloder"] = model_census3.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scvi_non_dataloder", key_added="scvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scvi_non_dataloder")
sc.pl.umap(adata, color="batch", title="batch_SCVI_adata")
sc.pl.umap(adata, color="cell_type", title="cell_type_SCVI_adata")
scanvi (regular)#
adata
AnnData object with n_obs × n_vars = 164084 × 3315
obs: 'batch', 'cell_type', 'artifact_uid', 'predictions_scanvi', '_scvi_batch', '_scvi_labels'
uns: 'scvi', 'umap', 'batch_colors', 'cell_type_colors', 'scanvi', 'log1p', '_scvi_uuid', '_scvi_manager_uuid', 'scvi_non_dataloder'
obsm: 'scvi', 'X_umap', 'scanvi', 'scvi_non_dataloder'
layers: 'counts'
obsp: 'scvi_distances', 'scvi_connectivities', 'scanvi_distances', 'scanvi_connectivities', 'scvi_non_dataloder_distances', 'scvi_non_dataloder_connectivities'
scvi.model.SCANVI.setup_anndata(
adata,
layer="counts",
labels_key="cell_type",
unlabeled_category="label_0",
batch_key=batch_keys,
)
# model_census4 = scvi.model.SCVI.load("census_model", adata=adata)
model_census4 = scvi.model.SCANVI(adata)
start4 = time.time()
model_census4.train(
max_epochs=100,
)
end4 = time.time()
print(f"Elapsed time: {end4 - start4:.2f} seconds")
INFO Training for 100 epochs.
Elapsed time: 1110.17 seconds
model_census4.history["train_accuracy"].tail()
| train_accuracy | |
|---|---|
| epoch | |
| 95 | 0.99784 |
| 96 | 0.997176 |
| 97 | 0.997765 |
| 98 | 0.99765 |
| 99 | 0.997522 |
adata.obsm["scanvi_non_dataloder"] = model_census4.get_latent_representation()
sc.pp.neighbors(adata, use_rep="scanvi_non_dataloder", key_added="scanvi_non_dataloder")
sc.tl.umap(adata, neighbors_key="scanvi_non_dataloder")
sc.pl.umap(adata, color=["batch"], title=["SCANVI__non_dataloder_" + x for x in ["batch"]])
sc.pl.umap(adata, color="cell_type", title="SCANVI_non_dataloder")
adata.obs["predictions_scanvi_non_dataloder"] = model_census4.predict()
df = (
adata.obs.groupby(["cell_type", "predictions_scanvi_non_dataloder"])
.size()
.unstack(fill_value=0)
)
norm_df = df / df.sum(axis=0)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')
Compare results#
Compute integration metrics
from scib_metrics.benchmark import Benchmarker
bm = Benchmarker(
adata,
batch_key="batch",
label_key="cell_type",
embedding_obsm_keys=["X_pca", "scvi", "scanvi", "scvi_non_dataloder", "scanvi_non_dataloder"],
n_jobs=-1,
)
bm.benchmark()
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
INFO 8 clusters consist of a single batch or are too small. Skip.
bm.plot_results_table(min_max_scale=False)
<plottable.table.Table at 0x7438a6749940>
As expected SCANVI outperforms the SCVI using the labels data, however as can be seen the regular use of Anndataloader dataloader gives 5% better integration results comparing to the lamin dataloader