Source code for scvi.models.modules

import collections
from typing import Iterable, List

import torch
from torch import nn as nn
from torch.distributions import Normal
from torch.nn import ModuleList

from scvi.models.utils import one_hot


[docs]def reparameterize_gaussian(mu, var): return Normal(mu, var.sqrt()).rsample()
[docs]def identity(x): return x
[docs]class FCLayers(nn.Module): """A helper class to build fully-connected layers for a neural network. Parameters ---------- n_in The dimensionality of the input n_out The dimensionality of the output n_cat_list A list containing, for each category of interest, the number of categories. Each category will be included using a one-hot encoding. n_layers The number of fully-connected hidden layers n_hidden The number of nodes per hidden layer dropout_rate Dropout rate to apply to each of the hidden layers use_batch_norm Whether to have `BatchNorm` layers or not use_relu Whether to have `ReLU` layers or not bias Whether to learn bias in linear layers or not """ def __init__( self, n_in: int, n_out: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, use_batch_norm: bool = True, use_relu: bool = True, bias: bool = True, ): super().__init__() layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out] if n_cat_list is not None: # n_cat = 1 will be ignored self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list] else: self.n_cat_list = [] self.fc_layers = nn.Sequential( collections.OrderedDict( [ ( "Layer {}".format(i), nn.Sequential( nn.Linear(n_in + sum(self.n_cat_list), n_out, bias=bias), # Below, 0.01 and 0.001 are the default values for `momentum` and `eps` from # the tensorflow implementation of batch norm; we're using those settings # here too so that the results match our old tensorflow code. The default # setting from pytorch would probably be fine too but we haven't tested that. nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001) if use_batch_norm else None, nn.ReLU() if use_relu else None, nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None, ), ) for i, (n_in, n_out) in enumerate( zip(layers_dim[:-1], layers_dim[1:]) ) ] ) )
[docs] def forward(self, x: torch.Tensor, *cat_list: int, instance_id: int = 0): """Forward computation on ``x``. Parameters ---------- x tensor of values with shape ``(n_in,)`` cat_list list of category membership(s) for this sample instance_id Use a specific conditional instance normalization (batchnorm) x: torch.Tensor Returns ------- py:class:`torch.Tensor` tensor of shape ``(n_out,)`` """ one_hot_cat_list = [] # for generality in this list many indices useless. assert len(self.n_cat_list) <= len( cat_list ), "nb. categorical args provided doesn't match init. params." for n_cat, cat in zip(self.n_cat_list, cat_list): assert not ( n_cat and cat is None ), "cat not provided while n_cat != 0 in init. params." if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: one_hot_cat = one_hot(cat, n_cat) else: one_hot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [one_hot_cat] for layers in self.fc_layers: for layer in layers: if layer is not None: if isinstance(layer, nn.BatchNorm1d): if x.dim() == 3: x = torch.cat( [(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0 ) else: x = layer(x) else: if isinstance(layer, nn.Linear): if x.dim() == 3: one_hot_cat_list_layer = [ o.unsqueeze(0).expand( (x.size(0), o.size(0), o.size(1)) ) for o in one_hot_cat_list ] else: one_hot_cat_list_layer = one_hot_cat_list x = torch.cat((x, *one_hot_cat_list_layer), dim=-1) x = layer(x) return x
# Encoder
[docs]class Encoder(nn.Module): """Encodes data of ``n_input`` dimensions into a latent space of ``n_output`` dimensions using a fully-connected neural network of ``n_hidden`` layers. Parameters ---------- n_input The dimensionality of the input (data space) n_output The dimensionality of the output (latent space) n_cat_list A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding n_layers The number of fully-connected hidden layers n_hidden The number of nodes per hidden layer :dropout_rate: Dropout rate to apply to each of the hidden layers distribution Distribution of z Returns ------- """ def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 128, dropout_rate: float = 0.1, distribution: str = "normal", ): super().__init__() self.distribution = distribution self.encoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.mean_encoder = nn.Linear(n_hidden, n_output) self.var_encoder = nn.Linear(n_hidden, n_output) if distribution == "ln": self.z_transformation = nn.Softmax(dim=-1) else: self.z_transformation = identity
[docs] def forward(self, x: torch.Tensor, *cat_list: int): """The forward computation for a single sample. #. Encodes the data into latent space using the encoder network #. Generates a mean \\( q_m \\) and variance \\( q_v \\) #. Samples a new value from an i.i.d. multivariate normal \\( \\sim Ne(q_m, \\mathbf{I}q_v) \\) Parameters ---------- x tensor with shape (n_input,) cat_list list of category membership(s) for this sample Returns ------- 3-tuple of :py:class:`torch.Tensor` tensors of shape ``(n_latent,)`` for mean and var, and sample """ # Parameters for latent distribution q = self.encoder(x, *cat_list) q_m = self.mean_encoder(q) q_v = torch.exp(self.var_encoder(q)) + 1e-4 latent = self.z_transformation(reparameterize_gaussian(q_m, q_v)) return q_m, q_v, latent
# Decoder
[docs]class DecoderSCVI(nn.Module): """Decodes data from latent space of ``n_input`` dimensions ``n_output`` dimensions using a fully-connected neural network of ``n_hidden`` layers. Parameters ---------- n_input The dimensionality of the input (latent space) n_output The dimensionality of the output (data space) n_cat_list A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding n_layers The number of fully-connected hidden layers n_hidden The number of nodes per hidden layer dropout_rate Dropout rate to apply to each of the hidden layers Returns ------- """ def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 128, ): super().__init__() self.px_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, ) # mean gamma self.px_scale_decoder = nn.Sequential( nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1) ) # dispersion: here we only deal with gene-cell dispersion case self.px_r_decoder = nn.Linear(n_hidden, n_output) # dropout self.px_dropout_decoder = nn.Linear(n_hidden, n_output)
[docs] def forward( self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int ): """The forward computation for a single sample. #. Decodes the data from the latent space using the decoder network #. Returns parameters for the ZINB distribution of expression #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None`` Parameters ---------- dispersion One of the following * ``'gene'`` - dispersion parameter of NB is constant per gene across cells * ``'gene-batch'`` - dispersion can differ between different batches * ``'gene-label'`` - dispersion can differ between different labels * ``'gene-cell'`` - dispersion can differ for every gene in every cell z : tensor with shape ``(n_input,)`` library library size cat_list list of category membership(s) for this sample Returns ------- 4-tuple of :py:class:`torch.Tensor` parameters for the ZINB distribution of expression """ # The decoder returns values for the parameters of the ZINB distribution px = self.px_decoder(z, *cat_list) px_scale = self.px_scale_decoder(px) px_dropout = self.px_dropout_decoder(px) # Clamp to high value: exp(12) ~ 160000 to avoid nans (computational stability) px_rate = torch.exp(library) * px_scale # torch.clamp( , max=12) px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None return px_scale, px_r, px_rate, px_dropout
[docs]class LinearDecoderSCVI(nn.Module): def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, use_batch_norm: bool = True, bias: bool = False, ): super(LinearDecoderSCVI, self).__init__() # mean gamma self.factor_regressor = FCLayers( n_in=n_input, n_out=n_output, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=use_batch_norm, bias=bias, dropout_rate=0, ) # dropout self.px_dropout_decoder = FCLayers( n_in=n_input, n_out=n_output, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=use_batch_norm, bias=bias, dropout_rate=0, )
[docs] def forward( self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int ): # The decoder returns values for the parameters of the ZINB distribution raw_px_scale = self.factor_regressor(z, *cat_list) px_scale = torch.softmax(raw_px_scale, dim=-1) px_dropout = self.px_dropout_decoder(z, *cat_list) px_rate = torch.exp(library) * px_scale px_r = None return px_scale, px_r, px_rate, px_dropout
# Decoder
[docs]class Decoder(nn.Module): """Decodes data from latent space to data space ``n_input`` dimensions to ``n_output`` dimensions using a fully-connected neural network of ``n_hidden`` layers. Output is the mean and variance of a multivariate Gaussian Parameters ---------- n_input The dimensionality of the input (latent space) n_output The dimensionality of the output (data space) n_cat_list A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding n_layers The number of fully-connected hidden layers n_hidden The number of nodes per hidden layer dropout_rate Dropout rate to apply to each of the hidden layers Returns ------- """ def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 128, ): super().__init__() self.decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0, ) self.mean_decoder = nn.Linear(n_hidden, n_output) self.var_decoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, *cat_list: int): """The forward computation for a single sample. #. Decodes the data from the latent space using the decoder network #. Returns tensors for the mean and variance of a multivariate distribution Parameters ---------- x tensor with shape ``(n_input,)`` cat_list list of category membership(s) for this sample Returns ------- 2-tuple of :py:class:`torch.Tensor` Mean and variance tensors of shape ``(n_output,)`` """ # Parameters for latent distribution p = self.decoder(x, *cat_list) p_m = self.mean_decoder(p) p_v = torch.exp(self.var_decoder(p)) return p_m, p_v
[docs]class MultiEncoder(nn.Module): def __init__( self, n_heads: int, n_input_list: List[int], n_output: int, n_hidden: int = 128, n_layers_individual: int = 1, n_layers_shared: int = 2, n_cat_list: Iterable[int] = None, dropout_rate: float = 0.1, ): super().__init__() self.encoders = ModuleList( [ FCLayers( n_in=n_input_list[i], n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers_individual, n_hidden=n_hidden, dropout_rate=dropout_rate, use_batch_norm=True, ) for i in range(n_heads) ] ) self.encoder_shared = FCLayers( n_in=n_hidden, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers_shared, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.mean_encoder = nn.Linear(n_hidden, n_output) self.var_encoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q = self.encoders[head_id](x, *cat_list) q = self.encoder_shared(q, *cat_list) q_m = self.mean_encoder(q) q_v = torch.exp(self.var_encoder(q)) latent = reparameterize_gaussian(q_m, q_v) return q_m, q_v, latent
[docs]class MultiDecoder(nn.Module): def __init__( self, n_input: int, n_output: int, n_hidden_conditioned: int = 32, n_hidden_shared: int = 128, n_layers_conditioned: int = 1, n_layers_shared: int = 1, n_cat_list: Iterable[int] = None, dropout_rate: float = 0.2, ): super().__init__() n_out = n_hidden_conditioned if n_layers_shared else n_hidden_shared if n_layers_conditioned: self.px_decoder_conditioned = FCLayers( n_in=n_input, n_out=n_out, n_cat_list=n_cat_list, n_layers=n_layers_conditioned, n_hidden=n_hidden_conditioned, dropout_rate=dropout_rate, use_batch_norm=True, ) n_in = n_out else: self.px_decoder_conditioned = None n_in = n_input if n_layers_shared: self.px_decoder_final = FCLayers( n_in=n_in, n_out=n_hidden_shared, n_cat_list=[], n_layers=n_layers_shared, n_hidden=n_hidden_shared, dropout_rate=dropout_rate, use_batch_norm=True, ) n_in = n_hidden_shared else: self.px_decoder_final = None self.px_scale_decoder = nn.Sequential( nn.Linear(n_in, n_output), nn.Softmax(dim=-1) ) self.px_r_decoder = nn.Linear(n_in, n_output) self.px_dropout_decoder = nn.Linear(n_in, n_output)
[docs] def forward( self, z: torch.Tensor, dataset_id: int, library: torch.Tensor, dispersion: str, *cat_list: int ): px = z if self.px_decoder_conditioned: px = self.px_decoder_conditioned(px, *cat_list, instance_id=dataset_id) if self.px_decoder_final: px = self.px_decoder_final(px, *cat_list) px_scale = self.px_scale_decoder(px) px_dropout = self.px_dropout_decoder(px) px_rate = torch.exp(library) * px_scale px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None return px_scale, px_r, px_rate, px_dropout
[docs]class DecoderTOTALVI(nn.Module): """Decodes data from latent space of ``n_input`` dimensions ``n_output`` dimensions using a linear decoder Parameters ---------- n_input The dimensionality of the input (latent space) n_output_genes The dimensionality of the output (gene space) n_output_proteins The dimensionality of the output (protein space) n_cat_list A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding Returns ------- """ def __init__( self, n_input: int, n_output_genes: int, n_output_proteins: int, n_cat_list: Iterable[int] = None, n_layers: int = 1, n_hidden: int = 256, dropout_rate: float = 0, ): super().__init__() self.n_output_genes = n_output_genes self.n_output_proteins = n_output_proteins self.px_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # mean gamma self.px_scale_decoder = FCLayers( n_in=n_hidden + n_input, n_out=n_output_genes, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=False, dropout_rate=0, ) # background mean first decoder self.py_back_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # background mean parameters second decoder self.py_back_mean_log_alpha = FCLayers( n_in=n_hidden + n_input, n_out=n_output_proteins, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=False, dropout_rate=0, ) self.py_back_mean_log_beta = FCLayers( n_in=n_hidden + n_input, n_out=n_output_proteins, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=False, dropout_rate=0, ) # foreground increment decoder step 1 self.py_fore_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) # foreground increment decoder step 2 self.py_fore_scale_decoder = FCLayers( n_in=n_hidden + n_input, n_out=n_output_proteins, n_cat_list=n_cat_list, n_layers=1, use_relu=True, use_batch_norm=False, dropout_rate=0, ) # dropout (mixture component for proteins, ZI probability for genes) self.sigmoid_decoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.px_dropout_decoder_gene = FCLayers( n_in=n_hidden + n_input, n_out=n_output_genes, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=False, dropout_rate=0, ) self.py_background_decoder = FCLayers( n_in=n_hidden + n_input, n_out=n_output_proteins, n_cat_list=n_cat_list, n_layers=1, use_relu=False, use_batch_norm=False, dropout_rate=0, )
[docs] def forward(self, z: torch.Tensor, library_gene: torch.Tensor, *cat_list: int): """The forward computation for a single sample. #. Decodes the data from the latent space using the decoder network #. Returns local parameters for the ZINB distribution for genes #. Returns local parameters for the Mixture NB distribution for proteins We use the dictionary `px_` to contain the parameters of the ZINB/NB for genes. The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters. `scale` refers to the quanity upon which differential expression is performed. For genes, this can be viewed as the mean of the underlying gamma distribution. We use the dictionary `py_` to contain the parameters of the Mixture NB distribution for proteins. `rate_fore` refers to foreground mean, while `rate_back` refers to background mean. `scale` refers to foreground mean adjusted for background probability and scaled to reside in simplex. `back_alpha` and `back_beta` are the posterior parameters for `rate_back`. `fore_scale` is the scaling factor that enforces `rate_fore` > `rate_back`. Parameters ---------- z tensor with shape ``(n_input,)`` library_gene library size cat_list list of category membership(s) for this sample Returns ------- 3-tuple (first 2-tuple :py:class:`dict`, last :py:class:`torch.Tensor`) parameters for the ZINB distribution of expression """ px_ = {} py_ = {} px = self.px_decoder(z, *cat_list) px_cat_z = torch.cat([px, z], dim=-1) px_["scale"] = nn.Softmax(dim=-1)(self.px_scale_decoder(px_cat_z, *cat_list)) px_["rate"] = library_gene * px_["scale"] py_back = self.py_back_decoder(z, *cat_list) py_back_cat_z = torch.cat([py_back, z], dim=-1) py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list) py_["back_beta"] = torch.exp( self.py_back_mean_log_beta(py_back_cat_z, *cat_list) ) log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample() py_["rate_back"] = torch.exp(log_pro_back_mean) py_fore = self.py_fore_decoder(z, *cat_list) py_fore_cat_z = torch.cat([py_fore, z], dim=-1) py_["fore_scale"] = ( self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8 ) py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"] p_mixing = self.sigmoid_decoder(z, *cat_list) p_mixing_cat_z = torch.cat([p_mixing, z], dim=-1) px_["dropout"] = self.px_dropout_decoder_gene(p_mixing_cat_z, *cat_list) py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, *cat_list) protein_mixing = 1 / (1 + torch.exp(-py_["mixing"])) py_["scale"] = torch.nn.functional.normalize( (1 - protein_mixing) * py_["rate_fore"], p=1, dim=-1 ) return (px_, py_, log_pro_back_mean)
# Encoder
[docs]class EncoderTOTALVI(nn.Module): """Encodes data of ``n_input`` dimensions into a latent space of ``n_output`` dimensions using a fully-connected neural network of ``n_hidden`` layers. Parameters ---------- n_input The dimensionality of the input (data space) n_output The dimensionality of the output (latent space) n_cat_list A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding n_layers The number of fully-connected hidden layers n_hidden The number of nodes per hidden layer dropout_rate Dropout rate to apply to each of the hidden layers distribution Distribution of the latent space, one of * ``'normal'`` - Normal distribution * ``'ln'`` - Logistic normal Returns ------- """ def __init__( self, n_input: int, n_output: int, n_cat_list: Iterable[int] = None, n_layers: int = 2, n_hidden: int = 256, dropout_rate: float = 0.1, distribution: str = "ln", ): super().__init__() self.encoder = FCLayers( n_in=n_input, n_out=n_hidden, n_cat_list=n_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, ) self.z_encoder = nn.Sequential( nn.Linear(n_hidden, n_hidden), nn.BatchNorm1d(n_hidden), nn.ReLU(), nn.Dropout(p=dropout_rate), ) self.z_mean_encoder = nn.Linear(n_hidden, n_output) self.z_var_encoder = nn.Linear(n_hidden, n_output) self.l_gene_encoder = nn.Sequential( nn.Linear(n_hidden, n_hidden), nn.BatchNorm1d(n_hidden), nn.ReLU(), nn.Dropout(p=dropout_rate), ) self.l_gene_mean_encoder = nn.Linear(n_hidden, 1) self.l_gene_var_encoder = nn.Linear(n_hidden, 1) self.distribution = distribution if distribution == "ln": self.z_transformation = nn.Softmax(dim=-1) else: self.z_transformation = identity self.l_transformation = torch.exp
[docs] def reparameterize_transformation(self, mu, var): untran_z = Normal(mu, var.sqrt()).rsample() z = self.z_transformation(untran_z) return z, untran_z
[docs] def forward(self, data: torch.Tensor, *cat_list: int): """The forward computation for a single sample. #. Encodes the data into latent space using the encoder network #. Generates a mean \\( q_m \\) and variance \\( q_v \\) #. Samples a new value from an i.i.d. latent distribution The dictionary ``latent`` contains the samples of the latent variables, while ``untran_latent`` contains the untransformed versions of these latent variables. For example, the library size is log normally distributed, so ``untran_latent["l"]`` gives the normal sample that was later exponentiated to become ``latent["l"]``. The logistic normal distribution is equivalent to applying softmax to a normal sample. Parameters ---------- data tensor with shape ``(n_input,)`` cat_list list of category membership(s) for this sample Returns ------- 6-tuple. First 4 of :py:class:`torch.Tensor`, next 2 are `dict` of :py:class:`torch.Tensor` tensors of shape ``(n_latent,)`` for mean and var, and sample """ # Parameters for latent distribution q = self.encoder(data, *cat_list) qz = self.z_encoder(q) qz_m = self.z_mean_encoder(qz) qz_v = torch.exp(self.z_var_encoder(qz)) + 1e-4 z, untran_z = self.reparameterize_transformation(qz_m, qz_v) ql_gene = self.l_gene_encoder(q) ql_m = self.l_gene_mean_encoder(ql_gene) ql_v = torch.exp(self.l_gene_var_encoder(ql_gene)) + 1e-4 log_library_gene = torch.clamp(reparameterize_gaussian(ql_m, ql_v), max=15) library_gene = self.l_transformation(log_library_gene) latent = {} untran_latent = {} latent["z"] = z latent["l"] = library_gene untran_latent["z"] = untran_z untran_latent["l"] = log_library_gene return qz_m, qz_v, ql_m, ql_v, latent, untran_latent