totalVI [1] (total Variational Inference; Python class TOTALVI) posits a flexible generative model of CITE-seq RNA and protein data that can subsequently be used for many common downstream tasks.

The advantages of totalVI are:

  • Comprehensive in capabilities.

  • Scalable to very large datasets (>1 million cells).

The limitations of totalVI include:

  • Effectively requires a GPU for fast inference.

  • Difficult to understand the balance between RNA and protein data in the low-dimensional representation of cells.


totalVI takes as input a scRNA-seq gene expression matrix of UMI counts \(X\) with \(N\) cells and \(G\) genes along with a paired matrix of protein abundance (UMI counts) \(Y\), also of \(N\) cells, but with \(T\) proteins. Thus, for each cell, we observe both RNA and protein information. Additionally, a design matrix \(S\) containing \(p\) observed covariates for each of the cells, such as day, donor, etc, is an optional input. While \(S\) can include both categorical covariates and continuous covariates, in the following, we assume it contains only one categorical covariate with \(K\) categories, which represents the common case of having multiple batches of data.

Generative process#

We posit each cell’s protein and RNA expression to be generated by the following process:

First, for each cell \(n\),

\begin{align} z_n &\sim \textrm{Normal}(0, I) \tag{1} \\ \rho_{n} &= f_\rho(z_n, s_n) \tag{2} \\ \alpha_n &= g_\alpha(z_n, s_n) \tag{3} \\ \pi_n &= h_\pi(z_n, s_n) \tag{4} \\ l_n &\sim \textrm{LogNormal}(l_\mu^\top s_n, l_{\sigma^2}^\top s_n) \tag{5}\\ \end{align}

The prior parameters \(l_\mu\) and \(l_{\sigma^2}\) are computed per batch as the mean and variance of the log library size over cells. The generative process of totalVI uses neural networks:

\begin{align} f_\rho(z_n, s_n) &: \mathbb{R}^d \times \{0, 1\}^K \to \Delta^{G-1} \tag{6} \\ g_\alpha(z_n, s_n) &: \mathbb{R}^d \times \{0, 1\}^K \to [1, \infty)^T \tag{7}\\ h_\pi(z_n, s_n) &: \mathbb{R}^d \times \{0, 1\}^K \to (0, 1)^T \tag{8} \end{align}

where \(d\) is the dimension of the latent space (associated with latent variable \(z\)). We also have global parameters \(\theta_g\) and \(\phi_t\), which represent gene- and protein-specific (respectively) overdispersion.

Then for each gene \(g\) in cell \(n\),

\begin{align} x_{ng} &\sim \textrm{NegativeBinomial}\left(l_n\rho_{ng}, \theta_g \right), \tag{10}\\ \end{align}

where the negative binomial is parameterized by its mean and inverse dispersion. And finally for each protein \(t\) in cell \(n\),

\begin{align} \beta_{nt} &\sim \textrm{LogNormal}(c_t^\top s_n, d_t^\top s_n) \tag{11}\\ v_{nt} &\sim \textrm{Bernoulli}(\pi_{nt}) \tag{12}\\ y_{nt} &\sim \textrm{NegativeBinomial}\left(v_{nt}\beta_{nt} + (1-v_{nt})\beta_{nt}\alpha_{nt}, \phi_t \right) \tag{14} \end{align}

Integrating out \(v_{nt}\) yields a negative binomial mixture conditional distribution for \(y_{nt}\). Furthermore, \(\beta_{nt}\) represents background protein signal due to ambient antibodies or non-specific antibody binding. The prior parameters \(c_t\) and \(d_t\) are unfortunately called background_pro_alpha and background_pro_log_beta in the code. They are learned parameters during infererence, but are initialized through a procedure that fits a two-component Gaussian mixture model for each cell and records the mean and variance of the component with smaller mean, aggregating across all cells. This can be disabled by setting empirical_protein_background_prior=False, which then forces a random Initialization.

totalVI graphical model

totalVI graphical model. Shaded nodes represent observed data, unshaded nodes represent latent variables.#

The latent variables, along with their description are summarized in the following table:

Latent variable


Code variable (if different)

\(z_n \in \mathbb{R}^d\)

Low-dimensional representation capturing joint state of a cell


\(\rho_n \in \Delta^{G-1}\)

Denoised/normalized gene expression. This is a vector that sums to 1 within a cell, unless size_factor_key is not None in setup_anndata, in which case this is only forced to be non-negative via softplus.


\(\alpha_n \in [1, \infty)^T\)

Foreground scaling factor for proteins, identifies the mixture distribution (see below)


\(\pi_n \in (0, 1)^T\)

Probability of background for each protein

py_["mixing"] (logits scale).

\(l_n \in (0, \infty)\)

Library size for RNA. Here it is modeled as a latent variable, but the recent default for totalVI is to treat library size as observed, equal to the total RNA UMI count of a cell. This can be controlled by passing use_observed_lib_size=False to TOTALVI. The library size can also be set manually using size_factor_key in setup_anndata.


\(\beta_{nt} \in (0, \infty)\)

Protein background intensity. Used twice to identify the protein mixture model.



totalVI uses variational inference, and specifically auto-encoding variational bayes (see Variational Inference), to learn both the model parameters (the neural network params, dispersion params, etc.), and an approximate posterior distribution with the following factorization:

\begin{align} q_\eta(\beta_n, z_n, l_n \mid x_n, y_n, s_n) := q_\eta(\beta_n \mid z_n,s_n)q_\eta(z_n \mid x_n, y_n,s_n)q_\eta(l_n \mid x_n, y_n, s_n). \end{align}

Here \(\eta\) is a set of parameters corresponding to inference neural networks, which we do not describe in detail here, but are described in the totalVI paper. totalVI can also handle missing proteins (i.e., a dataset comprised of multiple batches, where each batch potentially has a different antibody panel, or no protein data at all). We refer the reader to the original publication for these details.


Dimensionality reduction#

For dimensionality reduction, we by default return the mean of the approximate posterior \(q_\eta(z_n \mid x_n, y_n,s_n)\). This is achieved using the method:

>>> latent = model.get_latent_representation()
>>> adata.obsm["X_totalvi"] = latent

Users may also return samples from this distribution, as opposed to the mean by passing the argument give_mean=False. The latent representation can be used to create a nearest neighbor graph with scanpy with:

>>> import scanpy as sc
>>> sc.pp.neighbors(adata, use_rep="X_totalvi")
>>> adata.obsp["distances"]

Normalization and denoising of RNA and protein expression#

In get_normalized_expression() totalVI returns, for RNA, the expected value of \(l_n\rho_n\) under the approximate posterior, and for proteins, the expected value of \((1 − \pi_{nt})\beta_{nt}\alpha_n\). For one cell \(n\), in the case of RNA, this can be written as:

\begin{align} \mathbb{E}_{q_\eta(z_n \mid x_n, y_n,s_n)}\left[l_n'f_\rho\left( z_n, s_n \right) \right], \end{align}

where \(l_n'\) is by default set to 1. See the library_size parameter for more details. The expectation is approximated using Monte Carlo, and the number of samples can be passed as an argument in the code:

>>> rna, protein = model.get_normalized_expression(n_samples=10)

By default the mean over these samples is returned, but users may pass return_mean=False to retrieve all the samples.

In the case of proteins, there are a few important options that control what constitues denoised protein expression. For example, include_protein_background=True will result in estimating the expectation of \((1 − \pi_{nt})\beta_{nt}\alpha_{nt} + \pi_{nt}\beta_{nt}\). Setting sampling_protein_mixing=True will result in sampling \(v_{nt} \sim \textrm{Bernoulli}(\pi_{nt})\) and replacing \(\pi_{nt}\) with \(v_{nt}\).

Notably, this function also has the transform_batch parameter that allows counterfactual prediction of expression in an unobserved batch. See the Counterfactual prediction guide.

Differential expression#

Differential expression analysis is achieved with differential_expression(). totalVI tests differences in magnitude of \(f_\rho\left( z_n, s_n \right)\) for RNA, and \((1 − \pi_{nt})\beta_{nt}\alpha_{nt}\) with similar options to change this quantity as in the normalized expression function. More info on the mathematics behind differential expression is in Differential Expression.

Data simulation#

Data can be generated from the model using the posterior predictive distribution in posterior_predictive_sample(). This is equivalent to feeding a cell through the model, sampling from the posterior distributions of the latent variables, retrieving the likelihood parameters, and finally, sampling from this distribution.