from typing import Dict, Iterable, Optional
import numpy as np
import torch
from torch.distributions import Normal, Poisson
from torch.distributions import kl_divergence as kld
from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from scvi.module._peakvae import Decoder as DecoderPeakVI
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.nn import DecoderSCVI, Encoder, FCLayers
class LibrarySizeEncoder(torch.nn.Module):
def __init__(
self,
n_input: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 2,
n_hidden: int = 128,
use_batch_norm: bool = False,
use_layer_norm: bool = True,
deep_inject_covariates: bool = False,
):
super().__init__()
self.px_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
activation_fn=torch.nn.LeakyReLU,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
inject_covariates=deep_inject_covariates,
)
self.output = torch.nn.Sequential(
torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU()
)
def forward(self, x: torch.Tensor, *cat_list: int):
return self.output(self.px_decoder(x, *cat_list))
[docs]class MULTIVAE(BaseModuleClass):
"""
Variational auto-encoder model for joint paired + unpaired RNA-seq and ATAC-seq data.
Parameters
----------
n_input_regions
Number of input regions.
n_input_genes
Number of input genes.
n_batch
Number of batches, if 0, no batch correction is performed.
gene_likelihood
The distribution to use for gene expression data. One of the following
* ``'zinb'`` - Zero-Inflated Negative Binomial
* ``'nb'`` - Negative Binomial
* ``'poisson'`` - Poisson
n_hidden
Number of nodes per hidden layer. If `None`, defaults to square root
of number of regions.
n_latent
Dimensionality of the latent space. If `None`, defaults to square root
of `n_hidden`.
n_layers_encoder
Number of hidden layers used for encoder NN.
n_layers_decoder
Number of hidden layers used for decoder NN.
dropout_rate
Dropout rate for neural networks
region_factors
Include region-specific factors in the model
use_batch_norm
One of the following
* ``'encoder'`` - use batch normalization in the encoder only
* ``'decoder'`` - use batch normalization in the decoder only
* ``'none'`` - do not use batch normalization
* ``'both'`` - use batch normalization in both the encoder and decoder
use_layer_norm
One of the following
* ``'encoder'`` - use layer normalization in the encoder only
* ``'decoder'`` - use layer normalization in the decoder only
* ``'none'`` - do not use layer normalization
* ``'both'`` - use layer normalization in both the encoder and decoder
latent_distribution
which latent distribution to use, options are
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
deeply_inject_covariates
Whether to deeply inject covariates into all layers of the decoder. If False,
covariates will only be included in the input layer.
encode_covariates
If True, include covariates in the input to the encoder.
use_size_factor_key
Use size_factor AnnDataField defined by the user as scaling factor in mean of conditional RNA distribution.
"""
## TODO: replace n_input_regions and n_input_genes with a gene/region mask (we don't dictate which comes forst or that they're even contiguous)
def __init__(
self,
n_input_regions: int = 0,
n_input_genes: int = 0,
n_batch: int = 0,
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
n_hidden: Optional[int] = None,
n_latent: Optional[int] = None,
n_layers_encoder: int = 2,
n_layers_decoder: int = 2,
n_continuous_cov: int = 0,
n_cats_per_cov: Optional[Iterable[int]] = None,
dropout_rate: float = 0.1,
region_factors: bool = True,
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both",
latent_distribution: str = "normal",
deeply_inject_covariates: bool = False,
encode_covariates: bool = False,
use_size_factor_key: bool = False,
):
super().__init__()
# INIT PARAMS
self.n_input_regions = n_input_regions
self.n_input_genes = n_input_genes
self.n_hidden = (
int(np.sqrt(self.n_input_regions + self.n_input_genes))
if n_hidden is None
else n_hidden
)
self.n_batch = n_batch
self.gene_likelihood = gene_likelihood
self.latent_distribution = latent_distribution
self.n_latent = int(np.sqrt(self.n_hidden)) if n_latent is None else n_latent
self.n_layers_encoder = n_layers_encoder
self.n_layers_decoder = n_layers_decoder
self.n_cats_per_cov = n_cats_per_cov
self.n_continuous_cov = n_continuous_cov
self.dropout_rate = dropout_rate
self.use_batch_norm_encoder = use_batch_norm in ("encoder", "both")
self.use_batch_norm_decoder = use_batch_norm in ("decoder", "both")
self.use_layer_norm_encoder = use_layer_norm in ("encoder", "both")
self.use_layer_norm_decoder = use_layer_norm in ("decoder", "both")
self.encode_covariates = encode_covariates
self.deeply_inject_covariates = deeply_inject_covariates
self.use_size_factor_key = use_size_factor_key
cat_list = (
[n_batch] + list(n_cats_per_cov) if n_cats_per_cov is not None else []
)
n_input_encoder_acc = (
self.n_input_regions + n_continuous_cov * encode_covariates
)
n_input_encoder_exp = self.n_input_genes + n_continuous_cov * encode_covariates
encoder_cat_list = cat_list if encode_covariates else None
## accessibility encoder
self.z_encoder_accessibility = Encoder(
n_input=n_input_encoder_acc,
n_layers=self.n_layers_encoder,
n_output=self.n_latent,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
dropout_rate=self.dropout_rate,
activation_fn=torch.nn.LeakyReLU,
distribution=self.latent_distribution,
var_eps=0,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
)
## expression encoder
self.z_encoder_expression = Encoder(
n_input=n_input_encoder_exp,
n_layers=self.n_layers_encoder,
n_output=self.n_latent,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
dropout_rate=self.dropout_rate,
activation_fn=torch.nn.LeakyReLU,
distribution=self.latent_distribution,
var_eps=0,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
)
# expression decoder
self.z_decoder_expression = DecoderSCVI(
self.n_latent + self.n_continuous_cov,
n_input_genes,
n_cat_list=cat_list,
n_layers=n_layers_decoder,
n_hidden=self.n_hidden,
inject_covariates=self.deeply_inject_covariates,
use_batch_norm=self.use_batch_norm_decoder,
use_layer_norm=self.use_layer_norm_decoder,
scale_activation="softplus" if use_size_factor_key else "softmax",
)
# accessibility decoder
self.z_decoder_accessibility = DecoderPeakVI(
n_input=self.n_latent + self.n_continuous_cov,
n_output=n_input_regions,
n_hidden=self.n_hidden,
n_cat_list=cat_list,
n_layers=self.n_layers_decoder,
use_batch_norm=self.use_batch_norm_decoder,
use_layer_norm=self.use_layer_norm_decoder,
deep_inject_covariates=self.deeply_inject_covariates,
)
## accessibility region-specific factors
self.region_factors = None
if region_factors:
self.region_factors = torch.nn.Parameter(torch.zeros(self.n_input_regions))
## expression dispersion parameters
self.px_r = torch.nn.Parameter(torch.randn(n_input_genes))
## expression library size encoder
self.l_encoder_expression = LibrarySizeEncoder(
n_input_encoder_exp,
n_cat_list=encoder_cat_list,
n_layers=self.n_layers_encoder,
n_hidden=self.n_hidden,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
deep_inject_covariates=self.deeply_inject_covariates,
)
## accessibility library size encoder
self.l_encoder_accessibility = DecoderPeakVI(
n_input=n_input_encoder_acc,
n_output=1,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
n_layers=self.n_layers_encoder,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
deep_inject_covariates=self.deeply_inject_covariates,
)
def _get_inference_input(self, tensors):
x = tensors[REGISTRY_KEYS.X_KEY]
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY)
cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY)
input_dict = dict(
x=x,
batch_index=batch_index,
cont_covs=cont_covs,
cat_covs=cat_covs,
)
return input_dict
[docs] @auto_move_data
def inference(
self,
x,
batch_index,
cont_covs,
cat_covs,
n_samples=1,
) -> Dict[str, torch.Tensor]:
# Get Data and Additional Covs
x_rna = x[:, : self.n_input_genes]
x_chr = x[:, self.n_input_genes :]
mask_expr = x_rna.sum(dim=1) > 0
mask_acc = x_chr.sum(dim=1) > 0
if cont_covs is not None and self.encode_covariates:
encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1)
encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1)
else:
encoder_input_expression = x_rna
encoder_input_accessibility = x_chr
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()
# Z Encoders
qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
encoder_input_accessibility, batch_index, *categorical_input
)
qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
encoder_input_expression, batch_index, *categorical_input
)
# L encoders
libsize_expr = self.l_encoder_expression(
encoder_input_expression, batch_index, *categorical_input
)
libsize_acc = self.l_encoder_accessibility(
encoder_input_accessibility, batch_index, *categorical_input
)
# ReFormat Outputs
if n_samples > 1:
qzm_acc = qzm_acc.unsqueeze(0).expand(
(n_samples, qzm_acc.size(0), qzm_acc.size(1))
)
qzv_acc = qzv_acc.unsqueeze(0).expand(
(n_samples, qzv_acc.size(0), qzv_acc.size(1))
)
untran_za = Normal(qzm_acc, qzv_acc.sqrt()).sample()
z_acc = self.z_encoder_accessibility.z_transformation(untran_za)
qzm_expr = qzm_expr.unsqueeze(0).expand(
(n_samples, qzm_expr.size(0), qzm_expr.size(1))
)
qzv_expr = qzv_expr.unsqueeze(0).expand(
(n_samples, qzv_expr.size(0), qzv_expr.size(1))
)
untran_zr = Normal(qzm_expr, qzv_expr.sqrt()).sample()
z_expr = self.z_encoder_expression.z_transformation(untran_zr)
libsize_expr = libsize_expr.unsqueeze(0).expand(
(n_samples, libsize_expr.size(0), libsize_expr.size(1))
)
libsize_acc = libsize_acc.unsqueeze(0).expand(
(n_samples, libsize_acc.size(0), libsize_acc.size(1))
)
## Sample from the average distribution
qzp_m = (qzm_acc + qzm_expr) / 2
qzp_v = (qzv_acc + qzv_expr) / (2**0.5)
zp = Normal(qzp_m, qzp_v.sqrt()).rsample()
## choose the correct latent representation based on the modality
qz_m = self._mix_modalities(qzp_m, qzm_expr, qzm_acc, mask_expr, mask_acc)
qz_v = self._mix_modalities(qzp_v, qzv_expr, qzv_acc, mask_expr, mask_acc)
z = self._mix_modalities(zp, z_expr, z_acc, mask_expr, mask_acc)
outputs = dict(
z=z,
qz_m=qz_m,
qz_v=qz_v,
z_expr=z_expr,
qzm_expr=qzm_expr,
qzv_expr=qzv_expr,
z_acc=z_acc,
qzm_acc=qzm_acc,
qzv_acc=qzv_acc,
libsize_expr=libsize_expr,
libsize_acc=libsize_acc,
)
return outputs
def _get_generative_input(self, tensors, inference_outputs, transform_batch=None):
z = inference_outputs["z"]
qz_m = inference_outputs["qz_m"]
libsize_expr = inference_outputs["libsize_expr"]
size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY
size_factor = (
torch.log(tensors[size_factor_key])
if size_factor_key in tensors.keys()
else None
)
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None
cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch
input_dict = dict(
z=z,
qz_m=qz_m,
batch_index=batch_index,
cont_covs=cont_covs,
cat_covs=cat_covs,
libsize_expr=libsize_expr,
size_factor=size_factor,
)
return input_dict
[docs] @auto_move_data
def generative(
self,
z,
qz_m,
batch_index,
cont_covs=None,
cat_covs=None,
libsize_expr=None,
size_factor=None,
use_z_mean=False,
):
"""Runs the generative model."""
if cat_covs is not None:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()
latent = z if not use_z_mean else qz_m
decoder_input = (
latent if cont_covs is None else torch.cat([latent, cont_covs], dim=-1)
)
# Accessibility Decoder
p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input)
# Expression Decoder
if not self.use_size_factor_key:
size_factor = libsize_expr
px_scale, _, px_rate, px_dropout = self.z_decoder_expression(
"gene", decoder_input, size_factor, batch_index, *categorical_input
)
return dict(
p=p,
px_scale=px_scale,
px_r=torch.exp(self.px_r),
px_rate=px_rate,
px_dropout=px_dropout,
)
[docs] def loss(
self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0
):
# Get the data
x = tensors[REGISTRY_KEYS.X_KEY]
x_rna = x[:, : self.n_input_genes]
x_chr = x[:, self.n_input_genes :]
mask_expr = x_rna.sum(dim=1) > 0
mask_acc = x_chr.sum(dim=1) > 0
# Compute Accessibility loss
x_accessibility = x[:, self.n_input_genes :]
p = generative_outputs["p"]
libsize_acc = inference_outputs["libsize_acc"]
rl_accessibility = self.get_reconstruction_loss_accessibility(
x_accessibility, p, libsize_acc
)
# Compute Expression loss
px_rate = generative_outputs["px_rate"]
px_r = generative_outputs["px_r"]
px_dropout = generative_outputs["px_dropout"]
x_expression = x[:, : self.n_input_genes]
rl_expression = self.get_reconstruction_loss_expression(
x_expression, px_rate, px_r, px_dropout
)
# mix losses to get the correct loss for each cell
recon_loss = self._mix_modalities(
rl_accessibility + rl_expression, # paired
rl_expression, # expression
rl_accessibility, # accessibility
mask_expr,
mask_acc,
)
# Compute KLD between Z and N(0,I)
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
kl_div_z = kld(
Normal(qz_m, torch.sqrt(qz_v)),
Normal(0, 1),
).sum(dim=1)
# Compute KLD between distributions for paired data
qzm_expr = inference_outputs["qzm_expr"]
qzv_expr = inference_outputs["qzv_expr"]
qzm_acc = inference_outputs["qzm_acc"]
qzv_acc = inference_outputs["qzv_acc"]
kld_paired = kld(
Normal(qzm_expr, torch.sqrt(qzv_expr)), Normal(qzm_acc, torch.sqrt(qzv_acc))
) + kld(
Normal(qzm_acc, torch.sqrt(qzv_acc)), Normal(qzm_expr, torch.sqrt(qzv_expr))
)
kld_paired = torch.where(
torch.logical_and(mask_acc, mask_expr),
kld_paired.T,
torch.zeros_like(kld_paired).T,
).sum(dim=0)
# KL WARMUP
kl_local_for_warmup = kl_div_z
weighted_kl_local = kl_weight * kl_local_for_warmup
# PENALTY
# distance_penalty = kl_weight * torch.pow(z_acc - z_expr, 2).sum(dim=1)
# TOTAL LOSS
loss = torch.mean(recon_loss + weighted_kl_local + kld_paired)
kl_local = dict(kl_divergence_z=kl_div_z)
kl_global = torch.tensor(0.0)
return LossRecorder(loss, recon_loss, kl_local, kl_global)
[docs] def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout):
rl = 0.0
if self.gene_likelihood == "zinb":
rl = (
-ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
)
.log_prob(x)
.sum(dim=-1)
)
elif self.gene_likelihood == "nb":
rl = -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)
elif self.gene_likelihood == "poisson":
rl = -Poisson(px_rate).log_prob(x).sum(dim=-1)
return rl
[docs] def get_reconstruction_loss_accessibility(self, x, p, d):
f = torch.sigmoid(self.region_factors) if self.region_factors is not None else 1
return torch.nn.BCELoss(reduction="none")(p * d * f, (x > 0).float()).sum(
dim=-1
)
@staticmethod
def _mix_modalities(x_paired, x_expr, x_acc, mask_expr, mask_acc):
"""
Mixes modality-specific vectors according to the modality masks.
in positions where both `mask_expr` and `mask_acc` are True (corresponding to cell
for which both expression and accessibility data is available), values from `x_paired`
will be used. If only `mask_expr` is True, use values from `x_expr`, and if only
`mask_acc` is True, use values from `x_acc`.
Parameters
----------
x_paired
the values for paired cells (both modalities available), will be used in
positions where both `mask_expr` and `mask_acc` are True.
x_expr
the values for expression-only cells, will be used in positions where
only `mask_expr` is True.
x_acc
the values for accessibility-only cells, will be used on positions where
only `mask_acc` is True.
mask_expr
the expression mask, indicating which cells have expression data
mask_acc
the accessibility mask, indicating which cells have accessibility data
"""
x = torch.where(mask_expr.T, x_expr.T, x_acc.T).T
x = torch.where(torch.logical_and(mask_acc, mask_expr), x_paired.T, x.T).T
return x