scvi.model.PEAKVI

class scvi.model.PEAKVI(adata, n_hidden=None, n_latent=None, n_layers_encoder=2, n_layers_decoder=2, dropout_rate=0.1, model_depth=True, region_factors=True, use_batch_norm='none', use_layer_norm='both', latent_distribution='normal', deeply_inject_covariates=False, encode_covariates=False, **model_kwargs)[source]

Peak Variational Inference [Ashuach21]

Parameters
adata : AnnDataAnnData

AnnData object that has been registered via setup_anndata().

n_hidden : int | NoneOptional[int] (default: None)

Number of nodes per hidden layer. If None, defaults to square root of number of regions.

n_latent : int | NoneOptional[int] (default: None)

Dimensionality of the latent space. If None, defaults to square root of n_hidden.

n_layers_encoder : intint (default: 2)

Number of hidden layers used for encoder NN.

n_layers_decoder : intint (default: 2)

Number of hidden layers used for decoder NN.

dropout_rate : floatfloat (default: 0.1)

Dropout rate for neural networks

model_depth : boolbool (default: True)

Model sequencing depth / library size (default: True)

region_factors : boolbool (default: True)

Include region-specific factors in the model (default: True)

latent_distribution : {‘normal’, ‘ln’}Literal[‘normal’, ‘ln’] (default: 'normal')

One of

  • 'normal' - Normal distribution (Default)

  • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

deeply_inject_covariates : boolbool (default: False)

Whether to deeply inject covariates into all layers of the decoder. If False (default), covairates will only be included in the input layer.

**model_kwargs

Keyword args for PEAKVAE

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.dataset.setup_anndata(adata, batch_key="batch")
>>> vae = scvi.model.PEAKVI(adata)
>>> vae.train()

Notes

See further usage examples in the following tutorials:

  1. PeakVI: Analyzing scATACseq data

Attributes

device

history

Returns computed metrics during training.

is_trained

test_indices

train_indices

validation_indices

Methods

differential_accessibility([adata, groupby, …])

A unified method for differential accessibility analysis.

get_accessibility_estimates([adata, …])

Impute the full accessibility matrix.

get_elbo([adata, indices, batch_size])

Return the ELBO for the data.

get_latent_representation([adata, indices, …])

Return the latent representation for each cell.

get_library_size_factors([adata, indices, …])

get_marginal_ll([adata, indices, …])

Return the marginal LL for the data.

get_reconstruction_error([adata, indices, …])

Return the reconstruction error for the data.

get_region_factors()

load(dir_path[, adata, use_gpu])

Instantiate a model from the saved output.

load_query_data(adata, reference_model[, …])

Online update of a reference model with scArches algorithm [Lotfollahi20].

save(dir_path[, overwrite, save_anndata])

Save the state of the model.

to_device(device)

Move model to device.

train([max_epochs, lr, use_gpu, train_size, …])

Trains the model using amortized variational inference.