import torch
import torch.nn.functional as F
from torch.distributions import Normal, Beta, Gamma, kl_divergence as kl
import numpy as np
from scipy.special import logit
from scvi.models.distributions import ZeroInflatedNegativeBinomial, NegativeBinomial
from scvi.models.vae import VAE
from scvi.models.utils import one_hot
from typing import Dict, Optional, Tuple, Union
torch.backends.cudnn.benchmark = True
[docs]class AutoZIVAE(VAE):
r"""AutoZI variational auto-encoder model.
Implementation of AutoZI model [Clivio19]_.
n_input
Number of input genes
alpha_prior
Float denoting the alpha parameter of the prior Beta distribution of
the zero-inflation Bernoulli parameter. Should be between 0 and 1, not included.
When set to ``None'', will be set to 1 - beta_prior if beta_prior is not ``None'',
otherwise the prior Beta distribution will be learned on an Empirical Bayes fashion.
beta_prior
Float denoting the beta parameter of the prior Beta distribution of
the zero-inflation Bernoulli parameter. Should be between 0 and 1, not included.
When set to ``None'', will be set to 1 - alpha_prior if alpha_prior is not ``None'',
otherwise the prior Beta distribution will be learned on an Empirical Bayes fashion.
minimal_dropout
Float denoting the lower bound of the cell-gene ZI rate in the ZINB component.
Must be non-negative. Can be set to 0 but not recommended as this may make
the mixture problem ill-defined.
zero_inflation: One of the following
* ``'gene'`` - zero-inflation Bernoulli parameter of AutoZI is constant per gene across cells
* ``'gene-batch'`` - zero-inflation Bernoulli parameter can differ between different batches
* ``'gene-label'`` - zero-inflation Bernoulli parameter can differ between different labels
* ``'gene-cell'`` - zero-inflation Bernoulli parameter can differ for every gene in every cell
See VAE docstring (scvi/models/vae.py) for more parameters. ``reconstruction_loss`` should not be specified.
Examples
--------
>>> gene_dataset = CortexDataset()
>>> autozivae = AutoZIVAE(gene_dataset.nb_genes, alpha_prior=0.5, beta_prior=0.5, minimal_dropout=0.01)
"""
def __init__(
self,
n_input: int,
alpha_prior: Optional[float] = 0.5,
beta_prior: Optional[float] = 0.5,
minimal_dropout: float = 0.01,
zero_inflation: str = "gene",
**args,
) -> None:
if "reconstruction_loss" in args:
raise ValueError(
"No reconstruction loss must be specified for AutoZI : it is 'autozinb'."
)
super().__init__(n_input, **args)
self.zero_inflation = zero_inflation
self.reconstruction_loss = "autozinb"
self.minimal_dropout = minimal_dropout
# Parameters of prior Bernoulli Beta distribution : alpha + beta = 1 if only one is specified
if beta_prior is None and alpha_prior is not None:
beta_prior = 1.0 - alpha_prior
if alpha_prior is None and beta_prior is not None:
alpha_prior = 1.0 - beta_prior
# Create parameters for Bernoulli Beta prior and posterior distributions
# Each parameter, whose values are in (0,1), is encoded as its logit, in the set of real numbers
if self.zero_inflation == "gene":
self.alpha_posterior_logit = torch.nn.Parameter(torch.randn(n_input))
self.beta_posterior_logit = torch.nn.Parameter(torch.randn(n_input))
self.alpha_prior_logit = (
torch.nn.Parameter(torch.randn(1))
if alpha_prior is None
else torch.Tensor([logit(alpha_prior)])
)
self.beta_prior_logit = (
torch.nn.Parameter(torch.randn(1))
if beta_prior is None
else torch.Tensor([logit(beta_prior)])
)
elif self.zero_inflation == "gene-batch":
self.alpha_posterior_logit = torch.nn.Parameter(
torch.randn(n_input, self.n_batch)
)
self.beta_posterior_logit = torch.nn.Parameter(
torch.randn(n_input, self.n_batch)
)
self.alpha_prior_logit = (
torch.nn.Parameter(torch.randn(1, self.n_batch))
if alpha_prior is None
else torch.Tensor([logit(alpha_prior)])
)
self.beta_prior_logit = (
torch.nn.Parameter(torch.randn(1, self.n_batch))
if beta_prior is None
else torch.Tensor([logit(beta_prior)])
)
elif self.zero_inflation == "gene-label":
self.alpha_posterior_logit = torch.nn.Parameter(
torch.randn(n_input, self.n_labels)
)
self.beta_posterior_logit = torch.nn.Parameter(
torch.randn(n_input, self.n_labels)
)
self.alpha_prior_logit = (
torch.nn.Parameter(torch.randn(1, self.n_labels))
if alpha_prior is None
else torch.Tensor([logit(alpha_prior)])
)
self.beta_prior_logit = (
torch.nn.Parameter(torch.randn(1, self.n_labels))
if beta_prior is None
else torch.Tensor([logit(beta_prior)])
)
else: # gene-cell
raise Exception("Gene-cell not implemented yet for AutoZI")
[docs] def cuda(self, device: Optional[str] = None) -> torch.nn.Module:
r""" Moves all model parameters and also fixed prior alpha and beta values, when relevant, to the GPU.
Parameters
----------
device
string denoting the GPU device on which parameters and prior distribution values are copied.
Returns
-------
"""
self = super().cuda(device)
if isinstance(self.alpha_prior_logit, torch.Tensor):
self.alpha_prior_logit = self.alpha_prior_logit.cuda(device)
if isinstance(self.beta_prior_logit, torch.Tensor):
self.beta_prior_logit = self.beta_prior_logit.cuda(device)
return self
[docs] def get_alphas_betas(
self, as_numpy: bool = True
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
# Return parameters of Bernoulli Beta distributions in a dictionary
outputs = {}
outputs["alpha_posterior"] = torch.sigmoid(self.alpha_posterior_logit)
outputs["beta_posterior"] = torch.sigmoid(self.beta_posterior_logit)
outputs["alpha_prior"] = torch.sigmoid(self.alpha_prior_logit)
outputs["beta_prior"] = torch.sigmoid(self.beta_prior_logit)
if as_numpy:
for key, value in outputs.items():
outputs[key] = (
value.detach().cpu().numpy()
if value.requires_grad
else value.cpu().numpy()
)
return outputs
[docs] def sample_from_beta_distribution(
self,
alpha: torch.Tensor,
beta: torch.Tensor,
eps_gamma: float = 1e-30,
eps_sample: float = 1e-7,
) -> torch.Tensor:
# Sample from a Beta distribution using the reparameterization trick.
# Problem : it is not implemented in CUDA yet
# Workaround : sample X and Y from Gamma(alpha,1) and Gamma(beta,1), the Beta sample is X/(X+Y)
# Warning : use logs and perform logsumexp to avoid numerical issues
# Sample from Gamma
sample_x_log = torch.log(Gamma(alpha, 1).rsample() + eps_gamma)
sample_y_log = torch.log(Gamma(beta, 1).rsample() + eps_gamma)
# Sum using logsumexp (note : eps_gamma is used to prevent numerical issues with perfect
# 0 and 1 final Beta samples
sample_xy_log_max = torch.max(sample_x_log, sample_y_log)
sample_xplusy_log = sample_xy_log_max + torch.log(
torch.exp(sample_x_log - sample_xy_log_max)
+ torch.exp(sample_y_log - sample_xy_log_max)
)
sample_log = sample_x_log - sample_xplusy_log
sample = eps_sample + (1 - 2 * eps_sample) * torch.exp(sample_log)
return sample
[docs] def reshape_bernoulli(
self,
bernoulli_params: torch.Tensor,
batch_index: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.zero_inflation == "gene-label":
one_hot_label = one_hot(y, self.n_labels)
# If we sampled several random Bernoulli parameters
if len(bernoulli_params.shape) == 2:
bernoulli_params = F.linear(one_hot_label, bernoulli_params)
else:
bernoulli_params_res = []
for sample in range(bernoulli_params.shape[0]):
bernoulli_params_res.append(
F.linear(one_hot_label, bernoulli_params[sample])
)
bernoulli_params = torch.stack(bernoulli_params_res)
elif self.zero_inflation == "gene-batch":
one_hot_batch = one_hot(batch_index, self.n_batch)
if len(bernoulli_params.shape) == 2:
bernoulli_params = F.linear(one_hot_batch, bernoulli_params)
# If we sampled several random Bernoulli parameters
else:
bernoulli_params_res = []
for sample in range(bernoulli_params.shape[0]):
bernoulli_params_res.append(
F.linear(one_hot_batch, bernoulli_params[sample])
)
bernoulli_params = torch.stack(bernoulli_params_res)
return bernoulli_params
[docs] def sample_bernoulli_params(
self,
batch_index: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
n_samples: int = 1,
) -> torch.Tensor:
outputs = self.get_alphas_betas(as_numpy=False)
alpha_posterior = outputs["alpha_posterior"]
beta_posterior = outputs["beta_posterior"]
if n_samples > 1:
alpha_posterior = (
alpha_posterior.unsqueeze(0).expand(
(n_samples, alpha_posterior.size(0))
)
if self.zero_inflation == "gene"
else alpha_posterior.unsqueeze(0).expand(
(n_samples, alpha_posterior.size(0), alpha_posterior.size(1))
)
)
beta_posterior = (
beta_posterior.unsqueeze(0).expand((n_samples, beta_posterior.size(0)))
if self.zero_inflation == "gene"
else beta_posterior.unsqueeze(0).expand(
(n_samples, beta_posterior.size(0), beta_posterior.size(1))
)
)
bernoulli_params = self.sample_from_beta_distribution(
alpha_posterior, beta_posterior
)
bernoulli_params = self.reshape_bernoulli(bernoulli_params, batch_index, y)
return bernoulli_params
[docs] def rescale_dropout(
self, px_dropout: torch.Tensor, eps_log: float = 1e-8
) -> torch.Tensor:
if self.minimal_dropout > 0.0:
dropout_prob_rescaled = self.minimal_dropout + (
1.0 - self.minimal_dropout
) * torch.sigmoid(px_dropout)
px_dropout_rescaled = torch.log(
dropout_prob_rescaled / (1.0 - dropout_prob_rescaled + eps_log)
)
else:
px_dropout_rescaled = px_dropout
return px_dropout_rescaled
[docs] def inference(
self,
x,
batch_index: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
n_samples: int = 1,
eps_log: float = 1e-8,
) -> Dict[str, torch.Tensor]:
outputs = super().inference(
x, batch_index=batch_index, y=y, n_samples=n_samples
)
# Rescale dropout
outputs["px_dropout"] = self.rescale_dropout(
outputs["px_dropout"], eps_log=eps_log
)
# Bernoulli parameters
outputs["bernoulli_params"] = self.sample_bernoulli_params(
batch_index, y, n_samples=n_samples
)
return outputs
[docs] def compute_global_kl_divergence(self) -> torch.Tensor:
outputs = self.get_alphas_betas(as_numpy=False)
alpha_posterior = outputs["alpha_posterior"]
beta_posterior = outputs["beta_posterior"]
alpha_prior = outputs["alpha_prior"]
beta_prior = outputs["beta_prior"]
return kl(
Beta(alpha_posterior, beta_posterior), Beta(alpha_prior, beta_prior)
).sum()
[docs] def get_reconstruction_loss(
self,
x: torch.Tensor,
px_rate: torch.Tensor,
px_r: torch.Tensor,
px_dropout: torch.Tensor,
bernoulli_params: torch.Tensor,
eps_log: float = 1e-8,
**kwargs,
) -> torch.Tensor:
# LLs for NB and ZINB
ll_zinb = torch.log(
1.0 - bernoulli_params + eps_log
) + ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
).log_prob(
x
)
ll_nb = torch.log(bernoulli_params + eps_log) + NegativeBinomial(
mu=px_rate, theta=px_r
).log_prob(x)
# Reconstruction loss using a logsumexp-type computation
ll_max = torch.max(ll_zinb, ll_nb)
ll_tot = ll_max + torch.log(
torch.exp(ll_nb - ll_max) + torch.exp(ll_zinb - ll_max)
)
reconst_loss = -ll_tot.sum(dim=-1)
return reconst_loss
[docs] def forward(
self,
x: torch.Tensor,
local_l_mean: torch.Tensor,
local_l_var: torch.Tensor,
batch_index: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r""" Returns the reconstruction loss and the Kullback divergences
Parameters
----------
x
tensor of values with shape (batch_size, n_input)
local_l_mean
tensor of means of the prior distribution of latent variable l
with shape (batch_size, 1)
local_l_var
tensor of variancess of the prior distribution of latent variable l
with shape (batch_size, 1)
batch_index
array that indicates which batch the cells belong to with shape ``batch_size``
y
tensor of cell-types labels with shape (batch_size, n_labels)
Returns
-------
2-tuple of :py:class:`torch.FloatTensor`
the reconstruction loss and the Kullback divergences
"""
# Parameters for z latent distribution
outputs = self.inference(x, batch_index, y)
qz_m = outputs["qz_m"]
qz_v = outputs["qz_v"]
ql_m = outputs["ql_m"]
ql_v = outputs["ql_v"]
px_rate = outputs["px_rate"]
px_r = outputs["px_r"]
px_dropout = outputs["px_dropout"]
bernoulli_params = outputs["bernoulli_params"]
# KL divergences wrt z_n,l_n
mean = torch.zeros_like(qz_m)
scale = torch.ones_like(qz_v)
kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
dim=1
)
kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_l_mean, torch.sqrt(local_l_var)),
).sum(dim=1)
# KL divergence wrt Bernoulli parameters
kl_divergence_bernoulli = self.compute_global_kl_divergence()
# Reconstruction loss
reconst_loss = self.get_reconstruction_loss(
x, px_rate, px_r, px_dropout, bernoulli_params
)
return reconst_loss + kl_divergence_l, kl_divergence_z, kl_divergence_bernoulli