Source code for scvi.models.jvae

# -*- coding: utf-8 -*-
"""Main module."""
from typing import List, Optional, Union, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Normal, kl_divergence as kl
from torch.nn import ModuleList

from scvi.models.distributions import (
    NegativeBinomial,
    ZeroInflatedNegativeBinomial,
    Poisson,
)
from scvi.models.modules import Encoder
from scvi.models.modules import MultiEncoder, MultiDecoder
from scvi.models.utils import one_hot

torch.backends.cudnn.benchmark = True


[docs]class JVAE(nn.Module): """Joint Variational auto-encoder for imputing missing genes in spatial data Implementation of gimVI [Lopez19]_. dim_input_list List of number of input genes for each dataset. If the datasets have different sizes, the dataloader will loop on the smallest until it reaches the size of the longest one total_genes Total number of different genes indices_mappings list of mapping the model inputs to the model output Eg: ``[[0,2], [0,1,3,2]]`` means the first dataset has 2 genes that will be reconstructed at location ``[0,2]`` the second dataset has 4 genes that will be reconstructed at ``[0,1,3,2]`` reconstruction_losses list of distributions to use in the generative process 'zinb', 'nb', 'poisson' model_library_bools bool list model or not library size with a latent variable or use observed values n_latent dimension of latent space n_layers_encoder_individual number of individual layers in the encoder n_layers_encoder_shared number of shared layers in the encoder dim_hidden_encoder dimension of the hidden layers in the encoder n_layers_decoder_individual number of layers that are conditionally batchnormed in the encoder n_layers_decoder_shared number of shared layers in the decoder dim_hidden_decoder_individual dimension of the individual hidden layers in the decoder dim_hidden_decoder_shared dimension of the shared hidden layers in the decoder dropout_rate_encoder dropout encoder dropout_rate_decoder dropout decoder n_batch total number of batches n_labels total number of labels dispersion See ``vae.py`` log_variational Log(data+1) prior to encoding for numerical stability. Not normalization. """ def __init__( self, dim_input_list: List[int], total_genes: int, indices_mappings: List[Union[np.ndarray, slice]], reconstruction_losses: List[str], model_library_bools: List[bool], n_latent: int = 10, n_layers_encoder_individual: int = 1, n_layers_encoder_shared: int = 1, dim_hidden_encoder: int = 128, n_layers_decoder_individual: int = 0, n_layers_decoder_shared: int = 0, dim_hidden_decoder_individual: int = 32, dim_hidden_decoder_shared: int = 128, dropout_rate_encoder: float = 0.1, dropout_rate_decoder: float = 0.3, n_batch: int = 0, n_labels: int = 0, dispersion: str = "gene-batch", log_variational: bool = True, ): super().__init__() self.n_input_list = dim_input_list self.total_genes = total_genes self.indices_mappings = indices_mappings self.reconstruction_losses = reconstruction_losses self.model_library_bools = model_library_bools self.n_latent = n_latent self.n_batch = n_batch self.n_labels = n_labels self.dispersion = dispersion self.log_variational = log_variational self.z_encoder = MultiEncoder( n_heads=len(dim_input_list), n_input_list=dim_input_list, n_output=self.n_latent, n_hidden=dim_hidden_encoder, n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, ) self.l_encoders = ModuleList( [ Encoder( self.n_input_list[i], 1, n_layers=1, dropout_rate=dropout_rate_encoder, ) if self.model_library_bools[i] else None for i in range(len(self.n_input_list)) ] ) self.decoder = MultiDecoder( self.n_latent, self.total_genes, n_hidden_conditioned=dim_hidden_decoder_individual, n_hidden_shared=dim_hidden_decoder_shared, n_layers_conditioned=n_layers_decoder_individual, n_layers_shared=n_layers_decoder_shared, n_cat_list=[self.n_batch], dropout_rate=dropout_rate_decoder, ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes)) elif self.dispersion == "gene-batch": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_batch)) elif self.dispersion == "gene-label": self.px_r = torch.nn.Parameter(torch.randn(self.total_genes, n_labels)) else: # gene-cell pass
[docs] def sample_from_posterior_z( self, x: torch.Tensor, mode: int = None, deterministic: bool = False ) -> torch.Tensor: """Sample tensor of latent values from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` mode head id to use in the encoder deterministic bool - whether to sample or not Returns ------- type tensor of shape ``(batch_size, n_latent)`` """ if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode when having multiple datasets") qz_m, _, z, _, _, _ = self.encode(x, mode) if deterministic: z = qz_m return z
[docs] def sample_from_posterior_l( self, x: torch.Tensor, mode: int = None, deterministic: bool = False ) -> torch.Tensor: """Sample the tensor of library sizes from the posterior Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode mode head id to use in the encoder deterministic bool - whether to sample or not Returns ------- type tensor of shape ``(batch_size, 1)`` """ _, _, _, ql_m, _, library = self.encode(x, mode) if deterministic and ql_m is not None: library = ql_m return library
[docs] def sample_scale( self, x: torch.Tensor, mode: int, batch_index: torch.Tensor, y: Optional[torch.Tensor] = None, deterministic: bool = False, decode_mode: Optional[int] = None, ) -> torch.Tensor: """Return the tensor of predicted frequencies of expression Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode mode int encode mode (which input head to use in the model) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` y tensor of cell-types labels with shape ``(batch_size, n_labels)`` deterministic bool - whether to sample or not decode_mode int use to a decode mode different from encoding mode Returns ------- type tensor of predicted expression """ if decode_mode is None: decode_mode = mode qz_m, qz_v, z, ql_m, ql_v, library = self.encode(x, mode) if deterministic: z = qz_m if ql_m is not None: library = ql_m px_scale, px_r, px_rate, px_dropout = self.decode( z, decode_mode, library, batch_index, y ) return px_scale
# This is a potential wrapper for a vae like get_sample_rate
[docs] def get_sample_rate(self, x, batch_index, *_, **__): return self.sample_rate(x, 0, batch_index)
[docs] def sample_rate( self, x: torch.Tensor, mode: int, batch_index: torch.Tensor, y: Optional[torch.Tensor] = None, deterministic: bool = False, decode_mode: int = None, ) -> torch.Tensor: """Returns the tensor of scaled frequencies of expression Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode y tensor of cell-types labels with shape ``(batch_size, n_labels)`` mode int encode mode (which input head to use in the model) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` deterministic bool - whether to sample or not decode_mode int use to a decode mode different from encoding mode Returns ------- type tensor of means of the scaled frequencies """ if decode_mode is None: decode_mode = mode qz_m, qz_v, z, ql_m, ql_v, library = self.encode(x, mode) if deterministic: z = qz_m if ql_m is not None: library = ql_m px_scale, px_r, px_rate, px_dropout = self.decode( z, decode_mode, library, batch_index, y ) return px_rate
[docs] def reconstruction_loss( self, x: torch.Tensor, px_rate: torch.Tensor, px_r: torch.Tensor, px_dropout: torch.Tensor, mode: int, ) -> torch.Tensor: reconstruction_loss = None if self.reconstruction_losses[mode] == "zinb": reconstruction_loss = ( -ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) .log_prob(x) .sum(dim=-1) ) elif self.reconstruction_losses[mode] == "nb": reconstruction_loss = ( -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) ) elif self.reconstruction_losses[mode] == "poisson": reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss
[docs] def encode( self, x: torch.Tensor, mode: int ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, ]: x_ = x if self.log_variational: x_ = torch.log(1 + x_) qz_m, qz_v, z = self.z_encoder(x_, mode) ql_m, ql_v, library = None, None, None if self.model_library_bools[mode]: ql_m, ql_v, library = self.l_encoders[mode](x_) else: library = torch.log(torch.sum(x, dim=1)).view(-1, 1) return qz_m, qz_v, z, ql_m, ql_v, library
[docs] def decode( self, z: torch.Tensor, mode: int, library: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: px_scale, px_r, px_rate, px_dropout = self.decoder( z, mode, library, self.dispersion, batch_index, y ) if self.dispersion == "gene-label": px_r = F.linear(one_hot(y, self.n_labels), self.px_r) elif self.dispersion == "gene-batch": px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r) elif self.dispersion == "gene": px_r = self.px_r.view(1, self.px_r.size(0)) px_r = torch.exp(px_r) px_scale = px_scale / torch.sum( px_scale[:, self.indices_mappings[mode]], dim=1 ).view(-1, 1) px_rate = px_scale * torch.exp(library) return px_scale, px_r, px_rate, px_dropout
[docs] def forward( self, x: torch.Tensor, local_l_mean: torch.Tensor, local_l_var: torch.Tensor, batch_index: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, mode: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return the reconstruction loss and the Kullback divergences Parameters ---------- x tensor of values with shape ``(batch_size, n_input)`` or ``(batch_size, n_input_fish)`` depending on the mode local_l_mean tensor of means of the prior distribution of latent variable l with shape (batch_size, 1) local_l_var tensor of variances of the prior distribution of latent variable l with shape (batch_size, 1) batch_index array that indicates which batch the cells belong to with shape ``batch_size`` y tensor of cell-types labels with shape (batch_size, n_labels) mode indicates which head/tail to use in the joint network Returns ------- the reconstruction loss and the Kullback divergences """ if mode is None: if len(self.n_input_list) == 1: mode = 0 else: raise Exception("Must provide a mode") qz_m, qz_v, z, ql_m, ql_v, library = self.encode(x, mode) px_scale, px_r, px_rate, px_dropout = self.decode( z, mode, library, batch_index, y ) # mask loss to observed genes mapping_indices = self.indices_mappings[mode] reconstruction_loss = self.reconstruction_loss( x, px_rate[:, mapping_indices], px_r[:, mapping_indices], px_dropout[:, mapping_indices], mode, ) # KL Divergence mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( dim=1 ) if self.model_library_bools[mode]: kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = torch.zeros_like(kl_divergence_z) return reconstruction_loss, kl_divergence_l + kl_divergence_z, 0.0