DiagVI#
DiagVI (Diagonal multi-modal integration Variational Inference; Python class DIAGVI) is a deep generative model for diagonal integration of unpaired multi-modal single-cell data using prior knowledge about cross-modal feature correspondences. These relationships (such as associative or repressive interactions) are encoded in a guidance graph, where edge weights can represent both positive and negative covariation between features across modalities.
DiagVI is inspired by the GLUE[1] architecture, which uses modality-specific variational autoencoders (VAEs) to project heterogeneous data types into a shared latent space. In contrast to GLUE’s adversarial alignment strategy, DiagVI aligns modalities using Unbalanced Optimal Transport (UOT) via the Sinkhorn divergence[2], explicitly accounting for differences in cell-type composition across modalities.
The advantages of DiagVI are:
Flexible two-modality integration of various data types (e.g., scRNA-seq, spatial transcriptomics, spatial proteomics).
Full feature utilization: all features (not only overlapping ones) contribute to model training via the guidance graph and modality-specific VAEs.
Biologically informed alignment using prior feature correspondences via the guidance graph.
Robust integration of modality-specific or rare cell populations via UOT.
Scalable to very large datasets (>1 million cells).
The limitations of DiagVI include:
Currently supports integration of two modalities only.
Requires prior information on cross-modal feature correspondences (explicitly or implicitly).
May require tuning of loss weights for optimal performance (for more information, see Practical guidance).
Effectively requires a GPU for fast inference.
Note
DiagVI requires additional dependencies (geomloss, torch_geometric) that are not installed by default. To use DiagVI, install scvi-tools with the diagvi extra:
pip install scvi-tools[diagvi]
Preliminaries#
Input#
DiagVI takes as input expression matrices \(\mathbf{X}_1 \in \mathbb{R}^{N \times G}\) with \(N\) cells and \(G\) features and \(\mathbf{X}_2 \in \mathbb{R}^{M \times P}\) with \(M\) cells and \(P\) features from two unpaired modalities.
For count data, such as scRNA-seq, DiagVI expects a raw count expression matrix, where rows correspond to individual cells and columns correspond to features (e.g., genes).
For continuous data, such as antibody-based single-cell or spatial proteomics, DiagVI expects a transformed (and optionally scaled) protein expression matrix, where rows correspond to cells and columns to features (e.g., marker proteins). Because antibody-based measurements are inherently relative, preprocessing is required. Typically, transformations such as arcsinh or log1p are applied to compress dynamic range and stabilize variance, followed by feature-wise scaling (e.g., z-score or min–max scaling) to ensure comparability across markers. This results in approximately normal-like feature distributions, which can be effectively modeled using a normal likelihood. Empirically, we found this approach performs better than modeling raw intensities (see also CYTOVI and the corresponding preprint[3]).
For integration with DiagVI, we recommend a simple two-step preprocessing strategy inspired by CYTOVI:
Arcsinh transformation to stabilize variance and compress dynamic range.
Feature-wise min–max scaling to rescale each marker to the [0, 1] range and account for differences in marker brightness.
Optional: For both count and continuous data, DiagVI can additionally incorporate:
Experimental covariates such as batch annotation or confounding variables such as donor sex for one or both modalites. For simplicity, we will describe the case of categorical batch identifiers \(s_n, s_m \in \{1,...,S\}\).
Cell label annotations that weakly inform the prior of the latent space and guide a classifier in semi-supervised training for one or both modalities. We assume categorical label identifiers \(c_n, c_m \in\{1,...,C\}\).
Both optional inputs are handled independently for each modality, enabling distinct batch structures or cell type annotations per modality.
Currently supported modalities include:
scRNA-seq
Spatial transcriptomics: image-based and sequencing-based, e.g., 10x Xenium, Vizgen MERSCOPE, NanoString CosMx, and 10x Visium
Spatial proteomics: antibody-based imaging assays, e.g., PhenoCycler (formerly CODEX), 4i, Imaging Mass Cytometry, and protein add-on panels for Xenium or CosMx
Single-cell proteomics: such as CITE-seq and other antibody-derived tag (ADT)-based assays
Other count-based or continuous feature measurements
Model components#
DiagVI consists of several components which together define the overall training objective (see Training Objective).
Modality-specific variational autoencoders#
DiagVI integrates two unpaired modalities by projecting their expression matrices \(\mathbf{X}_1 \in \mathbb{R}^{N \times G}\) and \(\mathbf{X}_2 \in \mathbb{R}^{M \times P}\) into a shared latent space in which each cell \(n\) in modality 1 has a latent representation \(\mathbf{z}_n \in \mathbb{R}^d\) and each cell \(m\) in modality 2 has a latent representation \(\mathbf{z}_m \in \mathbb{R}^d\). To find this shared state, the observed data from each modality is modeled using an independently parametrized VAE.
Guidance graph#
Projecting cells into a common low-dimensional space alone does not result in semantically consistent embeddings in which identical cell types are assigned to the same region across both modalities. To ensure biological consistency in the latent space, a guidance graph establishes logical associations between the two modalities by connecting linked features (see also GLUE[1]).
Feature correspondences are encoded in the guidance graph \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\), where \(\mathcal{V} = \mathcal{V}_1 \cup \mathcal{V}_2\) with modality-specific feature sets \(\mathcal{V}_1\) and \(\mathcal{V}_2\), and \(\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}\) denotes the set of edges. Each edge \((i,j) \in \mathcal{E}\) is associated with
a weight \(w_{ij} \in (0, 1]\) reflecting the confidence of the link
a sign \(\sigma_{ij} \in \{-1,1\}\) specifying whether the interaction is associative (\(\sigma_{ij} = 1\)) or repressive (\(\sigma_{ij} = -1\))
The graph loss encourages the inner product between embeddings \(\mathbf{v}_i\) and \(\mathbf{v}_j\) of linked features to be large and positive for \(\sigma_{ij} = 1\) and large and negative for \(\sigma_{ij} = -1\), with strength modulated by edge weights.
Unbalanced optimal transport#
To ensure robust alignment of cells from different modalities within the shared latent space, DiagVI leverages unbalanced optimal transport (UOT). Specifically, it minimizes the de-biased Sinkhorn divergence between latent distributions using the GeomLoss library[4]. UOT relaxes the marginal constraints of classical optimal transport, allowing for unequal total mass between distributions. This enables robust integration in the presence of differing cell type proportions across modalities.
Classifier#
Inspired by RESOLVI[5], a simple cell type classifier predicting labels \(c_n\) and \(c_m\) from cell latent vectors \(\mathbf{z}_n\) and \(\mathbf{z}_m\), respectively, is integrated into the model in the semi-supervised setting. This classifier is trained jointly with the generative model.
Descriptive model#
DiagVI assumes that observations from each modality are generated from a shared \(d\)-dimensional latent space. For cell \(n\) and feature \(g\), the observed data \(x_{ng}\) is generated conditionally on the cell latent variable \(\mathbf{z}_n \in \mathbb{R}^{d}\), the feature latent variable \(\mathbf{v}_g \in \mathbb{R}^{d}\), and its associated batch \(s_n\).
The prior on the feature latent variable is a standard multivariate normal:
The prior on the cell latent variable can be either a standard multivariate normal (default) or a Gaussian mixture with \(L\) components:
If cell type labels are provided and a Gaussian mixture prior is used, \(L\) is set to the number of unique cell types in the labeled data.
Generative process#
Generative model#
The generative model follows the formulations used in GLUE[1], scVI[6], and CytoVI[3], and adapts depending on whether the observed data consist of discrete counts or continuous measurements. In both cases, the specific form of the data likelihood varies with the distribution used to model the observed data in each modality.
Regardless of the data type and likelihood choice, the inner product between cell and feature latent embeddings parametrizes the decoder together with batch-specific versions of scaling (\(\alpha_{s_n,g}\)) and bias (\(\beta_{s_n,g}\)) parameters:
For count-based modalities, denoised and normalized expression proportions \(\rho_{ng}\) are obtained via a softmax over features:
The mean of the generative distribution is given by
where \(l_n\) denotes the observed library size of cell \(n\). Counts are then modeled as
when using the negative binomial likelihood with mean \(\mu_{ng}\) and dispersion \(\theta_{s_n,g}\). For other likelihood choices, see Likelihood Models.
For continuous modalities, no simplex constraint is enforced and no library size normalization is applied. When using the normal likelihood, continuous expression values are modeled as
In both settings, when a batch covariate is provided, DiagVI learns batch-specific versions of scaling and bias parameters to account for batch effects within each modality. For each modality \(k\) and batch \(s\), the scaling and bias parameters satisfy
This generative process is also summarized in the following graphical model:
DiagVI graphical model. Shaded nodes represent observed data, unshaded nodes represent latent variables.#
Likelihood models#
Depending on the modality and data characteristics, different likelihood functions can be used to reconstruct the input data. The choice of likelihood should reflect whether the data are counts or continuous measurements, and whether they exhibit excess zeros or background signal.
We generally recommend:
Negative binomial (
nb) for count-based single-cell RNA-seq data.Normal (
normal) for transformed and scaled continuous measurements.
More specialized likelihoods such as zero-inflated or mixture models may be beneficial for sparse or background-contaminated measurements.
DiagVI supports the following likelihood functions:
Likelihood |
Distribution |
Typical modality |
Recommended preprocessing |
|---|---|---|---|
|
Negative Binomial |
Count data (e.g., scRNA-seq) |
Raw counts |
|
Zero-Inflated Negative Binomial |
Strongly zero-inflated count data |
Raw counts |
|
Negative Binomial Mixture |
Protein counts with background signal (e.g., CITE-seq ADTs) |
Raw counts |
|
Normal |
Continuous data |
Transformed (e.g., arcsinh or log1p) and feature-wise scaled |
|
Log1p-Normal |
Non-negative continuous data |
Raw data, optionally scaled |
|
Zero-Inflated Log-Normal |
Sparse positive continuous data |
Raw data, optionally scaled |
|
Zero-Inflated Gamma |
Sparse positive continuous data |
Raw data, optionally scaled |
Latent variables#
The exact set of variables instantiated during training depends on the chosen likelihood and on whether a Gaussian mixture prior is used for the cell latent space. Below we summarize the principal latent, decoder, and auxiliary variables used in DiagVI.
Variable |
Description |
Code variable |
|---|---|---|
\(\mathbf{z}_n \in \mathbb{R}^{d}\) |
Low-dimensional latent representation of cell \(n\), capturing its underlying biological state. |
|
\(\mathbf{v}_g \in \mathbb{R}^{d}\) |
Low-dimensional embedding of feature \(g\), inferred via the graph encoder from the guidance graph. |
|
\(c_n \in \{1,\dots,C\}\) |
Observed cell type label of cell \(n\) (if available), used for semi-supervised training and to inform the Gaussian mixture prior. |
|
\(\eta_{ng} \in \mathbb{R}\) |
Decoder linear predictor obtained from the bilinear interaction between cell and feature embeddings. |
N/A |
\(\boldsymbol{\rho}_n \in \Delta^{|\mathcal{V}_k|-1}\) |
Denoised, normalized expression proportions for cell \(n\) (count data); constrained to the probability simplex via softmax. |
|
\(\mu_{ng} \in \mathbb{R}_+\) |
Mean of the generative distribution for feature \(g\) in cell \(n\). |
|
\(\theta_{s_n,g} \in (0,\infty)\) |
Batch-specific inverse dispersion parameters used in count-based likelihoods. |
|
\(\sigma_{s_n,g}^2 \in (0,\infty)\) |
Batch-specific variance parameters used in continuous likelihoods. |
|
\(\delta_g \in (0,\infty)\) |
Dropout or zero-inflation parameters (used for zero-inflated likelihoods). |
|
\(\pi_l\) |
Mixture weights of the Gaussian mixture prior (if enabled). |
|
Inference#
Since the posterior distributions over the latent variables are intractable, DiagVI uses variational inference and specifically auto-encoding variational Bayes (see Variational Inference) to jointly learn model parameters and approximate posterior distributions.
For a given modality, the variational distribution factorizes as
Here \(\eta\) is a set of parameters corresponding to inference neural networks (encoders).
Training objective#
DiagVI is trained by minimizing a weighted sum of loss terms (weighted by lam_* parameters) corresponding to the Model Components introduced above:
Modality-specific VAEs: The data reconstruction loss (
lam_data) measures how well each modality-specific decoder reconstructs the observed data, while the KL divergence (lam_kl) regularizes the cell latent variables by encouraging adherence to the prior.Guidance graph: The graph reconstruction loss (
lam_graph) enforces biological consistency between feature embeddings using the guidance graph.UOT: The UOT alignment loss (
lam_sinkhorn) aligns cell distributions across modalities using unbalanced optimal transport.Classifier: The classification loss (
lam_class, optional) enables supervised or semi-supervised training via cell-type labels.
The lam_* parameters control the relative importance of within-modality reconstruction (lam_data, lam_kl), cross-modality alignment (lam_graph, lam_sinkhorn), and optional supervision (lam_class). DiagVI uses sensible defaults for all of these values, but they may require further tuning depending on the type of data being integrated (see Practical guidance).
Practical guidance#
Guidance graph creation#
When initializing DIAGVI, three ways to specify the guidance graph are supported:
Automatic construction: If neither
guidance_graphnormapping_dfare provided, DiagVI constructs a graph from overlapping feature names across modalities (e.g., shared gene symbols).Custom mapping via DataFrame: Pass a
pandas.DataFrametomapping_dfin which each column corresponds to a modality name (matching the keys ininput_dict) and each row specifies a feature pair. This is useful when feature naming conventions differ between modalities (e.g., genes vs. proteins). You can also useconstruct_custom_guidance_graph()to create a graph with custom edge weights and signs.Explicit graph specification: For full control, pass a pre-constructed
torch_geometric.data.Dataobject directly toguidance_graph. The graph must include node features, edge indices, edge weights, edge signs, and modality-specific feature index tensors..
Loss weights#
DiagVI provides default values for all loss weights that have been tested across multiple datasets and integration tasks. However, depending on the modalities and the integration setting, tuning some of these weights can improve performance.
In practice, we recommend paying particular attention to:
lam_sinkhorn(cross-modality alignment strength), andlam_class(classification strength, if cell-type labels are provided).
These two weights control the trade-off between alignment across modalities and separation of cell types within the latent space.
When integrating very different modalities (e.g., scRNA-seq + spatial proteomics), stronger alignment is typically required. In such cases:
Higher
lam_sinkhornLower
lam_class(if labels are used)
This encourages the model to prioritize matching global cell distributions across modalities, even if the feature spaces differ substantially.
When integrating more similar modalities (e.g., scRNA-seq + spatial transcriptomics), less alignment pressure is needed. In such cases:
Lower
lam_sinkhornHigher
lam_class(if labels are used)
Here, the modalities are already structurally similar, so stronger supervision can help maintain clearer separation of biologically distinct cell types.
As a general rule, we recommend lam_sinkhorn and lam_class values between 1 and 100, which work well in most settings.
The default loss weight values place relatively strong emphasis on cross-modality alignment and are therefore well suited for integrating heterogeneous modalities. For more similar modalities, we recommend reducing lam_sinkhorn (e.g., 5-20) and increasing lam_class (e.g., 60-80), depending on label availability and desired separation.
In practice, small grid searches over lam_sinkhorn and lam_class are often sufficient to find well-performing settings.
Sinkhorn parameters#
Three parameters determine the behavior of the Sinkhorn divergence used for cross-modality alignment:
p: order of the ground cost (Wasserstein-p distance). The default is p = 2, corresponding to a squared Euclidean ground cost.blur: controls the strength of entropic regularization and therefore the smoothness of the transport plan. Larger values increase regularization, resulting in smoother and more diffuse transport plans.reach: controls the marginal relaxation parameter in unbalanced optimal transport, penalizing deviations from strict mass conservation. Larger values enforce stricter marginal constraints, while smaller values allow greater mass variation between the matched distributions.
Although these parameters can be specified manually, we generally do not recommend modifying them.
By default, DiagVI follows the heuristic strategy introduced in OTT-JAX. In this setting:
bluris computed adaptively at each optimization step from the minibatch cost matrix, andreachis derived as a function of blur.
This adaptive strategy makes the Sinkhorn loss scale-aware and robust across datasets, and in practice removes the need for manual tuning. Adjusting these parameters is rarely necessary.
Tasks#
Here we provide an overview of some of the tasks that DiagVI can perform. Please see DIAGVI for the full API reference.
Dimensionality reduction#
DiagVI provides aligned low-dimensional representations for each modality. By default, the mean of the approximate posterior \(q_\eta(\mathbf{z}_n \mid x_n)\) is returned.
Latent representations can be obtained using:
>>> latents = model.get_latent_representation()
>>> adata_mod1.obsm["X_diagvi"] = latents["mod1"]
>>> adata_mod2.obsm["X_diagvi"] = latents["mod2"]
The aligned representations can be used for downstream analyses either within each modality or jointly across modalities.
For example, within-modality visualization:
>>> import scanpy as sc
>>> sc.pp.neighbors(adata_mod1, use_rep="X_diagvi")
>>> sc.tl.umap(adata_mod1)
For joint analysis across modalities:
>>> adata_combined = sc.concat(
... [adata_mod1, adata_mod2],
... axis=0,
... join="inner",
... label="modality",
... keys=["mod1", "mod2"],
... )
>>> sc.pp.neighbors(adata_combined, use_rep="X_diagvi")
>>> sc.tl.umap(adata_combined)
Cross-modal feature imputation#
DiagVI can impute features from one modality into another using get_imputed_values().
For example, to impute protein expression for RNA cells:
>>> imputed_protein = model.get_imputed_values(
... query_name="rna",
... query_adata=adata_rna,
... )
>>> adata_rna.obsm["imputed_protein"] = imputed_protein
It is also possible to perform counterfactual predictions by specifying a reference batch or reference library size (see also Counterfactual prediction):
>>> imputed_rna = model.get_imputed_values(
... query_name="protein",
... reference_batch="batch_1",
... reference_libsize=10000,
... )
>>> adata_protein.obsm["imputed_rna"] = imputed_rna
This allows imputation under specified sequencing depth or batch conditions.
Cell label transfer#
Aligned latent representations can be used for cross-modal cell label transfer.
For example, when integrating scRNA-seq reference data with spatial proteomics data, cell-type labels can be transferred from the RNA reference to the spatial data. This can be done either using DiagVI’s built-in classifier (when cell-type labels have been provided for the reference modality during training) or with external tools such as CellMapper, a toolkit for cross-modal cell mapping and evaluation.
Using DiagVI’s trained classifier, cell labels can be transferred as follows:
>>> preds = model.predict_celltype(labeled_modality="rna")
>>> adata_protein.obs["celltype_pred"] = preds["predictions"]
>>> adata_protein.obs["celltype_conf"] = preds["confidence"]
Alternatively, using CellMapper:
>>> latents = model.get_latent_representation()
>>> adata_rna.obsm["X_diagvi"] = latents["rna"]
>>> adata_protein.obsm["X_diagvi"] = latents["protein"]
>>>
>>> import CellMapper
>>> cmap = CellMapper(adata_protein, adata_rna)
>>> cmap.map("cell_type_labels", use_rep="X_diagvi")
>>> cmap.evaluate_label_transfer("cell_type_labels")
Because DiagVI produces a shared latent space, any downstream method that operates on low-dimensional embeddings (e.g., kNN mapping, clustering, classification) can be applied directly.
Data simulation#
Data can be generated from the model using the posterior predictive distribution in posterior_predictive_sample(). This is equivalent to feeding a cell through the model, sampling from the posterior distributions of the latent variables, retrieving the likelihood parameters, and finally, sampling from this distribution.