scvi.external.contrastivevi.ContrastiveVAE#
- class scvi.external.contrastivevi.ContrastiveVAE(n_input, n_batch=0, n_hidden=128, n_background_latent=10, n_salient_latent=10, n_layers=1, dropout_rate=0.1, use_observed_lib_size=True, library_log_means=None, library_log_vars=None, wasserstein_penalty=0)[source]#
Bases:
BaseModuleClass
Variational inference for contrastive analysis of RNA-seq data.
Implements the contrastiveVI model of [Weinberger et al., 2023].
- Parameters:
n_input (
int
) – Number of input genes.n_batch (
int
(default:0
)) – Number of batches. If 0, no batch effect correction is performed.n_hidden (
int
(default:128
)) – Number of nodes per hidden layer.n_background_latent (
int
(default:10
)) – Dimensionality of the background latent space.n_salient_latent (
int
(default:10
)) – Dimensionality of the salient latent space.n_layers (
int
(default:1
)) – Number of hidden layers used for encoder and decoder NNs.dropout_rate (
float
(default:0.1
)) – Dropout rate for neural networks.use_observed_lib_size (
bool
(default:True
)) – Use observed library size for RNA as scaling factor in mean of conditional distribution.library_log_means (
ndarray
|None
(default:None
)) – 1 x n_batch array of means of the log library sizes. Parameterize prior on library size if not using observed library size.library_log_vars (
ndarray
|None
(default:None
)) – 1 x n_batch array of variances of the log library sizes. Parameterize prior on library size if not using observed library size.wasserstein_penalty (
float
(default:0
)) – Weight of the Wasserstein distance loss that further discourages shared variations from leaking into the salient latent space.
Attributes table#
Methods table#
|
Run the generative model. |
|
Run the recognition model. |
|
Computes KL divergence between a variational posterior and prior Gaussian. |
|
Computes KL divergence between library size variational posterior and prior. |
|
Computes loss terms for contrastiveVI. |
|
Computes likelihood loss for zero-inflated negative binomial distribution. |
|
Generate samples from the learned model. |
Attributes#
- ContrastiveVAE.training: bool#
Methods#
- ContrastiveVAE.generative(background, target)[source]#
Run the generative model.
This function should return the parameters associated with the likelihood of the data. This is typically written as \(p(x|z)\).
This function should return a dictionary with str keys and
Tensor
values.
- ContrastiveVAE.inference(background, target, n_samples=1)[source]#
Run the recognition model.
In the case of variational inference, this function will perform steps related to computing variational distribution parameters. In a VAE, this will involve running data through encoder networks.
This function should return a dictionary with str keys and
Tensor
values.
- static ContrastiveVAE.latent_kl_divergence(variational_mean, variational_var, prior_mean, prior_var)[source]#
Computes KL divergence between a variational posterior and prior Gaussian.
- Parameters:
variational_mean (
Tensor
) – Mean of the variational posterior Gaussian.variational_var (
Tensor
) – Variance of the variational posterior Gaussian.prior_mean (
Tensor
) – Mean of the prior Gaussian.prior_var (
Tensor
) – Variance of the prior Gaussian.
- Return type:
Tensor
- Returns:
KL divergence for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).
- ContrastiveVAE.library_kl_divergence(batch_index, variational_library_mean, variational_library_var, library)[source]#
Computes KL divergence between library size variational posterior and prior.
Both the variational posterior and prior are Log-Normal.
- Parameters:
batch_index (
Tensor
) – Batch indices for batch-specific library size mean and variance.variational_library_mean (
Tensor
) – Mean of variational Log-Normal.variational_library_var (
Tensor
) – Variance of variational Log-Normal.library (
Tensor
) – Sampled library size.
- Return type:
Tensor
- Returns:
KL divergence for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).
- ContrastiveVAE.loss(concat_tensors, inference_outputs, generative_outputs, kl_weight=1.0)[source]#
Computes loss terms for contrastiveVI.
- Parameters:
concat_tensors (
dict
[str
,dict
[str
,Tensor
]]) – Tuple of data mini-batch. The first element contains background data mini-batch. The second element contains target data mini-batch.inference_outputs (
dict
[str
,dict
[str
,Tensor
]]) – Dictionary of inference step outputs. The keys are “background” and “target” for the corresponding outputs.generative_outputs (
dict
[str
,dict
[str
,Tensor
]]) – Dictionary of generative step outputs. The keys are “background” and “target” for the corresponding outputs.kl_weight (
float
(default:1.0
)) – Importance weight for KL divergence of background and salient latent variables, relative to KL divergence of library size.
- Return type:
- Returns:
An scvi.module.base.LossOutput instance that records the following: loss
One-dimensional tensor for overall loss used for optimization.
- reconstruction_loss
Reconstruction loss with shape (n_samples, batch_size) if number of latent samples > 1, or (batch_size, ) if number of latent samples == 1.
- kl_local
KL divergence term with shape (n_samples, batch_size) if number of latent samples > 1, or (batch_size, ) if number of latent samples == 1.
- static ContrastiveVAE.reconstruction_loss(x, px_rate, px_r, px_dropout)[source]#
Computes likelihood loss for zero-inflated negative binomial distribution.
- Parameters:
x (
Tensor
) – Input data.px_rate (
Tensor
) – Mean of distribution.px_r (
Tensor
) – Inverse dispersion.px_dropout (
Tensor
) – Logits scale of zero inflation probability.
- Return type:
Tensor
- Returns:
Negative log likelihood (reconstruction loss) for each data point. If number of latent samples == 1, the tensor has shape (batch_size, ). If number of latent samples > 1, the tensor has shape (n_samples, batch_size).