scvi.module.AutoZIVAE

class scvi.module.AutoZIVAE(n_input, alpha_prior=0.5, beta_prior=0.5, minimal_dropout=0.01, zero_inflation='gene', **args)[source]

Bases: scvi.module._vae.VAE

Implementation of the AutoZI model [Clivio19].

Parameters
n_input : intint

Number of input genes

alpha_prior : float | NoneOptional[float] (default: 0.5)

Float denoting the alpha parameter of the prior Beta distribution of the zero-inflation Bernoulli parameter. Should be between 0 and 1, not included. When set to ``None’’, will be set to 1 - beta_prior if beta_prior is not ``None’’, otherwise the prior Beta distribution will be learned on an Empirical Bayes fashion.

beta_prior : float | NoneOptional[float] (default: 0.5)

Float denoting the beta parameter of the prior Beta distribution of the zero-inflation Bernoulli parameter. Should be between 0 and 1, not included. When set to ``None’’, will be set to 1 - alpha_prior if alpha_prior is not ``None’’, otherwise the prior Beta distribution will be learned on an Empirical Bayes fashion.

minimal_dropout : floatfloat (default: 0.01)

Float denoting the lower bound of the cell-gene ZI rate in the ZINB component. Must be non-negative. Can be set to 0 but not recommended as this may make the mixture problem ill-defined.

zero_inflation : One of the following

  • 'gene' - zero-inflation Bernoulli parameter of AutoZI is constant per gene across cells

  • 'gene-batch' - zero-inflation Bernoulli parameter can differ between different batches

  • 'gene-label' - zero-inflation Bernoulli parameter can differ between different labels

  • 'gene-cell' - zero-inflation Bernoulli parameter can differ for every gene in every cell

See VAE docstring (scvi/models/vae.py) for more parameters. reconstruction_loss should not be specified.

Examples

>>> gene_dataset = CortexDataset()
>>> autozivae = AutoZIVAE(gene_dataset.nb_genes, alpha_prior=0.5, beta_prior=0.5, minimal_dropout=0.01)

Attributes

Methods

compute_global_kl_divergence()

rtype

TensorTensor

generative(z, library[, batch_index, y, …])

Runs the generative model.

get_alphas_betas([as_numpy])

rtype

{str: Tensor | ndarray}Dict[str, Union[Tensor, ndarray]]

get_reconstruction_loss(x, px_rate, px_r, …)

rtype

TensorTensor

loss(tensors, inference_outputs, …[, …])

Compute the loss for a minibatch of data.

rescale_dropout(px_dropout[, eps_log])

rtype

TensorTensor

reshape_bernoulli(bernoulli_params[, …])

rtype

TensorTensor

sample_bernoulli_params([batch_index, y, …])

rtype

TensorTensor

sample_from_beta_distribution(alpha, beta[, …])

rtype

TensorTensor