from typing import Sequence
import numpy as np
import torch
from torch.distributions import Normal, Categorical, kl_divergence as kl
from scvi.models.classifier import Classifier
from scvi.models.modules import Decoder, Encoder
from scvi.models.utils import broadcast_labels
from scvi.models.vae import VAE
[docs]class SCANVI(VAE):
"""Single-cell annotation using variational inference.
This is an implementation of the scANVI model descibed in [Xu19]_,
inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf).
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches
n_labels
Number of labels
n_hidden
Number of nodes per hidden layer
n_latent
Dimensionality of the latent space
n_layers
Number of hidden layers used for encoder and decoder NNs
dropout_rate
Dropout rate for neural networks
dispersion
One of the following
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
reconstruction_loss
One of
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
y_prior
If None, initialized to uniform probability over cell types
labels_groups
Label group designations
use_labels_groups
Whether to use the label groups
Returns
-------
Examples
--------
>>> gene_dataset = CortexDataset()
>>> scanvi = SCANVI(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
... n_labels=gene_dataset.n_labels)
>>> gene_dataset = SyntheticDataset(n_labels=3)
>>> scanvi = SCANVI(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
... n_labels=3, y_prior=torch.tensor([[0.1,0.5,0.4]]), labels_groups=[0,0,1])
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
dispersion: str = "gene",
log_variational: bool = True,
reconstruction_loss: str = "zinb",
y_prior=None,
labels_groups: Sequence[int] = None,
use_labels_groups: bool = False,
classifier_parameters: dict = dict(),
):
super().__init__(
n_input,
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
n_batch=n_batch,
dispersion=dispersion,
log_variational=log_variational,
reconstruction_loss=reconstruction_loss,
)
self.n_labels = n_labels
# Classifier takes n_latent as input
cls_parameters = {
"n_layers": n_layers,
"n_hidden": n_hidden,
"dropout_rate": dropout_rate,
}
cls_parameters.update(classifier_parameters)
self.classifier = Classifier(n_latent, n_labels=n_labels, **cls_parameters)
self.encoder_z2_z1 = Encoder(
n_latent,
n_latent,
n_cat_list=[self.n_labels],
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.decoder_z1_z2 = Decoder(
n_latent,
n_latent,
n_cat_list=[self.n_labels],
n_layers=n_layers,
n_hidden=n_hidden,
)
self.y_prior = torch.nn.Parameter(
y_prior
if y_prior is not None
else (1 / n_labels) * torch.ones(1, n_labels),
requires_grad=False,
)
self.use_labels_groups = use_labels_groups
self.labels_groups = (
np.array(labels_groups) if labels_groups is not None else None
)
if self.use_labels_groups:
assert labels_groups is not None, "Specify label groups"
unique_groups = np.unique(self.labels_groups)
self.n_groups = len(unique_groups)
assert (unique_groups == np.arange(self.n_groups)).all()
self.classifier_groups = Classifier(
n_latent, n_hidden, self.n_groups, n_layers, dropout_rate
)
self.groups_index = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.tensor(
(self.labels_groups == i).astype(np.uint8),
dtype=torch.uint8,
),
requires_grad=False,
)
for i in range(self.n_groups)
]
)
[docs] def classify(self, x):
if self.log_variational:
x = torch.log(1 + x)
qz_m, _, z = self.z_encoder(x)
# We classify using the inferred mean parameter of z_1 in the latent space
z = qz_m
if self.use_labels_groups:
w_g = self.classifier_groups(z)
unw_y = self.classifier(z)
w_y = torch.zeros_like(unw_y)
for i, group_index in enumerate(self.groups_index):
unw_y_g = unw_y[:, group_index]
w_y[:, group_index] = unw_y_g / (
unw_y_g.sum(dim=-1, keepdim=True) + 1e-8
)
w_y[:, group_index] *= w_g[:, [i]]
else:
w_y = self.classifier(z)
return w_y
[docs] def get_latents(self, x, y=None):
zs = super().get_latents(x)
qz2_m, qz2_v, z2 = self.encoder_z2_z1(zs[0], y)
if not self.training:
z2 = qz2_m
return [zs[0], z2]
[docs] def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
is_labelled = False if y is None else True
outputs = self.inference(x, batch_index, y)
px_r = outputs["px_r"]
px_rate = outputs["px_rate"]
px_dropout = outputs["px_dropout"]
qz1_m = outputs["qz_m"]
qz1_v = outputs["qz_v"]
z1 = outputs["z"]
ql_m = outputs["ql_m"]
ql_v = outputs["ql_v"]
# Enumerate choices of label
ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels)
qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)
reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)
# KL Divergence
mean = torch.zeros_like(qz2_m)
scale = torch.ones_like(qz2_v)
kl_divergence_z2 = kl(
Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)
).sum(dim=1)
loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_l_mean, torch.sqrt(local_l_var)),
).sum(dim=1)
if is_labelled:
return (
reconst_loss + loss_z1_weight + loss_z1_unweight,
kl_divergence_z2 + kl_divergence_l,
0.0,
)
probs = self.classifier(z1)
reconst_loss += loss_z1_weight + (
(loss_z1_unweight).view(self.n_labels, -1).t() * probs
).sum(dim=1)
kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(
dim=1
)
kl_divergence += kl(
Categorical(probs=probs),
Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
)
kl_divergence += kl_divergence_l
return reconst_loss, kl_divergence, 0.0