class scvi.model.TOTALVI(adata, n_latent=20, gene_dispersion='gene', protein_dispersion='protein', gene_likelihood='nb', latent_distribution='normal', empirical_protein_background_prior=None, **model_kwargs)[source]

total Variational Inference [GayosoSteier21].

adata : AnnDataAnnData

AnnData object that has been registered via setup_anndata().

n_latent : intint (default: 20)

Dimensionality of the latent space.

gene_dispersion : {‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’}Literal[‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’] (default: 'gene')

One of the following:

  • 'gene' - genes_dispersion parameter of NB is constant per gene across cells

  • 'gene-batch' - genes_dispersion can differ between different batches

  • 'gene-label' - genes_dispersion can differ between different labels

protein_dispersion : {‘protein’, ‘protein-batch’, ‘protein-label’}Literal[‘protein’, ‘protein-batch’, ‘protein-label’] (default: 'protein')

One of the following:

  • 'protein' - protein_dispersion parameter is constant per protein across cells

  • 'protein-batch' - protein_dispersion can differ between different batches NOT TESTED

  • 'protein-label' - protein_dispersion can differ between different labels NOT TESTED

gene_likelihood : {‘zinb’, ‘nb’}Literal[‘zinb’, ‘nb’] (default: 'nb')

One of:

  • 'nb' - Negative binomial distribution

  • 'zinb' - Zero-inflated negative binomial distribution

latent_distribution : {‘normal’, ‘ln’}Literal[‘normal’, ‘ln’] (default: 'normal')

One of:

  • 'normal' - Normal distribution

  • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

empirical_protein_background_prior : bool | NoneOptional[bool] (default: None)

Set the initialization of protein background prior empirically. This option fits a GMM for each of 100 cells per batch and averages the distributions. Note that even with this option set to True, this only initializes a parameter that is learned during inference. If False, randomly initializes. The default (None), sets this to True if greater than 10 proteins are used.


Keyword args for TOTALVAE


>>> adata = anndata.read_h5ad(path_to_anndata)
>>>, batch_key="batch", protein_expression_obsm_key="protein_expression")
>>> vae = scvi.model.TOTALVI(adata)
>>> vae.train()
>>> adata.obsm["X_totalVI"] = vae.get_latent_representation()


See further usage examples in the following tutorials:

  1. CITE-seq analysis with totalVI

  2. Integration of CITE-seq and scRNA-seq data

  3. Online update of scvi-tools models with query datasets




Returns computed metrics during training.






differential_expression([adata, groupby, …])

A unified method for differential expression analysis.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_feature_correlation_matrix([adata, …])

Generate gene-gene correlation matrix using scvi uncertainty and expression.

get_latent_library_size([adata, indices, …])

Returns the latent library size for each cell.

get_latent_representation([adata, indices, …])

Return the latent representation for each cell.

get_likelihood_parameters([adata, indices, …])

Estimates for the parameters of the likelihood \(p(x, y \mid z)\).

get_marginal_ll([adata, indices, …])

Return the marginal LL for the data.

get_normalized_expression([adata, indices, …])

Returns the normalized gene expression and protein expression.

get_protein_background_mean(adata, indices, …)

get_protein_foreground_probability([adata, …])

Returns the foreground probability for proteins.

get_reconstruction_error([adata, indices, …])

Return the reconstruction error for the data.

load(dir_path[, adata, use_gpu])

Instantiate a model from the saved output.

load_query_data(adata, reference_model[, …])

Online update of a reference model with scArches algorithm [Lotfollahi20].

posterior_predictive_sample([adata, …])

Generate observation samples from the posterior predictive distribution.

save(dir_path[, overwrite, save_anndata])

Save the state of the model.


Move model to device.

train([max_epochs, lr, use_gpu, train_size, …])

Trains the model using amortized variational inference.