# -*- coding: utf-8 -*-
"""Main module."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence as kl
from scvi.models.distributions import (
ZeroInflatedNegativeBinomial,
NegativeBinomial,
Poisson,
)
from scvi.models.modules import Encoder, DecoderSCVI, LinearDecoderSCVI
from scvi.models.utils import one_hot
from typing import Tuple, Dict
torch.backends.cudnn.benchmark = True
# VAE model
[docs]class VAE(nn.Module):
"""Variational auto-encoder model.
This is an implementation of the scVI model descibed in [Lopez18]_
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches, if 0, no batch correction is performed.
n_labels
Number of labels
n_hidden
Number of nodes per hidden layer
n_latent
Dimensionality of the latent space
n_layers
Number of hidden layers used for encoder and decoder NNs
dropout_rate
Dropout rate for neural networks
dispersion
One of the following
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
reconstruction_loss
One of
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
Examples
--------
>>> gene_dataset = CortexDataset()
>>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
... n_labels=gene_dataset.n_labels)
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: str = "gene",
log_variational: bool = True,
reconstruction_loss: str = "zinb",
latent_distribution: str = "normal",
):
super().__init__()
self.dispersion = dispersion
self.n_latent = n_latent
self.log_variational = log_variational
self.reconstruction_loss = reconstruction_loss
# Automatically deactivate if useless
self.n_batch = n_batch
self.n_labels = n_labels
self.latent_distribution = latent_distribution
if self.dispersion == "gene":
self.px_r = torch.nn.Parameter(torch.randn(n_input))
elif self.dispersion == "gene-batch":
self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
elif self.dispersion == "gene-label":
self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
elif self.dispersion == "gene-cell":
pass
else:
raise ValueError(
"dispersion must be one of ['gene', 'gene-batch',"
" 'gene-label', 'gene-cell'], but input was "
"{}.format(self.dispersion)"
)
# z encoder goes from the n_input-dimensional data to an n_latent-d
# latent space representation
self.z_encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
)
# l encoder goes from n_input-dimensional data to 1-d library size
self.l_encoder = Encoder(
n_input, 1, n_layers=1, n_hidden=n_hidden, dropout_rate=dropout_rate
)
# decoder goes from n_latent-dimensional space to n_input-d data
self.decoder = DecoderSCVI(
n_latent,
n_input,
n_cat_list=[n_batch],
n_layers=n_layers,
n_hidden=n_hidden,
)
[docs] def get_latents(self, x, y=None) -> torch.Tensor:
"""Returns the result of ``sample_from_posterior_z`` inside a list
Parameters
----------
x
tensor of values with shape ``(batch_size, n_input)``
y
tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
Returns
-------
type
one element list of tensor
"""
return [self.sample_from_posterior_z(x, y)]
[docs] def sample_from_posterior_z(
self, x, y=None, give_mean=False, n_samples=5000
) -> torch.Tensor:
"""Samples the tensor of latent values from the posterior
Parameters
----------
x
tensor of values with shape ``(batch_size, n_input)``
y
tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
give_mean
is True when we want the mean of the posterior distribution rather than sampling (Default value = False)
n_samples
how many MC samples to average over for transformed mean (Default value = 5000)
Returns
-------
type
tensor of shape ``(batch_size, n_latent)``
"""
if self.log_variational:
x = torch.log(1 + x)
qz_m, qz_v, z = self.z_encoder(x, y) # y only used in VAEC
if give_mean:
if self.latent_distribution == "ln":
samples = Normal(qz_m, qz_v.sqrt()).sample([n_samples])
z = self.z_encoder.z_transformation(samples)
z = z.mean(dim=0)
else:
z = qz_m
return z
[docs] def sample_from_posterior_l(self, x) -> torch.Tensor:
"""Samples the tensor of library sizes from the posterior
Parameters
----------
x
tensor of values with shape ``(batch_size, n_input)``
y
tensor of cell-types labels with shape ``(batch_size, n_labels)``
Returns
-------
type
tensor of shape ``(batch_size, 1)``
"""
if self.log_variational:
x = torch.log(1 + x)
ql_m, ql_v, library = self.l_encoder(x)
return library
[docs] def get_sample_scale(
self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
) -> torch.Tensor:
"""Returns the tensor of predicted frequencies of expression
Parameters
----------
x
tensor of values with shape ``(batch_size, n_input)``
batch_index
array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
y
tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
n_samples
number of samples (Default value = 1)
transform_batch
int of batch to transform samples into (Default value = None)
Returns
-------
type
tensor of predicted frequencies of expression with shape ``(batch_size, n_input)``
"""
return self.inference(
x,
batch_index=batch_index,
y=y,
n_samples=n_samples,
transform_batch=transform_batch,
)["px_scale"]
[docs] def get_sample_rate(
self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
) -> torch.Tensor:
"""Returns the tensor of means of the negative binomial distribution
Parameters
----------
x
tensor of values with shape ``(batch_size, n_input)``
y
tensor of cell-types labels with shape ``(batch_size, n_labels)`` (Default value = None)
batch_index
array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
n_samples
number of samples (Default value = 1)
transform_batch
int of batch to transform samples into (Default value = None)
Returns
-------
type
tensor of means of the negative binomial distribution with shape ``(batch_size, n_input)``
"""
return self.inference(
x,
batch_index=batch_index,
y=y,
n_samples=n_samples,
transform_batch=transform_batch,
)["px_rate"]
[docs] def get_reconstruction_loss(
self, x, px_rate, px_r, px_dropout, **kwargs
) -> torch.Tensor:
# Reconstruction Loss
if self.reconstruction_loss == "zinb":
reconst_loss = (
-ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
)
.log_prob(x)
.sum(dim=-1)
)
elif self.reconstruction_loss == "nb":
reconst_loss = (
-NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)
)
elif self.reconstruction_loss == "poisson":
reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
return reconst_loss
[docs] def inference(
self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
) -> Dict[str, torch.Tensor]:
"""Helper function used in forward pass
"""
x_ = x
if self.log_variational:
x_ = torch.log(1 + x_)
# Sampling
qz_m, qz_v, z = self.z_encoder(x_, y)
ql_m, ql_v, library = self.l_encoder(x_)
if n_samples > 1:
qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
# when z is normal, untran_z == z
untran_z = Normal(qz_m, qz_v.sqrt()).sample()
z = self.z_encoder.z_transformation(untran_z)
ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
library = Normal(ql_m, ql_v.sqrt()).sample()
if transform_batch is not None:
dec_batch_index = transform_batch * torch.ones_like(batch_index)
else:
dec_batch_index = batch_index
px_scale, px_r, px_rate, px_dropout = self.decoder(
self.dispersion, z, library, dec_batch_index, y
)
if self.dispersion == "gene-label":
px_r = F.linear(
one_hot(y, self.n_labels), self.px_r
) # px_r gets transposed - last dimension is nb genes
elif self.dispersion == "gene-batch":
px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r)
elif self.dispersion == "gene":
px_r = self.px_r
px_r = torch.exp(px_r)
return dict(
px_scale=px_scale,
px_r=px_r,
px_rate=px_rate,
px_dropout=px_dropout,
qz_m=qz_m,
qz_v=qz_v,
z=z,
ql_m=ql_m,
ql_v=ql_v,
library=library,
)
[docs] def forward(
self, x, local_l_mean, local_l_var, batch_index=None, y=None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the reconstruction loss and the KL divergences
Parameters
----------
x
tensor of values with shape (batch_size, n_input)
local_l_mean
tensor of means of the prior distribution of latent variable l
with shape (batch_size, 1)
local_l_var
tensor of variancess of the prior distribution of latent variable l
with shape (batch_size, 1)
batch_index
array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
y
tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None)
Returns
-------
type
the reconstruction loss and the Kullback divergences
"""
# Parameters for z latent distribution
outputs = self.inference(x, batch_index, y)
qz_m = outputs["qz_m"]
qz_v = outputs["qz_v"]
ql_m = outputs["ql_m"]
ql_v = outputs["ql_v"]
px_rate = outputs["px_rate"]
px_r = outputs["px_r"]
px_dropout = outputs["px_dropout"]
# KL Divergence
mean = torch.zeros_like(qz_m)
scale = torch.ones_like(qz_v)
kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
dim=1
)
kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_l_mean, torch.sqrt(local_l_var)),
).sum(dim=1)
kl_divergence = kl_divergence_z
reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)
return reconst_loss + kl_divergence_l, kl_divergence, 0.0
[docs]class LDVAE(VAE):
"""Linear-decoded Variational auto-encoder model.
Implementation of [Svensson20]_.
This model uses a linear decoder, directly mapping the latent representation
to gene expression levels. It still uses a deep neural network to encode
the latent representation.
Compared to standard VAE, this model is less powerful, but can be used to
inspect which genes contribute to variation in the dataset. It may also be used
for all scVI tasks, like differential expression, batch correction, imputation, etc.
However, batch correction may be less powerful as it assumes a linear model.
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches
n_labels
Number of labels
n_hidden
Number of nodes per hidden layer (for encoder)
n_latent
Dimensionality of the latent space
n_layers_encoder
Number of hidden layers used for encoder NNs
dropout_rate
Dropout rate for neural networks
dispersion
One of the following
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
reconstruction_loss
One of
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
use_batch_norm
Bool whether to use batch norm in decoder
bias
Bool whether to have bias term in linear decoder
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers_encoder: int = 1,
dropout_rate: float = 0.1,
dispersion: str = "gene",
log_variational: bool = True,
reconstruction_loss: str = "nb",
use_batch_norm: bool = True,
bias: bool = False,
latent_distribution: str = "normal",
):
super().__init__(
n_input,
n_batch,
n_labels,
n_hidden,
n_latent,
n_layers_encoder,
dropout_rate,
dispersion,
log_variational,
reconstruction_loss,
latent_distribution,
)
self.use_batch_norm = use_batch_norm
self.z_encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers_encoder,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution=latent_distribution,
)
self.decoder = LinearDecoderSCVI(
n_latent,
n_input,
n_cat_list=[n_batch],
use_batch_norm=use_batch_norm,
bias=bias,
)
[docs] @torch.no_grad()
def get_loadings(self) -> np.ndarray:
"""Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder."""
# This is BW, where B is diag(b) batch norm, W is weight matrix
if self.use_batch_norm is True:
w = self.decoder.factor_regressor.fc_layers[0][0].weight
bn = self.decoder.factor_regressor.fc_layers[0][1]
sigma = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
b = gamma / sigma
bI = torch.diag(b)
loadings = torch.matmul(bI, w)
else:
loadings = self.decoder.factor_regressor.fc_layers[0][0].weight
loadings = loadings.detach().cpu().numpy()
if self.n_batch > 1:
loadings = loadings[:, : -self.n_batch]
return loadings