scvi.model.SCVI

class scvi.model.SCVI(adata, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='zinb', latent_distribution='normal', **model_kwargs)[source]

single-cell Variational Inference [Lopez18].

Parameters
adata : AnnDataAnnData

AnnData object that has been registered via setup_anndata().

n_hidden : intint (default: 128)

Number of nodes per hidden layer.

n_latent : intint (default: 10)

Dimensionality of the latent space.

n_layers : intint (default: 1)

Number of hidden layers used for encoder and decoder NNs.

dropout_rate : floatfloat (default: 0.1)

Dropout rate for neural networks.

dispersion : {‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’}Literal[‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’] (default: 'gene')

One of the following:

  • 'gene' - dispersion parameter of NB is constant per gene across cells

  • 'gene-batch' - dispersion can differ between different batches

  • 'gene-label' - dispersion can differ between different labels

  • 'gene-cell' - dispersion can differ for every gene in every cell

gene_likelihood : {‘zinb’, ‘nb’, ‘poisson’}Literal[‘zinb’, ‘nb’, ‘poisson’] (default: 'zinb')

One of:

  • 'nb' - Negative binomial distribution

  • 'zinb' - Zero-inflated negative binomial distribution

  • 'poisson' - Poisson distribution

latent_distribution : {‘normal’, ‘ln’}Literal[‘normal’, ‘ln’] (default: 'normal')

One of:

  • 'normal' - Normal distribution

  • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

**model_kwargs

Keyword args for VAE

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.data.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.SCVI(adata)
>>> vae.train()
>>> adata.obsm["X_scVI"] = vae.get_latent_representation()
>>> adata.obsm["X_normalized_scVI"] = vae.get_normalized_expression()

Notes

See further usage examples in the following tutorials:

  1. Introduction to scvi-tools

  2. Atlas-level integration and label transfer

  3. Online update of scvi-tools models with query datasets

  4. Interoperability with R and Seurat

Attributes

device

history

Returns computed metrics during training.

is_trained

test_indices

train_indices

validation_indices

Methods

differential_expression([adata, groupby, …])

A unified method for differential expression analysis.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_feature_correlation_matrix([adata, …])

Generate gene-gene correlation matrix using scvi uncertainty and expression.

get_latent_library_size([adata, indices, …])

Returns the latent library size for each cell.

get_latent_representation([adata, indices, …])

Return the latent representation for each cell.

get_likelihood_parameters([adata, indices, …])

Estimates for the parameters of the likelihood \(p(x \mid z)\)

get_marginal_ll([adata, indices, …])

Return the marginal LL for the data.

get_normalized_expression([adata, indices, …])

Returns the normalized (decoded) gene expression.

get_reconstruction_error([adata, indices, …])

Return the reconstruction error for the data.

load(dir_path[, adata, use_gpu])

Instantiate a model from the saved output.

load_query_data(adata, reference_model[, …])

Online update of a reference model with scArches algorithm [Lotfollahi20].

posterior_predictive_sample([adata, …])

Generate observation samples from the posterior predictive distribution.

save(dir_path[, overwrite, save_anndata])

Save the state of the model.

to_device(device)

Move model to device.

train([max_epochs, use_gpu, train_size, …])

Train the model.