scvi.module.VAE#
- class scvi.module.VAE(n_input, n_batch=0, n_labels=0, n_hidden=128, n_latent=10, n_layers=1, n_continuous_cov=0, n_cats_per_cov=None, dropout_rate=0.1, dispersion='gene', log_variational=True, gene_likelihood='zinb', latent_distribution='normal', encode_covariates=False, deeply_inject_covariates=True, use_batch_norm='both', use_layer_norm='none', use_size_factor_key=False, use_observed_lib_size=True, library_log_means=None, library_log_vars=None, var_activation=None)[source]#
Bases:
scvi.module.base._base_module.BaseModuleClass
Variational auto-encoder model.
This is an implementation of the scVI model described in [Lopez18]
- Parameters
- n_input :
int
Number of input genes
- n_batch :
int
(default:0
) Number of batches, if 0, no batch correction is performed.
- n_labels :
int
(default:0
) Number of labels
- n_hidden :
int
(default:128
) Number of nodes per hidden layer
- n_latent :
int
(default:10
) Dimensionality of the latent space
- n_layers :
int
(default:1
) Number of hidden layers used for encoder and decoder NNs
- n_continuous_cov :
int
(default:0
) Number of continuous covarites
- n_cats_per_cov :
Iterable
[int
] |None
Optional
[Iterable
[int
]] (default:None
) Number of categories for each extra categorical covariate
- dropout_rate :
float
(default:0.1
) Dropout rate for neural networks
- dispersion :
str
(default:'gene'
) One of the following
'gene'
- dispersion parameter of NB is constant per gene across cells'gene-batch'
- dispersion can differ between different batches'gene-label'
- dispersion can differ between different labels'gene-cell'
- dispersion can differ for every gene in every cell
- log_variational :
bool
(default:True
) Log(data+1) prior to encoding for numerical stability. Not normalization.
- gene_likelihood :
str
(default:'zinb'
) One of
'nb'
- Negative binomial distribution'zinb'
- Zero-inflated negative binomial distribution'poisson'
- Poisson distribution
- latent_distribution :
str
(default:'normal'
) One of
'normal'
- Isotropic normal'ln'
- Logistic normal with normal params N(0, 1)
- encode_covariates :
bool
(default:False
) Whether to concatenate covariates to expression in encoder
- deeply_inject_covariates :
bool
(default:True
) Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when n_layers > 1. The covariates are concatenated to the input of subsequent hidden layers.
- use_layer_norm : {‘encoder’, ‘decoder’, ‘none’, ‘both’}
Literal
[‘encoder’, ‘decoder’, ‘none’, ‘both’] (default:'none'
) Whether to use layer norm in layers
- use_size_factor_key :
bool
(default:False
) Use size_factor AnnDataField defined by the user as scaling factor in mean of conditional distribution. Takes priority over use_observed_lib_size.
- use_observed_lib_size :
bool
(default:True
) Use observed library size for RNA as scaling factor in mean of conditional distribution
- library_log_means :
ndarray
|None
Optional
[ndarray
] (default:None
) 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if not using observed library size.
- library_log_vars :
ndarray
|None
Optional
[ndarray
] (default:None
) 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if not using observed library size.
- var_activation :
Callable
|None
Optional
[Callable
] (default:None
) Callable used to ensure positivity of the variational distributions’ variance. When None, defaults to torch.exp.
- n_input :
Attributes table#
Methods table#
|
Runs the generative model. |
|
|
|
High level inference method. |
|
Compute the loss for a minibatch of data. |
|
|
|
Generate observation samples from the posterior predictive distribution. |
Attributes#
T_destination#
- VAE.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: VAE.T_destination
device
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- VAE.device#
dump_patches#
- VAE.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
generative#
get_reconstruction_loss#
inference#
loss#
- VAE.loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Compute the loss for a minibatch of data.
This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here.
This function should return an object of type
LossRecorder
.
marginal_ll#
sample#
- VAE.sample(tensors, n_samples=1, library_size=1)[source]#
Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as \(p(\hat{x} \mid x)\).
- Parameters
- tensors
Tensors dict
- n_samples
Number of required samples for each cell
- library_size
Library size to scale scamples to
- Return type
- Returns
x_new :
torch.Tensor
tensor with shape (n_cells, n_genes, n_samples)