Source code for scvi.models.totalvi

# -*- coding: utf-8 -*-
"""Main module."""
from typing import Dict, Optional, Tuple, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli, kl_divergence as kl

from scvi.models.distributions import ZeroInflatedNegativeBinomial, NegativeBinomial
from scvi.models.log_likelihood import log_mixture_nb
from scvi.models.modules import DecoderTOTALVI, EncoderTOTALVI
from scvi.models.utils import one_hot
import numpy as np

torch.backends.cudnn.benchmark = True


# VAE model
[docs]class TOTALVI(nn.Module): """Total variational inference for CITE-seq data Implements the totalVI model of [GayosoSteier20]_. Parameters ---------- n_input_genes Number of input genes n_input_proteins Number of input proteins n_batch Number of batches n_labels Number of labels n_hidden Number of nodes per hidden layer for the z encoder (protein+genes), genes library encoder, z->genes+proteins decoder n_latent Dimensionality of the latent space n_layers Number of hidden layers used for encoder and decoder NNs dropout_rate Dropout rate for neural networks genes_dispersion One of the following * ``'gene'`` - genes_dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - genes_dispersion can differ between different batches * ``'gene-label'`` - genes_dispersion can differ between different labels protein_dispersion One of the following * ``'protein'`` - protein_dispersion parameter is constant per protein across cells * ``'protein-batch'`` - protein_dispersion can differ between different batches NOT TESTED * ``'protein-label'`` - protein_dispersion can differ between different labels NOT TESTED log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. reconstruction_loss_genes One of * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution latent_distribution One of * ``'normal'`` - Isotropic normal * ``'ln'`` - Logistic normal with normal params N(0, 1) Examples: Returns ------- >>> dataset = Dataset10X(dataset_name="pbmc_10k_protein_v3", save_path=save_path) >>> totalvae = TOTALVI(gene_dataset.nb_genes, len(dataset.protein_names), use_cuda=True) """ def __init__( self, n_input_genes: int, n_input_proteins: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 256, n_latent: int = 20, n_layers_encoder: int = 1, n_layers_decoder: int = 1, dropout_rate_decoder: float = 0.2, dropout_rate_encoder: float = 0.2, gene_dispersion: str = "gene", protein_dispersion: str = "protein", log_variational: bool = True, reconstruction_loss_gene: str = "nb", latent_distribution: str = "ln", protein_batch_mask: List[np.ndarray] = None, encoder_batch: bool = True, ): super().__init__() self.gene_dispersion = gene_dispersion self.n_latent = n_latent self.log_variational = log_variational self.reconstruction_loss_gene = reconstruction_loss_gene self.n_batch = n_batch self.n_labels = n_labels self.n_input_genes = n_input_genes self.n_input_proteins = n_input_proteins self.protein_dispersion = protein_dispersion self.latent_distribution = latent_distribution self.protein_batch_mask = protein_batch_mask # parameters for prior on rate_back (background protein mean) if n_batch > 0: self.background_pro_alpha = torch.nn.Parameter( torch.randn(n_input_proteins, n_batch) ) self.background_pro_log_beta = torch.nn.Parameter( torch.clamp(torch.randn(n_input_proteins, n_batch), -10, 1) ) else: self.background_pro_alpha = torch.nn.Parameter( torch.randn(n_input_proteins) ) self.background_pro_log_beta = torch.nn.Parameter( torch.clamp(torch.randn(n_input_proteins), -10, 1) ) if self.gene_dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(n_input_genes)) elif self.gene_dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_batch)) elif self.gene_dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(n_input_genes, n_labels)) else: # gene-cell pass if self.protein_dispersion == "protein": self.py_r = torch.nn.Parameter(torch.ones(self.n_input_proteins)) elif self.protein_dispersion == "protein-batch": self.py_r = torch.nn.Parameter(torch.ones(self.n_input_proteins, n_batch)) elif self.protein_dispersion == "protein-label": self.py_r = torch.nn.Parameter(torch.ones(self.n_input_proteins, n_labels)) else: # protein-cell pass # z encoder goes from the n_input-dimensional data to an n_latent-d # latent space representation self.encoder = EncoderTOTALVI( n_input_genes + self.n_input_proteins, n_latent, n_layers=n_layers_encoder, n_cat_list=[n_batch] if encoder_batch else None, n_hidden=n_hidden, dropout_rate=dropout_rate_encoder, distribution=latent_distribution, ) self.decoder = DecoderTOTALVI( n_latent, n_input_genes, self.n_input_proteins, n_layers=n_layers_decoder, n_cat_list=[n_batch], n_hidden=n_hidden, dropout_rate=dropout_rate_decoder, )
[docs] def sample_from_posterior_z( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, give_mean: bool = False, n_samples: int = 5000, ) -> torch.Tensor: """Access the tensor of latent values from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` batch_index tensor of batch indices give_mean Whether to sample, or give mean of distribution Returns ------- type tensor of shape ``(batch_size, n_latent)`` """ if self.log_variational: x = torch.log(1 + x) y = torch.log(1 + y) qz_m, qz_v, _, _, latent, _ = self.encoder( torch.cat((x, y), dim=-1), batch_index ) z = latent["z"] if give_mean: if self.latent_distribution == "ln": samples = Normal(qz_m, qz_v.sqrt()).sample([n_samples]) z = self.encoder.z_transformation(samples) z = z.mean(dim=0) else: z = qz_m return z
[docs] def sample_from_posterior_l( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, give_mean: bool = True, ) -> torch.Tensor: """Provides the tensor of library size from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` Returns ------- type tensor of shape ``(batch_size, 1)`` """ if self.log_variational: x = torch.log(1 + x) y = torch.log(1 + y) _, _, ql_m, ql_v, latent, _ = self.encoder( torch.cat((x, y), dim=-1), batch_index ) library_gene = latent["l"] if give_mean is True: return torch.exp(ql_m + 0.5 * ql_v) else: return library_gene
[docs] def get_sample_rate( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples: int = 1, ) -> torch.Tensor: """Returns the tensor of negative binomial mean for genes Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` label tensor of cell-types labels with shape ``(batch_size, n_labels)`` n_samples number of samples Returns ------- type tensor of means of the negative binomial distribution with shape ``(batch_size, n_input_genes)`` """ outputs = self.inference( x, y, batch_index=batch_index, label=label, n_samples=n_samples ) rate = outputs["px_"]["rate"] return rate
[docs] def get_sample_dispersion( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Returns the tensors of dispersions for genes and proteins Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` label tensor of cell-types labels with shape ``(batch_size, n_labels)`` n_samples number of samples Returns ------- type tensors of dispersions of the negative binomial distribution """ outputs = self.inference( x, y, batch_index=batch_index, label=label, n_samples=n_samples ) px_r = outputs["px_"]["r"] py_r = outputs["py_"]["r"] return px_r, py_r
[docs] def get_sample_scale( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples: int = 1, transform_batch: Optional[int] = None, eps=0, normalize_pro=False, sample_bern=True, include_bg=False, ) -> torch.Tensor: """Returns tuple of gene and protein scales. These scales can also be transformed into a particular batch. This function is the core of differential expression. Parameters ---------- transform_batch Int of batch to "transform" all cells into eps Prior count to add to protein normalized expression (Default value = 0) normalize_pro bool, whether to make protein expression sum to one in a cell (Default value = False) include_bg bool, whether to include the background component of expression (Default value = False) Returns ------- """ outputs = self.inference( x, y, batch_index=batch_index, label=label, n_samples=n_samples, transform_batch=transform_batch, ) px_ = outputs["px_"] py_ = outputs["py_"] protein_mixing = 1 / (1 + torch.exp(-py_["mixing"])) if sample_bern is True: protein_mixing = Bernoulli(protein_mixing).sample() pro_value = (1 - protein_mixing) * py_["rate_fore"] if include_bg is True: pro_value = (1 - protein_mixing) * py_["rate_fore"] + protein_mixing * py_[ "rate_back" ] if normalize_pro is True: pro_value = torch.nn.functional.normalize(pro_value, p=1, dim=-1) return px_["scale"], pro_value + eps
[docs] def get_reconstruction_loss( self, x: torch.Tensor, y: torch.Tensor, px_: Dict[str, torch.Tensor], py_: Dict[str, torch.Tensor], pro_batch_mask_minibatch: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute reconstruction loss """ # Reconstruction Loss if self.reconstruction_loss_gene == "zinb": reconst_loss_gene = ( -ZeroInflatedNegativeBinomial( mu=px_["rate"], theta=px_["r"], zi_logits=px_["dropout"] ) .log_prob(x) .sum(dim=-1) ) else: reconst_loss_gene = ( -NegativeBinomial(mu=px_["rate"], theta=px_["r"]) .log_prob(x) .sum(dim=-1) ) reconst_loss_protein_full = -log_mixture_nb( y, py_["rate_back"], py_["rate_fore"], py_["r"], None, py_["mixing"] ) if pro_batch_mask_minibatch is not None: temp_pro_loss_full = torch.zeros_like(reconst_loss_protein_full) temp_pro_loss_full.masked_scatter_( pro_batch_mask_minibatch.bool(), reconst_loss_protein_full ) reconst_loss_protein = temp_pro_loss_full.sum(dim=-1) else: reconst_loss_protein = reconst_loss_protein_full.sum(dim=-1) return reconst_loss_gene, reconst_loss_protein
[docs] def inference( self, x: torch.Tensor, y: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, n_samples=1, transform_batch: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Internal helper function to compute necessary inference quantities We use the dictionary ``px_`` to contain the parameters of the ZINB/NB for genes. The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters. `scale` refers to the quanity upon which differential expression is performed. For genes, this can be viewed as the mean of the underlying gamma distribution. We use the dictionary ``py_`` to contain the parameters of the Mixture NB distribution for proteins. `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. ``scale`` refers to foreground mean adjusted for background probability and scaled to reside in simplex. ``back_alpha`` and ``back_beta`` are the posterior parameters for ``rate_back``. ``fore_scale`` is the scaling factor that enforces `rate_fore` > `rate_back`. ``px_["r"]`` and ``py_["r"]`` are the inverse dispersion parameters for genes and protein, respectively. """ x_ = x y_ = y if self.log_variational: x_ = torch.log(1 + x_) y_ = torch.log(1 + y_) # Sampling - Encoder gets concatenated genes + proteins qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder( torch.cat((x_, y_), dim=-1), batch_index ) z = latent["z"] library_gene = latent["l"] untran_z = untran_latent["z"] untran_l = untran_latent["l"] if n_samples > 1: qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) untran_z = Normal(qz_m, qz_v.sqrt()).sample() z = self.encoder.z_transformation(untran_z) ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) untran_l = Normal(ql_m, ql_v.sqrt()).sample() library_gene = self.encoder.l_transformation(untran_l) if self.gene_dispersion == "gene-label": # px_r gets transposed - last dimension is nb genes px_r = F.linear(one_hot(label, self.n_labels), self.px_r) elif self.gene_dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.gene_dispersion == "gene": px_r = self.px_r px_r = torch.exp(px_r) if self.protein_dispersion == "protein-label": # py_r gets transposed - last dimension is n_proteins py_r = F.linear(one_hot(label, self.n_labels), self.py_r) elif self.protein_dispersion == "protein-batch": py_r = F.linear(one_hot(batch_index, self.n_batch), self.py_r) elif self.protein_dispersion == "protein": py_r = self.py_r py_r = torch.exp(py_r) # Background regularization if self.n_batch > 0: py_back_alpha_prior = F.linear( one_hot(batch_index, self.n_batch), self.background_pro_alpha ) py_back_beta_prior = F.linear( one_hot(batch_index, self.n_batch), torch.exp(self.background_pro_log_beta), ) else: py_back_alpha_prior = self.background_pro_alpha py_back_beta_prior = torch.exp(self.background_pro_log_beta) self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch px_, py_, log_pro_back_mean = self.decoder(z, library_gene, batch_index, label) px_["r"] = px_r py_["r"] = py_r return dict( px_=px_, py_=py_, qz_m=qz_m, qz_v=qz_v, z=z, untran_z=untran_z, ql_m=ql_m, ql_v=ql_v, library_gene=library_gene, untran_l=untran_l, log_pro_back_mean=log_pro_back_mean, )
[docs] def forward( self, x: torch.Tensor, y: torch.Tensor, local_l_mean_gene: torch.Tensor, local_l_var_gene: torch.Tensor, batch_index: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor ]: """Returns the reconstruction loss and the Kullback divergences Parameters ---------- x tensor of values with shape ``(batch_size, n_input_genes)`` y tensor of values with shape ``(batch_size, n_input_proteins)`` local_l_mean_gene tensor of means of the prior distribution of latent variable l with shape ``(batch_size, 1)```` local_l_var_gene tensor of variancess of the prior distribution of latent variable l with shape ``(batch_size, 1)`` batch_index array that indicates which batch the cells belong to with shape ``batch_size`` label tensor of cell-types labels with shape (batch_size, n_labels) Returns ------- type the reconstruction loss and the Kullback divergences """ # Parameters for z latent distribution outputs = self.inference(x, y, batch_index, label) qz_m = outputs["qz_m"] qz_v = outputs["qz_v"] ql_m = outputs["ql_m"] ql_v = outputs["ql_v"] px_ = outputs["px_"] py_ = outputs["py_"] if self.protein_batch_mask is not None: pro_batch_mask_minibatch = torch.zeros_like(y) for b in np.arange(len(torch.unique(batch_index))): b_indices = (batch_index == b).reshape(-1) pro_batch_mask_minibatch[b_indices] = torch.tensor( self.protein_batch_mask[b].astype(np.float32), device=y.device ) else: pro_batch_mask_minibatch = None reconst_loss_gene, reconst_loss_protein = self.get_reconstruction_loss( x, y, px_, py_, pro_batch_mask_minibatch ) # KL Divergence kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) kl_div_l_gene = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean_gene, torch.sqrt(local_l_var_gene)), ).sum(dim=1) kl_div_back_pro_full = kl( Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior ) if pro_batch_mask_minibatch is not None: kl_div_back_pro = (pro_batch_mask_minibatch * kl_div_back_pro_full).sum( dim=1 ) else: kl_div_back_pro = kl_div_back_pro_full.sum(dim=1) return ( reconst_loss_gene, reconst_loss_protein, kl_div_z, kl_div_l_gene, kl_div_back_pro, )