class scvi.model.PEAKVI(adata, n_hidden=None, n_latent=None, n_layers_encoder=2, n_layers_decoder=2, dropout_rate=0.1, model_depth=True, region_factors=True, use_batch_norm='none', use_layer_norm='both', latent_distribution='normal', deeply_inject_covariates=False, encode_covariates=False, **model_kwargs)[source]

Peak Variational Inference [Ashuach21]

adata : AnnDataAnnData

AnnData object that has been registered via setup_anndata().

n_hidden : int | NoneOptional[int] (default: None)

Number of nodes per hidden layer. If None, defaults to square root of number of regions.

n_latent : int | NoneOptional[int] (default: None)

Dimensionality of the latent space. If None, defaults to square root of n_hidden.

n_layers_encoder : intint (default: 2)

Number of hidden layers used for encoder NN.

n_layers_decoder : intint (default: 2)

Number of hidden layers used for decoder NN.

dropout_rate : floatfloat (default: 0.1)

Dropout rate for neural networks

model_depth : boolbool (default: True)

Model sequencing depth / library size (default: True)

region_factors : boolbool (default: True)

Include region-specific factors in the model (default: True)

latent_distribution : {‘normal’, ‘ln’}Literal[‘normal’, ‘ln’] (default: 'normal')

One of

  • 'normal' - Normal distribution (Default)

  • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

deeply_inject_covariates : boolbool (default: False)

Whether to deeply inject covariates into all layers of the decoder. If False (default), covairates will only be included in the input layer.


Keyword args for PEAKVAE


>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.PEAKVI.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.PEAKVI(adata)
>>> vae.train()


See further usage examples in the following tutorials:

  1. /user_guide/notebooks/PeakVI




Returns computed metrics during training.






differential_accessibility([adata, groupby, ...])

A unified method for differential accessibility analysis.

get_accessibility_estimates([adata, ...])

Impute the full accessibility matrix.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_latent_representation([adata, indices, ...])

Return the latent representation for each cell.

get_library_size_factors([adata, indices, ...])

Return library size factors.

get_marginal_ll([adata, indices, ...])

Return the marginal LL for the data.

get_reconstruction_error([adata, indices, ...])

Return the reconstruction error for the data.


Return region-specific factors.

load(dir_path[, adata, use_gpu])

Instantiate a model from the saved output.

load_query_data(adata, reference_model[, ...])

Online update of a reference model with scArches algorithm [Lotfollahi21].

save(dir_path[, overwrite, save_anndata])

Save the state of the model.

setup_anndata(adata[, batch_key, layer, ...])

Sets up the AnnData object for this model.


Move model to device.

train([max_epochs, lr, use_gpu, train_size, ...])

Trains the model using amortized variational inference.