scBasset: Batch correction of scATACseq data#
Warning
SCBASSET’s development is still in progress. The current version may not fully reproduce the original implementation’s results.
In addition to performing representation learning on scATAC-seq data, scBasset can also be used to integrate data across several samples. This tutorial walks through the following:
Loading the dataset
Preprocessing the dataset with
scanpy
Setting up and training the model
Visualizing the batch-corrected latent space with
scanpy
Quantifying integration performance with
scib-metrics
Note
Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.
!pip install --quiet scvi-colab
from scvi_colab import install
install()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
import tempfile
import matplotlib.pyplot as plt
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scib_metrics.benchmark import Benchmarker
scvi.settings.seed = 0
sc.set_figure_params(figsize=(4, 4), frameon=False)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.1.0
Note
You can modify save_dir
below to change where the data files for this tutorial are saved.
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()
%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"
Loading the dataset#
We will use the dataset from Buenrostro et al., 2018 throughout this tutorial, which contains single-cell chromatin accessibility profiles across 10 populations of human hematopoietic cell types.
adata = sc.read(
"data/buen_ad_sc.h5ad",
backup_url="https://storage.googleapis.com/scbasset_tutorial_data/buen_ad_sc.h5ad",
)
adata
AnnData object with n_obs × n_vars = 2034 × 103151
obs: 'cell_barcode', 'label', 'batch'
var: 'chr', 'start', 'end', 'n_cells'
uns: 'label_colors'
We see that batch information is stored in adata.obs["batch"]
. In this case, batches correspond to different donors.
BATCH_KEY = "batch"
adata.obs[BATCH_KEY].value_counts()
batch
BM0828 533
BM1077 507
BM1137 402
BM1214 298
BM0106 203
other 91
Name: count, dtype: int64
We also have author-provided cell type labels available.
LABEL_KEY = "label"
adata.obs[LABEL_KEY].value_counts()
label
CMP 502
GMP 402
HSC 347
LMPP 160
MPP 142
pDC 141
MEP 138
CLP 78
mono 64
UNK 60
Name: count, dtype: int64
Preprocessing the dataset#
We now use scanpy
to preprocess the data before giving it to the model. In our case, we filter out peaks that are rarely detected (detected in less than 5% of cells) in order to make the model train faster.
print("before filtering:", adata.shape)
min_cells = int(adata.n_obs * 0.05) # threshold: 5% of cells
sc.pp.filter_genes(adata, min_cells=min_cells) # in-place filtering of regions
print("after filtering:", adata.shape)
before filtering: (2034, 103151)
after filtering: (2034, 33247)
Taking a look at adata.var
, we see that this dataset has already been processed to include the start
and end
positions of each peak, as well as the chromosomes on which they are located.
adata.var.sample(10)
chr | start | end | n_cells | |
---|---|---|---|---|
218963 | chr8 | 121761544 | 121762104 | 107 |
227586 | chr9 | 117167843 | 117168397 | 125 |
223385 | chr9 | 34986390 | 34987016 | 470 |
90362 | chr17 | 15602531 | 15603282 | 542 |
48102 | chr12 | 14537791 | 14538412 | 111 |
83864 | chr16 | 29634123 | 29634443 | 110 |
206831 | chr7 | 112030880 | 112032276 | 390 |
176756 | chr5 | 72143780 | 72145204 | 363 |
100447 | chr18 | 29599335 | 29600153 | 265 |
23121 | chr10 | 11217571 | 11218248 | 102 |
We will use this information to add DNA sequences into adata.varm
. This can be performed in-place with scvi.data.add_dna_sequence
.
scvi.data.add_dna_sequence(
adata,
chr_var_key="chr",
start_var_key="start",
end_var_key="end",
genome_name="hg19",
genome_dir="data",
)
adata
Working...: 100%|██████████| 24/24 [00:01<00:00, 13.53it/s]
AnnData object with n_obs × n_vars = 2034 × 33247
obs: 'cell_barcode', 'label', 'batch'
var: 'chr', 'start', 'end', 'n_cells'
uns: 'label_colors'
varm: 'dna_sequence', 'dna_code'
The function adds two new fields into adata.varm
: dna_sequence
, containing bases for each position, and dna_code
, containing bases encoded as integers.
adata.varm["dna_sequence"]
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 1334 | 1335 | 1336 | 1337 | 1338 | 1339 | 1340 | 1341 | 1342 | 1343 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | N | N | N | N | N | N | N | N | N | N | ... | A | G | C | C | G | G | G | C | A | C |
3 | A | A | G | G | A | C | A | C | T | C | ... | C | A | G | A | A | C | A | T | A | C |
5 | T | T | C | C | C | A | A | T | T | C | ... | C | T | T | G | G | T | T | G | T | G |
8 | A | A | G | A | G | G | T | T | T | A | ... | C | C | A | C | C | C | A | G | G | A |
9 | T | T | T | C | G | T | C | A | T | G | ... | A | C | T | G | A | A | A | C | C | C |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
237371 | C | T | G | C | A | G | G | C | T | G | ... | G | A | C | C | A | G | C | C | T | G |
237383 | C | T | G | A | T | A | A | G | C | T | ... | G | C | T | C | T | T | T | C | T | C |
237399 | T | A | A | G | C | C | A | T | G | A | ... | T | T | T | C | C | T | T | G | T | T |
237425 | T | T | T | T | T | T | G | C | T | A | ... | T | T | G | A | A | G | T | T | T | G |
237449 | G | G | T | T | G | G | G | G | T | T | ... | N | N | N | N | N | N | N | N | N | N |
33247 rows × 1344 columns
Setting up and training the model#
Now, we are readyto register our data with scvi
. We set up our data with the model using setup_anndata
, which will ensure everything the model needs is in place for training.
In this stage, we can condition the model on covariates, which encourages the model to remove the impact of those covariates from the learned latent space. Since we are integrating our data across donors, we set the batch_key
argument to the key in adata.obs
that contains donor information (in our case, just "batch"
).
Additionally, since scBasset considers training mini-batches across regions rather than observations, we transpose the data prior to giving it to the model. The model also expects binary accessibility data, so we add a new layer with binary information.
bdata = adata.transpose()
bdata.layers["binary"] = (bdata.X.copy() > 0).astype(float)
scvi.external.SCBASSET.setup_anndata(
bdata, layer="binary", dna_code_key="dna_code", batch_key=BATCH_KEY
)
INFO Using column names from columns of adata.obsm['dna_code']
We now create the model. We use a non-default argument (l2_reg_cell_embedding
), which is designed to aid integration of scATAC-seq data.
model = scvi.external.SCBASSET(bdata, l2_reg_cell_embedding=1e-8)
model.view_anndata_setup()
Anndata setup with scvi-tools version 1.1.0.
Setup via `SCBASSET.setup_anndata` with arguments:
{'dna_code_key': 'dna_code', 'layer': 'binary', 'batch_key': 'batch'}
Summary Statistics ┏━━━━━━━━━━━━━━━━━━┳━━━━━━━┓ ┃ Summary Stat Key ┃ Value ┃ ┡━━━━━━━━━━━━━━━━━━╇━━━━━━━┩ │ n_batch │ 6 │ │ n_cells │ 33247 │ │ n_dna_code │ 1344 │ │ n_vars │ 2034 │ └──────────────────┴───────┘
Data Registry ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Registry Key ┃ scvi-tools Location ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ X │ adata.layers['binary'] │ │ batch │ adata.var['_scvi_batch'] │ │ dna_code │ adata.obsm['dna_code'] │ └──────────────┴──────────────────────────┘
batch State Registry ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ ┃ Source Location ┃ Categories ┃ scvi-tools Encoding ┃ ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ │ adata.var['batch'] │ BM0106 │ 0 │ │ │ BM0828 │ 1 │ │ │ BM1077 │ 2 │ │ │ BM1137 │ 3 │ │ │ BM1214 │ 4 │ │ │ other │ 5 │ └────────────────────┴────────────┴─────────────────────┘
model.train()
Epoch 1/1000: 0%| | 0/1000 [00:00<?, ?it/s]
Epoch 1000/1000: 100%|██████████| 1000/1000 [3:04:42<00:00, 11.02s/it, v_num=1, train_loss_step=0.316, train_loss_epoch=0.319]
Epoch 1000/1000: 100%|██████████| 1000/1000 [3:04:42<00:00, 11.08s/it, v_num=1, train_loss_step=0.316, train_loss_epoch=0.319]
Visualizing the batch-corrected latent space#
After training, we retrieve the integrated latent space and save it into adata.obsm
.
LATENT_KEY = "X_scbasset"
adata.obsm[LATENT_KEY] = model.get_latent_representation()
adata.obsm[LATENT_KEY].shape
(2034, 32)
Now, we use scanpy
to visualize the latent space by first computing the k-nearest-neighbor graph and then computing its TSNE representation with parameters to reproduce the original scBasset tutorial for this dataset.
sc.pp.neighbors(adata, use_rep=LATENT_KEY)
sc.tl.umap(adata, min_dist=1.0)
Quantifying integration performance#
Here we use the scib-metrics package, which contains scalable implementations of the metrics used in the scIB benchmarking suite. We can use these metrics to assess the quality of the integration.
bm = Benchmarker(
adata,
batch_key=BATCH_KEY,
label_key=LABEL_KEY,
embedding_obsm_keys=[LATENT_KEY],
n_jobs=-1,
)
bm.benchmark()
INFO UNK consists of a single batch or is too small. Skip.
INFO mono consists of a single batch or is too small. Skip.
df = bm.get_results(min_max_scale=False)
df
Isolated labels | KMeans NMI | KMeans ARI | Silhouette label | cLISI | Silhouette batch | iLISI | KBET | Graph connectivity | PCR comparison | Batch correction | Bio conservation | Total | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Embedding | |||||||||||||
X_scbasset | 0.56907 | 0.553789 | 0.411057 | 0.514124 | 0.952487 | 0.870874 | 0.099883 | 0.134311 | 0.85471 | 0 | 0.391956 | 0.600105 | 0.516845 |
Metric Type | Bio conservation | Bio conservation | Bio conservation | Bio conservation | Bio conservation | Batch correction | Batch correction | Batch correction | Batch correction | Batch correction | Aggregate score | Aggregate score | Aggregate score |