class scvi.model.SCANVI(adata, unlabeled_category, n_hidden=128, n_latent=10, n_layers=1, dropout_rate=0.1, dispersion='gene', gene_likelihood='zinb', **model_kwargs)[source]

Single-cell annotation using variational inference [Xu21].

Inspired from M1 + M2 model, as described in (

adata : AnnDataAnnData

AnnData object that has been registered via setup_anndata().

unlabeled_category : str | int | floatUnion[str, int, float]

Value used for unlabeled cells in labels_key used to setup AnnData with scvi.

n_hidden : intint (default: 128)

Number of nodes per hidden layer.

n_latent : intint (default: 10)

Dimensionality of the latent space.

n_layers : intint (default: 1)

Number of hidden layers used for encoder and decoder NNs.

dropout_rate : floatfloat (default: 0.1)

Dropout rate for neural networks.

dispersion : {‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’}Literal[‘gene’, ‘gene-batch’, ‘gene-label’, ‘gene-cell’] (default: 'gene')

One of the following:

  • 'gene' - dispersion parameter of NB is constant per gene across cells

  • 'gene-batch' - dispersion can differ between different batches

  • 'gene-label' - dispersion can differ between different labels

  • 'gene-cell' - dispersion can differ for every gene in every cell

gene_likelihood : {‘zinb’, ‘nb’, ‘poisson’}Literal[‘zinb’, ‘nb’, ‘poisson’] (default: 'zinb')

One of:

  • 'nb' - Negative binomial distribution

  • 'zinb' - Zero-inflated negative binomial distribution

  • 'poisson' - Poisson distribution


Keyword args for SCANVAE


>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.SCANVI.setup_anndata(adata, batch_key="batch", labels_key="labels")
>>> vae = scvi.model.SCANVI(adata, "Unknown")
>>> vae.train()
>>> adata.obsm["X_scVI"] = vae.get_latent_representation()
>>> adata.obs["pred_label"] = vae.predict()


See further usage examples in the following tutorials:

  1. /user_guide/notebooks/harmonization

  2. /user_guide/notebooks/scarches_scvi_tools

  3. /user_guide/notebooks/seed_labeling




Returns computed metrics during training.






differential_expression([adata, groupby, ...])

A unified method for differential expression analysis.

from_scvi_model(scvi_model, unlabeled_category)

Initialize scanVI model with weights from pretrained scVI model.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_feature_correlation_matrix([adata, ...])

Generate gene-gene correlation matrix using scvi uncertainty and expression.

get_latent_library_size([adata, indices, ...])

Returns the latent library size for each cell.

get_latent_representation([adata, indices, ...])

Return the latent representation for each cell.

get_likelihood_parameters([adata, indices, ...])

Estimates for the parameters of the likelihood \(p(x \mid z)\)

get_marginal_ll([adata, indices, ...])

Return the marginal LL for the data.

get_normalized_expression([adata, indices, ...])

Returns the normalized (decoded) gene expression.

get_reconstruction_error([adata, indices, ...])

Return the reconstruction error for the data.

load(dir_path[, adata, use_gpu])

Instantiate a model from the saved output.

load_query_data(adata, reference_model[, ...])

Online update of a reference model with scArches algorithm [Lotfollahi21].

posterior_predictive_sample([adata, ...])

Generate observation samples from the posterior predictive distribution.

predict([adata, indices, soft, batch_size])

Return cell label predictions.

save(dir_path[, overwrite, save_anndata])

Save the state of the model.

setup_anndata(adata, labels_key[, ...])

Sets up the AnnData object for this model.


Move model to device.

train([max_epochs, n_samples_per_label, ...])

Train the model.