class scvi.module.VAE(n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, n_continuous_cov=0, n_cats_per_cov=None, dropout_rate=0.1, dispersion='gene', log_variational=True, gene_likelihood='zinb', latent_distribution='normal', encode_covariates=False, deeply_inject_covariates=True, use_batch_norm='both', use_layer_norm='none', use_size_factor_key=False, use_observed_lib_size=True, library_log_means=None, library_log_vars=None, var_activation=None)[source]#

Bases: scvi.module.base._base_module.BaseModuleClass

Variational auto-encoder model.

This is an implementation of the scVI model described in [Lopez18]

n_input : int

Number of input genes

n_batch : int (default: 0)

Number of batches, if 0, no batch correction is performed.

n_labels : int (default: 0)

Number of labels

n_hidden : int (default: 128)

Number of nodes per hidden layer

n_latent : int (default: 10)

Dimensionality of the latent space

n_layers : int (default: 1)

Number of hidden layers used for encoder and decoder NNs

n_continuous_cov : int (default: 0)

Number of continuous covarites

n_cats_per_cov : Iterable[int] | NoneOptional[Iterable[int]] (default: None)

Number of categories for each extra categorical covariate

dropout_rate : float (default: 0.1)

Dropout rate for neural networks

dispersion : str (default: 'gene')

One of the following

  • 'gene' - dispersion parameter of NB is constant per gene across cells

  • 'gene-batch' - dispersion can differ between different batches

  • 'gene-label' - dispersion can differ between different labels

  • 'gene-cell' - dispersion can differ for every gene in every cell

log_variational : bool (default: True)

Log(data+1) prior to encoding for numerical stability. Not normalization.

gene_likelihood : str (default: 'zinb')

One of

  • 'nb' - Negative binomial distribution

  • 'zinb' - Zero-inflated negative binomial distribution

  • 'poisson' - Poisson distribution

latent_distribution : str (default: 'normal')

One of

  • 'normal' - Isotropic normal

  • 'ln' - Logistic normal with normal params N(0, 1)

encode_covariates : bool (default: False)

Whether to concatenate covariates to expression in encoder

deeply_inject_covariates : bool (default: True)

Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when n_layers > 1. The covariates are concatenated to the input of subsequent hidden layers.

use_layer_norm : {‘encoder’, ‘decoder’, ‘none’, ‘both’}Literal[‘encoder’, ‘decoder’, ‘none’, ‘both’] (default: 'none')

Whether to use layer norm in layers

use_size_factor_key : bool (default: False)

Use size_factor AnnDataField defined by the user as scaling factor in mean of conditional distribution. Takes priority over use_observed_lib_size.

use_observed_lib_size : bool (default: True)

Use observed library size for RNA as scaling factor in mean of conditional distribution

library_log_means : ndarray | NoneOptional[ndarray] (default: None)

1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if not using observed library size.

library_log_vars : ndarray | NoneOptional[ndarray] (default: None)

1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if not using observed library size.

var_activation : Callable | NoneOptional[Callable] (default: None)

Callable used to ensure positivity of the variational distributions’ variance. When None, defaults to torch.exp.

Attributes table#

Methods table#

generative(z, library, batch_index[, ...])

Runs the generative model.

get_reconstruction_loss(x, px_rate, px_r, ...)



inference(x, batch_index[, cont_covs, ...])

High level inference method.

loss(tensors, inference_outputs, ...[, ...])

Compute the loss for a minibatch of data.

marginal_ll(tensors, n_mc_samples)

sample(tensors[, n_samples, library_size])

Generate observation samples from the posterior predictive distribution.




alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor]) .. autoattribute:: VAE.T_destination device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^



VAE.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

training# bool#



VAE.generative(z, library, batch_index, cont_covs=None, cat_covs=None, size_factor=None, y=None, transform_batch=None)[source]#

Runs the generative model.


VAE.get_reconstruction_loss(x, px_rate, px_r, px_dropout)[source]#
Return type



VAE.inference(x, batch_index, cont_covs=None, cat_covs=None, n_samples=1)[source]#

High level inference method.

Runs the inference (encoder) model.


VAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#

Compute the loss for a minibatch of data.

This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.

This function should return an object of type LossRecorder.


VAE.marginal_ll(tensors, n_mc_samples)[source]#


VAE.sample(tensors, n_samples=1, library_size=1)[source]#

Generate observation samples from the posterior predictive distribution.

The posterior predictive distribution is written as \(p(\hat{x} \mid x)\).


Tensors dict


Number of required samples for each cell


Library size to scale scamples to

Return type



x_new : torch.Tensor tensor with shape (n_cells, n_genes, n_samples)