scBasset: Batch correction of scATACseq data#

Warning

SCBASSET’s development is still in progress. The current version may not fully reproduce the original implementation’s results.

In addition to performing representation learning on scATAC-seq data, scBasset can also be used to integrate data across several samples. This tutorial walks through the following:

  1. Loading the dataset

  2. Preprocessing the dataset with scanpy

  3. Setting up and training the model

  4. Visualizing the batch-corrected latent space with scanpy

  5. Quantifying integration performance with scib-metrics

Note

Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.

!pip install --quiet scvi-colab
from scvi_colab import install

install()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

import tempfile

import matplotlib.pyplot as plt
import scanpy as sc
import scvi
import seaborn as sns
import torch
from scib_metrics.benchmark import Benchmarker

scvi.settings.seed = 0
sc.set_figure_params(figsize=(4, 4), frameon=False)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
Last run with scvi-tools version: 1.1.0

Note

You can modify save_dir below to change where the data files for this tutorial are saved.

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

Loading the dataset#

We will use the dataset from Buenrostro et al., 2018 throughout this tutorial, which contains single-cell chromatin accessibility profiles across 10 populations of human hematopoietic cell types.

adata = sc.read(
    "data/buen_ad_sc.h5ad",
    backup_url="https://storage.googleapis.com/scbasset_tutorial_data/buen_ad_sc.h5ad",
)
adata
AnnData object with n_obs × n_vars = 2034 × 103151
    obs: 'cell_barcode', 'label', 'batch'
    var: 'chr', 'start', 'end', 'n_cells'
    uns: 'label_colors'

We see that batch information is stored in adata.obs["batch"]. In this case, batches correspond to different donors.

BATCH_KEY = "batch"
adata.obs[BATCH_KEY].value_counts()
batch
BM0828    533
BM1077    507
BM1137    402
BM1214    298
BM0106    203
other      91
Name: count, dtype: int64

We also have author-provided cell type labels available.

LABEL_KEY = "label"
adata.obs[LABEL_KEY].value_counts()
label
CMP     502
GMP     402
HSC     347
LMPP    160
MPP     142
pDC     141
MEP     138
CLP      78
mono     64
UNK      60
Name: count, dtype: int64

Preprocessing the dataset#

We now use scanpy to preprocess the data before giving it to the model. In our case, we filter out peaks that are rarely detected (detected in less than 5% of cells) in order to make the model train faster.

print("before filtering:", adata.shape)
min_cells = int(adata.n_obs * 0.05)  # threshold: 5% of cells
sc.pp.filter_genes(adata, min_cells=min_cells)  # in-place filtering of regions
print("after filtering:", adata.shape)
before filtering: (2034, 103151)
after filtering: (2034, 33247)

Taking a look at adata.var, we see that this dataset has already been processed to include the start and end positions of each peak, as well as the chromosomes on which they are located.

adata.var.sample(10)
chr start end n_cells
218963 chr8 121761544 121762104 107
227586 chr9 117167843 117168397 125
223385 chr9 34986390 34987016 470
90362 chr17 15602531 15603282 542
48102 chr12 14537791 14538412 111
83864 chr16 29634123 29634443 110
206831 chr7 112030880 112032276 390
176756 chr5 72143780 72145204 363
100447 chr18 29599335 29600153 265
23121 chr10 11217571 11218248 102

We will use this information to add DNA sequences into adata.varm. This can be performed in-place with scvi.data.add_dna_sequence.

scvi.data.add_dna_sequence(
    adata,
    chr_var_key="chr",
    start_var_key="start",
    end_var_key="end",
    genome_name="hg19",
    genome_dir="data",
)
adata
Working...: 100%|██████████| 24/24 [00:01<00:00, 13.53it/s]
AnnData object with n_obs × n_vars = 2034 × 33247
    obs: 'cell_barcode', 'label', 'batch'
    var: 'chr', 'start', 'end', 'n_cells'
    uns: 'label_colors'
    varm: 'dna_sequence', 'dna_code'

The function adds two new fields into adata.varm: dna_sequence, containing bases for each position, and dna_code, containing bases encoded as integers.

adata.varm["dna_sequence"]
0 1 2 3 4 5 6 7 8 9 ... 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343
0 N N N N N N N N N N ... A G C C G G G C A C
3 A A G G A C A C T C ... C A G A A C A T A C
5 T T C C C A A T T C ... C T T G G T T G T G
8 A A G A G G T T T A ... C C A C C C A G G A
9 T T T C G T C A T G ... A C T G A A A C C C
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
237371 C T G C A G G C T G ... G A C C A G C C T G
237383 C T G A T A A G C T ... G C T C T T T C T C
237399 T A A G C C A T G A ... T T T C C T T G T T
237425 T T T T T T G C T A ... T T G A A G T T T G
237449 G G T T G G G G T T ... N N N N N N N N N N

33247 rows × 1344 columns

Setting up and training the model#

Now, we are readyto register our data with scvi. We set up our data with the model using setup_anndata, which will ensure everything the model needs is in place for training.

In this stage, we can condition the model on covariates, which encourages the model to remove the impact of those covariates from the learned latent space. Since we are integrating our data across donors, we set the batch_key argument to the key in adata.obs that contains donor information (in our case, just "batch").

Additionally, since scBasset considers training mini-batches across regions rather than observations, we transpose the data prior to giving it to the model. The model also expects binary accessibility data, so we add a new layer with binary information.

bdata = adata.transpose()
bdata.layers["binary"] = (bdata.X.copy() > 0).astype(float)
scvi.external.SCBASSET.setup_anndata(
    bdata, layer="binary", dna_code_key="dna_code", batch_key=BATCH_KEY
)
INFO     Using column names from columns of adata.obsm['dna_code']                                                 

We now create the model. We use a non-default argument (l2_reg_cell_embedding), which is designed to aid integration of scATAC-seq data.

model = scvi.external.SCBASSET(bdata, l2_reg_cell_embedding=1e-8)
model.view_anndata_setup()
Anndata setup with scvi-tools version 1.1.0.

Setup via `SCBASSET.setup_anndata` with arguments:
{'dna_code_key': 'dna_code', 'layer': 'binary', 'batch_key': 'batch'}

     Summary Statistics     
┏━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃ Summary Stat Key  Value ┃
┡━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
│     n_batch         6   │
│     n_cells       33247 │
│    n_dna_code     1344  │
│      n_vars       2034  │
└──────────────────┴───────┘
               Data Registry               
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Registry Key    scvi-tools Location    ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      X         adata.layers['binary']  │
│    batch      adata.var['_scvi_batch'] │
│   dna_code     adata.obsm['dna_code']  │
└──────────────┴──────────────────────────┘
                  batch State Registry                   
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃  Source Location    Categories  scvi-tools Encoding ┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ adata.var['batch']    BM0106             0          │
│                       BM0828             1          │
│                       BM1077             2          │
│                       BM1137             3          │
│                       BM1214             4          │
│                       other              5          │
└────────────────────┴────────────┴─────────────────────┘
model.train()
Epoch 1/1000:   0%|          | 0/1000 [00:00<?, ?it/s]Epoch 1000/1000: 100%|██████████| 1000/1000 [3:04:42<00:00, 11.02s/it, v_num=1, train_loss_step=0.316, train_loss_epoch=0.319]Epoch 1000/1000: 100%|██████████| 1000/1000 [3:04:42<00:00, 11.08s/it, v_num=1, train_loss_step=0.316, train_loss_epoch=0.319]
fig, ax = plt.subplots()
model.history_["auroc_train"].plot(ax=ax)
model.history_["auroc_validation"].plot(ax=ax)
<Axes: xlabel='epoch'>
../../../_images/a035d5cc1d0f65bdb818193c68f83a5fd44d964e5af13ae5b15ffaa3a8695d05.png

Visualizing the batch-corrected latent space#

After training, we retrieve the integrated latent space and save it into adata.obsm.

LATENT_KEY = "X_scbasset"
adata.obsm[LATENT_KEY] = model.get_latent_representation()
adata.obsm[LATENT_KEY].shape
(2034, 32)

Now, we use scanpy to visualize the latent space by first computing the k-nearest-neighbor graph and then computing its TSNE representation with parameters to reproduce the original scBasset tutorial for this dataset.

sc.pp.neighbors(adata, use_rep=LATENT_KEY)
sc.tl.umap(adata, min_dist=1.0)
sc.pl.umap(adata, color=LABEL_KEY)
sc.pl.umap(adata, color=BATCH_KEY)

Quantifying integration performance#

Here we use the scib-metrics package, which contains scalable implementations of the metrics used in the scIB benchmarking suite. We can use these metrics to assess the quality of the integration.

bm = Benchmarker(
    adata,
    batch_key=BATCH_KEY,
    label_key=LABEL_KEY,
    embedding_obsm_keys=[LATENT_KEY],
    n_jobs=-1,
)
bm.benchmark()
INFO     UNK consists of a single batch or is too small. Skip.                                                     
INFO     mono consists of a single batch or is too small. Skip.                                                    
df = bm.get_results(min_max_scale=False)
df
Isolated labels KMeans NMI KMeans ARI Silhouette label cLISI Silhouette batch iLISI KBET Graph connectivity PCR comparison Batch correction Bio conservation Total
Embedding
X_scbasset 0.56907 0.553789 0.411057 0.514124 0.952487 0.870874 0.099883 0.134311 0.85471 0 0.391956 0.600105 0.516845
Metric Type Bio conservation Bio conservation Bio conservation Bio conservation Bio conservation Batch correction Batch correction Batch correction Batch correction Batch correction Aggregate score Aggregate score Aggregate score