from typing import Optional, Union, List, Callable, Tuple
import logging
import torch
from torch.distributions import Poisson, Gamma, Bernoulli, Normal
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from scvi.inference import Posterior
from . import UnsupervisedTrainer
from scvi.dataset import GeneExpressionDataset
from scvi.models import TOTALVI, Classifier
from scvi.models.utils import one_hot
logger = logging.getLogger(__name__)
[docs]class TotalPosterior(Posterior):
"""The functional data unit for totalVI.
A `TotalPosterior` instance is instantiated with a model and
a `gene_dataset`, and as well as additional arguments that for Pytorch's `DataLoader`. A subset of indices
can be specified, for purposes such as splitting the data into train/test/validation. Each trainer instance of the `TotalTrainer` class can therefore have multiple
`TotalPosterior` instances to train a model. A `TotalPosterior` instance also comes with many methods or
utilities for its corresponding data.
Parameters
----------
model :
A model instance from class ``TOTALVI``
gene_dataset :
A gene_dataset instance like ``CbmcDataset()`` with attribute ``protein_expression``
shuffle :
Specifies if a `RandomSampler` or a `SequentialSampler` should be used
indices :
Specifies how the data should be split with regards to train/test or labelled/unlabelled
use_cuda :
Default: ``True``
data_loader_kwargs :
Keyword arguments to passed into the `DataLoader`
Examples
--------
Let us instantiate a `trainer`, with a gene_dataset and a model
>>> gene_dataset = CbmcDataset()
>>> totalvi = TOTALVI(gene_dataset.nb_genes, len(gene_dataset.protein_names),
... n_batch=gene_dataset.n_batches, use_cuda=True)
>>> trainer = TotalTrainer(vae, gene_dataset)
>>> trainer.train(n_epochs=500)
"""
def __init__(
self,
model: TOTALVI,
gene_dataset: GeneExpressionDataset,
shuffle: bool = False,
indices: Optional[np.ndarray] = None,
use_cuda: bool = True,
data_loader_kwargs=dict(),
):
super().__init__(
model,
gene_dataset,
shuffle=shuffle,
indices=indices,
use_cuda=use_cuda,
data_loader_kwargs=data_loader_kwargs,
)
# Add protein tensor as another tensor to be loaded
self.data_loader_kwargs.update(
{
"collate_fn": gene_dataset.collate_fn_builder(
{"protein_expression": np.float32}
)
}
)
self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
[docs] def corrupted(self):
return self.update(
{
"collate_fn": self.gene_dataset.collate_fn_builder(
{"protein_expression": np.float32}, corrupted=True
)
}
)
[docs] def uncorrupted(self):
return self.update(
{
"collate_fn": self.gene_dataset.collate_fn_builder(
{"protein_expression": np.float32}
)
}
)
[docs] @torch.no_grad()
def elbo(self):
elbo = self.compute_elbo(self.model)
return elbo
elbo.mode = "min"
[docs] @torch.no_grad()
def reconstruction_error(self, mode="total"):
ll_gene, ll_protein = self.compute_reconstruction_error(self.model)
if mode == "total":
return ll_gene + ll_protein
elif mode == "gene":
return ll_gene
else:
return ll_protein
reconstruction_error.mode = "min"
[docs] @torch.no_grad()
def marginal_ll(self, n_mc_samples=1000):
ll = self.compute_marginal_log_likelihood()
return ll
[docs] @torch.no_grad()
def get_protein_background_mean(self):
background_mean = []
for tensors in self:
x, _, _, batch_index, label, y = tensors
outputs = self.model.inference(
x, y, batch_index=batch_index, label=label, n_samples=1
)
b_mean = outputs["py_"]["rate_back"]
background_mean += [np.array(b_mean.cpu())]
return np.concatenate(background_mean)
[docs] def compute_elbo(self, vae: TOTALVI, **kwargs):
"""Computes the ELBO.
The ELBO is the reconstruction error + the KL divergences
between the variational distributions and the priors.
It differs from the marginal log likelihood.
Specifically, it is a lower bound on the marginal log likelihood
plus a term that is constant with respect to the variational distribution.
It still gives good insights on the modeling of the data, and is fast to compute.
Parameters
----------
vae
**kwargs
Returns
-------
"""
# Iterate once over the posterior and computes the total log_likelihood
elbo = 0
for i_batch, tensors in enumerate(self):
x, local_l_mean, local_l_var, batch_index, labels, y = tensors
(
reconst_loss_gene,
reconst_loss_protein,
kl_div_z,
kl_div_gene_l,
kl_div_back_pro,
) = vae(
x,
y,
local_l_mean,
local_l_var,
batch_index=batch_index,
label=labels,
**kwargs,
)
elbo += torch.sum(
reconst_loss_gene
+ reconst_loss_protein
+ kl_div_z
+ kl_div_gene_l
+ kl_div_back_pro
).item()
n_samples = len(self.indices)
return elbo / n_samples
[docs] def compute_reconstruction_error(self, vae: TOTALVI, **kwargs):
r""" Computes log p(x/z), which is the reconstruction error.
Differs from the marginal log likelihood, but still gives good
insights on the modeling of the data, and is fast to compute
This is really a helper function to self.ll, self.ll_protein, etc.
"""
# Iterate once over the posterior and computes the total log_likelihood
log_lkl_gene = 0
log_lkl_protein = 0
for i_batch, tensors in enumerate(self):
x, local_l_mean, local_l_var, batch_index, labels, y = tensors
(
reconst_loss_gene,
reconst_loss_protein,
kl_div_z,
kl_div_l_gene,
kl_div_back_pro,
) = vae(
x,
y,
local_l_mean,
local_l_var,
batch_index=batch_index,
label=labels,
**kwargs,
)
log_lkl_gene += torch.sum(reconst_loss_gene).item()
log_lkl_protein += torch.sum(reconst_loss_protein).item()
n_samples = len(self.indices)
return log_lkl_gene / n_samples, log_lkl_protein / n_samples
[docs] def compute_marginal_log_likelihood(
self, n_samples_mc: int = 100, batch_size: int = 96
):
"""Computes a biased estimator for log p(x, y), which is the marginal log likelihood.
Despite its bias, the estimator still converges to the real value
of log p(x, y) when n_samples_mc (for Monte Carlo) goes to infinity
(a fairly high value like 100 should be enough). 5000 is the standard in machine learning publications.
Due to the Monte Carlo sampling, this method is not as computationally efficient
as computing only the reconstruction loss
Parameters
----------
n_samples_mc
(Default value = 100)
batch_size
(Default value = 96)
Returns
-------
"""
# Uses MC sampling to compute a tighter lower bound on log p(x)
log_lkl = 0
for i_batch, tensors in enumerate(self.update({"batch_size": batch_size})):
x, local_l_mean, local_l_var, batch_index, labels, y = tensors
to_sum = torch.zeros(x.size()[0], n_samples_mc)
for i in range(n_samples_mc):
# Distribution parameters and sampled variables
outputs = self.model.inference(x, y, batch_index, labels)
qz_m = outputs["qz_m"]
qz_v = outputs["qz_v"]
ql_m = outputs["ql_m"]
ql_v = outputs["ql_v"]
px_ = outputs["px_"]
py_ = outputs["py_"]
log_library = outputs["untran_l"]
# really need not softmax transformed random variable
z = outputs["untran_z"]
log_pro_back_mean = outputs["log_pro_back_mean"]
# Reconstruction Loss
(
reconst_loss_gene,
reconst_loss_protein,
) = self.model.get_reconstruction_loss(x, y, px_, py_)
# Log-probabilities
p_l_gene = (
Normal(local_l_mean, local_l_var.sqrt())
.log_prob(log_library)
.sum(dim=-1)
)
p_z = Normal(0, 1).log_prob(z).sum(dim=-1)
p_mu_back = self.model.back_mean_prior.log_prob(log_pro_back_mean).sum(
dim=-1
)
p_xy_zl = -(reconst_loss_gene + reconst_loss_protein)
q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(log_library).sum(dim=-1)
q_mu_back = (
Normal(py_["back_alpha"], py_["back_beta"])
.log_prob(log_pro_back_mean)
.sum(dim=-1)
)
to_sum[:, i] = (
p_z + p_l_gene + p_mu_back + p_xy_zl - q_z_x - q_l_x - q_mu_back
)
batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_samples_mc)
log_lkl += torch.sum(batch_log_lkl).item()
n_samples = len(self.indices)
# The minus sign is there because we actually look at the negative log likelihood
return -log_lkl / n_samples
[docs] @torch.no_grad()
def get_latent(
self, sample: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Output posterior z mean or sample, batch index, and label
Parameters
----------
sample
z mean or z sample
Returns
-------
type
4-tuple of latent, batch_indices, labels, library_gene
"""
latent = []
batch_indices = []
labels = []
library_gene = []
for tensors in self:
x, local_l_mean, local_l_var, batch_index, label, y = tensors
give_mean = not sample
latent += [
self.model.sample_from_posterior_z(
x, y, batch_index, give_mean=give_mean
).cpu()
]
batch_indices += [batch_index.cpu()]
labels += [label.cpu()]
library_gene += [
self.model.sample_from_posterior_l(
x, y, batch_index, give_mean=give_mean
).cpu()
]
return (
np.array(torch.cat(latent)),
np.array(torch.cat(batch_indices)),
np.array(torch.cat(labels)).ravel(),
np.array(torch.cat(library_gene)).ravel(),
)
[docs] @torch.no_grad()
def differential_expression_stats(self, M_sampling: int = 100):
raise NotImplementedError
[docs] @torch.no_grad()
def generate(
self, n_samples: int = 100, batch_size: int = 64
) -> Tuple[np.ndarray, np.ndarray]:
"""Sample from posterior predictive. Proteins are concatenated to genes.
Parameters
----------
n_samples
Number of posterior predictive samples
batch_size
mini batch size for loaded data. Lower for less memory usage
Returns
-------
x_new : :py:class:`torch.Tensor`
tensor with shape (n_cells, n_genes + n_proteins, n_samples)
x_old : :py:class:`torch.Tensor`
tensor with shape (n_cells, n_genes + n_proteins)
"""
original_list = []
posterior_list = []
for tensors in self.update({"batch_size": batch_size}):
x, _, _, batch_index, labels, y = tensors
with torch.no_grad():
outputs = self.model.inference(
x, y, batch_index=batch_index, label=labels, n_samples=n_samples
)
px_ = outputs["px_"]
py_ = outputs["py_"]
pi = 1 / (1 + torch.exp(-py_["mixing"]))
mixing_sample = Bernoulli(pi).sample()
protein_rate = (
py_["rate_fore"] * (1 - mixing_sample)
+ py_["rate_back"] * mixing_sample
)
rate = torch.cat((px_["rate"], protein_rate), dim=-1)
if len(px_["r"].size()) == 2:
px_dispersion = px_["r"]
else:
px_dispersion = torch.ones_like(x) * px_["r"]
if len(py_["r"].size()) == 2:
py_dispersion = py_["r"]
else:
py_dispersion = torch.ones_like(y) * py_["r"]
dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)
# This gamma is really l*w using scVI manuscript notation
p = rate / (rate + dispersion)
r = dispersion
l_train = Gamma(r, (1 - p) / p).sample()
data = Poisson(l_train).sample().cpu().numpy()
# """
# In numpy (shape, scale) => (concentration, rate), with scale = p /(1 - p)
# rate = (1 - p) / p # = 1/scale # used in pytorch
# """
original_list += [np.array(torch.cat((x, y), dim=-1).cpu())]
posterior_list += [data]
posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))
return (
np.concatenate(posterior_list, axis=0),
np.concatenate(original_list, axis=0),
)
[docs] @torch.no_grad()
def get_sample_dropout(self, n_samples: int = 1, give_mean: bool = True):
"""Zero-inflation mixing component for genes
Parameters
----------
n_samples
(Default value = 1)
give_mean
(Default value = True)
Returns
-------
"""
px_dropouts = []
for tensors in self:
x, _, _, batch_index, label, y = tensors
outputs = self.model.inference(
x, y, batch_index=batch_index, label=label, n_samples=n_samples
)
px_dropout = torch.sigmoid(outputs["px_"]["dropout"])
px_dropouts += [px_dropout.cpu()]
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
px_dropouts = torch.cat(px_dropouts, dim=1)
# (cells, features, samples)
px_dropouts = px_dropouts.permute(1, 2, 0)
else:
px_dropouts = torch.cat(px_dropouts, dim=0)
if give_mean is True and n_samples > 1:
px_dropouts = torch.mean(px_dropouts, dim=-1)
px_dropouts = px_dropouts.cpu().numpy()
return px_dropouts
[docs] @torch.no_grad()
def get_sample_mixing(
self,
n_samples: int = 1,
give_mean: bool = True,
transform_batch: Optional[Union[int, List[int]]] = None,
) -> np.ndarray:
"""Returns mixing bernoulli parameter for protein negative binomial mixtures (probability background)
Parameters
----------
n_samples
number of samples from posterior distribution
sample_protein_mixing
Sample mixing bernoulli, setting background to zero
give_mean
bool, whether to return samples along first axis or average over samples
transform_batch
Batches to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- list of int, then values are averaged over provided batches.
Returns
-------
array of probability background
"""
py_mixings = []
if (transform_batch is None) or (isinstance(transform_batch, int)):
transform_batch = [transform_batch]
for tensors in self:
x, _, _, batch_index, label, y = tensors
py_mixing = torch.zeros_like(y)
if n_samples > 1:
py_mixing = torch.stack(n_samples * [py_mixing])
for b in transform_batch:
outputs = self.model.inference(
x,
y,
batch_index=batch_index,
label=label,
n_samples=n_samples,
transform_batch=b,
)
py_mixing += torch.sigmoid(outputs["py_"]["mixing"])
py_mixing /= len(transform_batch)
py_mixings += [py_mixing.cpu()]
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
py_mixings = torch.cat(py_mixings, dim=1)
# (cells, features, samples)
py_mixings = py_mixings.permute(1, 2, 0)
else:
py_mixings = torch.cat(py_mixings, dim=0)
if give_mean is True and n_samples > 1:
py_mixings = torch.mean(py_mixings, dim=-1)
py_mixings = py_mixings.cpu().numpy()
return py_mixings
[docs] @torch.no_grad()
def get_sample_scale(
self,
transform_batch=None,
eps=0.5,
normalize_pro=False,
sample_bern=True,
include_bg=False,
) -> np.ndarray:
"""Helper function to provide normalized expression for DE testing.
For normalized, denoised expression, please use
`get_normalized_denoised_expression()`
Parameters
----------
transform_batch
Int of batch to "transform" all cells into (Default value = None)
eps
Prior count to add to protein normalized expression (Default value = 0.5)
normalize_pro
bool, whether to make protein expression sum to one in a cell (Default value = False)
include_bg
bool, whether to include the background component of expression (Default value = False)
sample_bern
(Default value = True)
Returns
-------
"""
scales = []
for tensors in self:
x, _, _, batch_index, label, y = tensors
model_scale = self.model.get_sample_scale(
x,
y,
batch_index=batch_index,
label=label,
n_samples=1,
transform_batch=transform_batch,
eps=eps,
normalize_pro=normalize_pro,
sample_bern=sample_bern,
include_bg=include_bg,
)
# prior count for proteins
scales += [torch.cat(model_scale, dim=-1).cpu().numpy()]
return np.concatenate(scales)
[docs] @torch.no_grad()
def get_normalized_denoised_expression(
self,
n_samples: int = 1,
give_mean: bool = True,
transform_batch: Optional[Union[int, List[int]]] = None,
sample_protein_mixing: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Returns the tensors of denoised normalized gene and protein expression
Parameters
----------
n_samples
number of samples from posterior distribution
sample_protein_mixing
Sample mixing bernoulli, setting background to zero
give_mean
bool, whether to return samples along first axis or average over samples
transform_batch
Batches to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- list of int, then values are averaged over provided batches.
Returns
-------
Denoised genes, denoised proteins
"""
scale_list_gene = []
scale_list_pro = []
if (transform_batch is None) or (isinstance(transform_batch, int)):
transform_batch = [transform_batch]
for tensors in self:
x, _, _, batch_index, label, y = tensors
px_scale = torch.zeros_like(x)
py_scale = torch.zeros_like(y)
if n_samples > 1:
px_scale = torch.stack(n_samples * [px_scale])
py_scale = torch.stack(n_samples * [py_scale])
for b in transform_batch:
outputs = self.model.inference(
x,
y,
batch_index=batch_index,
label=label,
n_samples=n_samples,
transform_batch=b,
)
px_scale += outputs["px_"]["scale"]
py_ = outputs["py_"]
# probability of background
protein_mixing = 1 / (1 + torch.exp(-py_["mixing"]))
if sample_protein_mixing is True:
protein_mixing = Bernoulli(protein_mixing).sample()
py_scale += py_["rate_fore"] * (1 - protein_mixing)
px_scale /= len(transform_batch)
py_scale /= len(transform_batch)
scale_list_gene.append(px_scale.cpu())
scale_list_pro.append(py_scale.cpu())
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
scale_list_gene = torch.cat(scale_list_gene, dim=1)
scale_list_pro = torch.cat(scale_list_pro, dim=1)
# (cells, features, samples)
scale_list_gene = scale_list_gene.permute(1, 2, 0)
scale_list_pro = scale_list_pro.permute(1, 2, 0)
else:
scale_list_gene = torch.cat(scale_list_gene, dim=0)
scale_list_pro = torch.cat(scale_list_pro, dim=0)
if give_mean is True and n_samples > 1:
scale_list_gene = torch.mean(scale_list_gene, dim=-1)
scale_list_pro = torch.mean(scale_list_pro, dim=-1)
scale_list_gene = scale_list_gene.cpu().numpy()
scale_list_pro = scale_list_pro.cpu().numpy()
return scale_list_gene, scale_list_pro
[docs] @torch.no_grad()
def get_protein_mean(
self,
n_samples: int = 1,
give_mean: bool = True,
transform_batch: Optional[Union[int, List[int]]] = None,
) -> np.ndarray:
"""Returns the tensors of protein mean (with foreground and background)
Parameters
----------
n_samples
number of samples from posterior distribution
give_mean
bool, whether to return samples along first axis or average over samples
transform_batch
Batches to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- list of int, then values are averaged over provided batches.
Returns
-------
Protein NB Mixture mean
"""
if (transform_batch is None) or (isinstance(transform_batch, int)):
transform_batch = [transform_batch]
rate_list_pro = []
for tensors in self:
x, _, _, batch_index, label, y = tensors
protein_rate = torch.zeros_like(y)
if n_samples > 1:
protein_rate = torch.stack(n_samples * [protein_rate])
for b in transform_batch:
outputs = self.model.inference(
x,
y,
batch_index=batch_index,
label=label,
n_samples=n_samples,
transform_batch=b,
)
py_ = outputs["py_"]
pi = 1 / (1 + torch.exp(-py_["mixing"]))
protein_rate += py_["rate_fore"] * (1 - pi) + py_["rate_back"] * pi
protein_rate /= len(transform_batch)
rate_list_pro.append(protein_rate.cpu())
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
rate_list_pro = torch.cat(rate_list_pro, dim=1)
# (cells, features, samples)
rate_list_pro = rate_list_pro.permute(1, 2, 0)
else:
rate_list_pro = torch.cat(rate_list_pro, dim=0)
if give_mean is True and n_samples > 1:
rate_list_pro = torch.mean(rate_list_pro, dim=-1)
rate_list_pro = rate_list_pro.cpu().numpy()
return rate_list_pro
[docs] @torch.no_grad()
def generate_denoised_samples(
self,
n_samples: int = 25,
batch_size: int = 64,
rna_size_factor: int = 1,
transform_batch: Optional[int] = None,
):
"""Samples from an adjusted posterior predictive. Proteins are concatenated to genes.
Parameters
----------
n_samples
How may samples per cell
batch_size
Mini-batch size for sampling. Lower means less GPU memory footprint
rna_size_factor
size factor for RNA prior to sampling gamma distribution
transform_batch
int of which batch to condition on for all cells
Returns
-------
"""
posterior_list = []
for tensors in self.update({"batch_size": batch_size}):
x, _, _, batch_index, labels, y = tensors
with torch.no_grad():
outputs = self.model.inference(
x,
y,
batch_index=batch_index,
label=labels,
n_samples=n_samples,
transform_batch=transform_batch,
)
px_ = outputs["px_"]
py_ = outputs["py_"]
pi = 1 / (1 + torch.exp(-py_["mixing"]))
mixing_sample = Bernoulli(pi).sample()
protein_rate = py_["rate_fore"]
rate = torch.cat((rna_size_factor * px_["scale"], protein_rate), dim=-1)
if len(px_["r"].size()) == 2:
px_dispersion = px_["r"]
else:
px_dispersion = torch.ones_like(x) * px_["r"]
if len(py_["r"].size()) == 2:
py_dispersion = py_["r"]
else:
py_dispersion = torch.ones_like(y) * py_["r"]
dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)
# This gamma is really l*w using scVI manuscript notation
p = rate / (rate + dispersion)
r = dispersion
l_train = Gamma(r, (1 - p) / p).sample()
data = l_train.cpu().numpy()
# make background 0
data[:, :, self.gene_dataset.nb_genes :] = (
data[:, :, self.gene_dataset.nb_genes :]
* (1 - mixing_sample).cpu().numpy()
)
posterior_list += [data]
posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))
return np.concatenate(posterior_list, axis=0)
[docs] @torch.no_grad()
def generate_feature_correlation_matrix(
self,
n_samples: int = 25,
batch_size: int = 64,
rna_size_factor: int = 1000,
transform_batch: Optional[Union[int, List[int]]] = None,
correlation_mode: str = "pearson",
log_transform: bool = False,
):
"""Wrapper of `generate_denoised_samples()` to create a gene-protein gene-protein corr matrix
Parameters
----------
n_samples
How may samples per cell
batch_size
Mini-batch size for sampling. Lower means less GPU memory footprint
rna_size_factor
size factor for RNA prior to sampling gamma distribution
transform_batch
Batches to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- list of int, then values are averaged over provided batches.
log_transform
Whether to log transform denoised values prior to correlation calculation
Returns
-------
Correlation matrix
"""
if (transform_batch is None) or (isinstance(transform_batch, int)):
transform_batch = [transform_batch]
corr_mats = []
for b in transform_batch:
denoised_data = self.generate_denoised_samples(
n_samples=n_samples,
batch_size=batch_size,
rna_size_factor=rna_size_factor,
transform_batch=b,
)
flattened = np.zeros(
(denoised_data.shape[0] * n_samples, denoised_data.shape[1])
)
for i in range(n_samples):
flattened[
denoised_data.shape[0] * (i) : denoised_data.shape[0] * (i + 1)
] = denoised_data[:, :, i]
if log_transform is True:
flattened[:, : self.gene_dataset.nb_genes] = np.log(
flattened[:, : self.gene_dataset.nb_genes] + 1e-8
)
flattened[:, self.gene_dataset.nb_genes :] = np.log1p(
flattened[:, self.gene_dataset.nb_genes :]
)
if correlation_mode == "pearson":
corr_matrix = np.corrcoef(flattened, rowvar=False)
else:
corr_matrix = spearmanr(flattened, axis=0)
corr_mats.append(corr_matrix)
corr_matrix = np.mean(np.stack(corr_mats), axis=0)
return corr_matrix
[docs] @torch.no_grad()
def imputation(self, n_samples: int = 1):
"""Gene imputation
Parameters
----------
n_samples
(Default value = 1)
Returns
-------
"""
imputed_list = []
for tensors in self:
x, _, _, batch_index, label, y = tensors
px_rate = self.model.get_sample_rate(
x, y, batch_index=batch_index, label=label, n_samples=n_samples
)
imputed_list += [np.array(px_rate.cpu())]
imputed_list = np.concatenate(imputed_list)
return imputed_list.squeeze()
[docs] @torch.no_grad()
def imputation_list(self, n_samples: int = 1):
"""This code is identical to same function in posterior.py
Except, we use the totalVI definition of `model.get_sample_rate`
Parameters
----------
n_samples
(Default value = 1)
Returns
-------
"""
original_list = []
imputed_list = []
batch_size = self.data_loader_kwargs["batch_size"] // n_samples
for tensors, corrupted_tensors in zip(
self.uncorrupted().sequential(batch_size=batch_size),
self.corrupted().sequential(batch_size=batch_size),
):
batch = tensors[0]
actual_batch_size = batch.size(0)
dropout_x, _, _, batch_index, labels, y = corrupted_tensors
px_rate = self.model.get_sample_rate(
dropout_x, y, batch_index=batch_index, label=labels, n_samples=n_samples
)
px_rate = px_rate[:, : self.gene_dataset.nb_genes]
indices_dropout = torch.nonzero(batch - dropout_x)
if indices_dropout.size() != torch.Size([0]):
i = indices_dropout[:, 0]
j = indices_dropout[:, 1]
batch = batch.unsqueeze(0).expand(
(n_samples, batch.size(0), batch.size(1))
)
original = np.array(batch[:, i, j].view(-1).cpu())
imputed = np.array(px_rate[..., i, j].view(-1).cpu())
cells_index = np.tile(np.array(i.cpu()), n_samples)
original_list += [
original[cells_index == i] for i in range(actual_batch_size)
]
imputed_list += [
imputed[cells_index == i] for i in range(actual_batch_size)
]
else:
original_list = np.array([])
imputed_list = np.array([])
return original_list, imputed_list
[docs] @torch.no_grad()
def differential_expression_score(
self,
idx1: Union[List[bool], np.ndarray],
idx2: Union[List[bool], np.ndarray],
mode: Optional[str] = "vanilla",
batchid1: Optional[Union[List[int], np.ndarray]] = None,
batchid2: Optional[Union[List[int], np.ndarray]] = None,
use_observed_batches: Optional[bool] = False,
n_samples: int = 5000,
use_permutation: bool = True,
M_permutation: int = 10000,
all_stats: bool = True,
change_fn: Optional[Union[str, Callable]] = None,
m1_domain_fn: Optional[Callable] = None,
delta: Optional[float] = 0.5,
cred_interval_lvls: Optional[Union[List[float], np.ndarray]] = None,
**kwargs,
) -> pd.DataFrame:
r"""Unified method for differential expression inference.
This function is an extension of the `get_bayes_factors` method
providing additional genes information to the user
Two modes coexist:
- the "vanilla" mode follows protocol described in [Lopez18]_
In this case, we perform hypothesis testing based on the hypotheses
.. math::
M_1: h_1 > h_2 ~\text{and}~ M_2: h_1 \leq h_2
DE can then be based on the study of the Bayes factors
.. math::
\log p(M_1 | x_1, x_2) / p(M_2 | x_1, x_2)
- the "change" mode (described in [Boyeau19]_)
consists in estimating an effect size random variable (e.g., log fold-change) and
performing Bayesian hypothesis testing on this variable.
The `change_fn` function computes the effect size variable r based two inputs
corresponding to the normalized means in both populations.
Hypotheses:
.. math::
M_1: r \in R_1 ~\text{(effect size r in region inducing differential expression)}
.. math::
M_2: r \notin R_1 ~\text{(no differential expression)}
To characterize the region :math:`R_1`, which induces DE, the user has two choices.
1. A common case is when the region :math:`[-\delta, \delta]` does not induce differential
expression.
If the user specifies a threshold delta,
we suppose that :math:`R_1 = \mathbb{R} \setminus [-\delta, \delta]`
2. specify an specific indicator function
.. math::
f: \mathbb{R} \mapsto \{0, 1\} ~\text{s.t.}~ r \in R_1 ~\text{iff.}~ f(r) = 1
Decision-making can then be based on the estimates of
.. math::
p(M_1 \mid x_1, x_2)
Both modes require to sample the normalized means posteriors.
To that purpose, we sample the Posterior in the following way:
1. The posterior is sampled n_samples times for each subpopulation
2. For computation efficiency (posterior sampling is quite expensive), instead of
comparing the obtained samples element-wise, we can permute posterior samples.
Remember that computing the Bayes Factor requires sampling
:math:`q(z_A \mid x_A)` and :math:`q(z_B \mid x_B)`
Currently, the code covers several batch handling configurations:
1. If ``use_observed_batches=True``, then batch are considered as observations
and cells' normalized means are conditioned on real batch observations
2. If case (cell group 1) and control (cell group 2) are conditioned on the same
batch ids.
Examples:
>>> set(batchid1) = set(batchid2)
or
>>> batchid1 = batchid2 = None
3. If case and control are conditioned on different batch ids that do not intersect
i.e.,
>>> set(batchid1) != set(batchid2)
and
>>> len(set(batchid1).intersection(set(batchid2))) == 0
This function does not cover other cases yet and will warn users in such cases.
Parameters
----------
mode
one of ["vanilla", "change"]
idx1
bool array masking subpopulation cells 1. Should be True where cell is
from associated population
idx2
bool array masking subpopulation cells 2. Should be True where cell is
from associated population
batchid1
List of batch ids for which you want to perform DE Analysis for
subpopulation 1. By default, all ids are taken into account
batchid2
List of batch ids for which you want to perform DE Analysis for
subpopulation 2. By default, all ids are taken into account
use_observed_batches
Whether normalized means are conditioned on observed
batches
n_samples
Number of posterior samples
use_permutation
Activates step 2 described above.
Simply formulated, pairs obtained from posterior sampling (when calling
`sample_scale_from_batch`) will be randomly permuted so that the number of
pairs used to compute Bayes Factors becomes M_permutation.
M_permutation
Number of times we will "mix" posterior samples in step 2.
Only makes sense when use_permutation=True
change_fn
function computing effect size based on both normalized means
m1_domain_fn
custom indicator function of effect size regions
inducing differential expression
delta
specific case of region inducing differential expression.
In this case, we suppose that R \setminus [-\delta, \delta] does not induce differential expression
(LFC case)
cred_interval_lvls
List of credible interval levels to compute for the posterior
LFC distribution
all_stats
whether additional metrics should be provided
**kwargs
Other keywords arguments for `get_sample_scale`
Returns
-------
diff_exp_results
The most important columns are:
- ``proba_de`` (probability of being differentially expressed in change mode)
- ``bayes_factor`` (bayes factors in the vanilla mode)
- ``scale1`` and ``scale2`` (means of the scales in population 1 and 2)
- When using the change mode, the mean, median, std of the posterior LFC
"""
all_info = self.get_bayes_factors(
idx1=idx1,
idx2=idx2,
mode=mode,
batchid1=batchid1,
batchid2=batchid2,
use_observed_batches=use_observed_batches,
n_samples=n_samples,
use_permutation=use_permutation,
M_permutation=M_permutation,
change_fn=change_fn,
m1_domain_fn=m1_domain_fn,
cred_interval_lvls=cred_interval_lvls,
delta=delta,
**kwargs,
)
col_names = np.concatenate(
[self.gene_dataset.gene_names, self.gene_dataset.protein_names]
)
if all_stats is True:
nan = np.array([np.nan] * len(self.gene_dataset.protein_names))
(
mean1,
mean2,
nonz1,
nonz2,
norm_mean1,
norm_mean2,
) = self.gene_dataset.raw_counts_properties(idx1, idx2)
mean1_pro = self.gene_dataset.protein_expression[idx1, :].mean(0)
mean2_pro = self.gene_dataset.protein_expression[idx2, :].mean(0)
nonz1_pro = (self.gene_dataset.protein_expression[idx1, :] > 0).mean(0)
nonz2_pro = (self.gene_dataset.protein_expression[idx2, :] > 0).mean(0)
# TODO implement properties for proteins
genes_properties_dict = dict(
raw_mean1=np.concatenate([mean1, mean1_pro]),
raw_mean2=np.concatenate([mean2, mean2_pro]),
non_zeros_proportion1=np.concatenate([nonz1, nonz1_pro]),
non_zeros_proportion2=np.concatenate([nonz2, nonz2_pro]),
raw_normalized_mean1=np.concatenate([norm_mean1, nan]),
raw_normalized_mean2=np.concatenate([norm_mean2, nan]),
)
all_info = {**all_info, **genes_properties_dict}
res = pd.DataFrame(all_info, index=col_names)
sort_key = "proba_de" if mode == "change" else "bayes_factor"
res = res.sort_values(by=sort_key, ascending=False)
return res
[docs] @torch.no_grad()
def generate_parameters(self):
raise NotImplementedError
default_early_stopping_kwargs = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"patience": 45,
"threshold": 0,
"reduce_lr_on_plateau": True,
"lr_patience": 30,
"lr_factor": 0.6,
"posterior_class": TotalPosterior,
}
[docs]class TotalTrainer(UnsupervisedTrainer):
"""Unsupervised training for totalVI using variational inference
Parameters
----------
model
A model instance from class ``TOTALVI``
gene_dataset
A gene_dataset instance like ``CbmcDataset()`` with attribute ``protein_expression``
train_size
The train size, a float between 0 and 1 representing proportion of dataset to use for training
to use Default: ``0.90``.
test_size
The test size, a float between 0 and 1 representing proportion of dataset to use for testing
to use Default: ``0.10``. Note that if train and test do not add to 1 the remainder is placed in a validation set
pro_recons_weight
Scaling factor on the reconstruction loss for proteins. Default: ``1.0``.
n_epochs_kl_warmup
Number of epochs for annealing the KL terms for `z` and `mu` of the ELBO (from 0 to 1). If None, no warmup performed, unless
`n_iter_kl_warmup` is set.
n_iter_kl_warmup
Number of minibatches for annealing the KL terms for `z` and `mu` of the ELBO (from 0 to 1). If set to "auto", the number
of iterations is equal to 75% of the number of cells. `n_epochs_kl_warmup` takes precedence if it is not None. If both are None, then
no warmup is performed.
discriminator
Classifier used for adversarial training scheme
use_adversarial_loss
Whether to use adversarial classifier to improve mixing
kappa
Scaling factor for adversarial loss. If None, follow inverse of kl warmup schedule.
early_stopping_kwargs
Keyword args for early stopping. If "auto", use totalVI defaults. If None, disable early stopping.
"""
default_metrics_to_monitor = ["elbo"]
def __init__(
self,
model: TOTALVI,
dataset: GeneExpressionDataset,
train_size: float = 0.9,
test_size: float = 0.1,
pro_recons_weight: float = 1.0,
n_epochs_kl_warmup: int = None,
n_iter_kl_warmup: Union[str, int] = "auto",
discriminator: Classifier = None,
use_adversarial_loss: bool = False,
kappa: float = None,
early_stopping_kwargs: Union[dict, str, None] = "auto",
**kwargs,
):
train_size = float(train_size)
if train_size > 1.0 or train_size <= 0.0:
raise ValueError(
"train_size needs to be greater than 0 and less than or equal to 1"
)
self.n_genes = dataset.nb_genes
self.n_proteins = model.n_input_proteins
self.use_adversarial_loss = use_adversarial_loss
self.kappa = kappa
self.pro_recons_weight = pro_recons_weight
if early_stopping_kwargs == "auto":
early_stopping_kwargs = default_early_stopping_kwargs
super().__init__(
model,
dataset,
n_epochs_kl_warmup=n_epochs_kl_warmup,
n_iter_kl_warmup=0.75 * len(dataset)
if n_iter_kl_warmup == "auto"
else n_iter_kl_warmup,
early_stopping_kwargs=early_stopping_kwargs,
**kwargs,
)
if use_adversarial_loss is True and discriminator is None:
discriminator = Classifier(
n_input=self.model.n_latent,
n_hidden=32,
n_labels=self.gene_dataset.n_batches,
n_layers=2,
logits=True,
)
self.discriminator = discriminator
if self.use_cuda and self.discriminator is not None:
self.discriminator.cuda()
if type(self) is TotalTrainer:
(
self.train_set,
self.test_set,
self.validation_set,
) = self.train_test_validation(
model, dataset, train_size, test_size, type_class=TotalPosterior
)
self.train_set.to_monitor = []
self.test_set.to_monitor = ["elbo"]
self.validation_set.to_monitor = ["elbo"]
[docs] def loss(self, tensors):
(
sample_batch_X,
local_l_mean,
local_l_var,
batch_index,
label,
sample_batch_Y,
) = tensors
(
reconst_loss_gene,
reconst_loss_protein,
kl_div_z,
kl_div_l_gene,
kl_div_back_pro,
) = self.model(
sample_batch_X,
sample_batch_Y,
local_l_mean,
local_l_var,
batch_index,
label,
)
loss = torch.mean(
reconst_loss_gene
+ self.pro_recons_weight * reconst_loss_protein
+ self.kl_weight * kl_div_z
+ kl_div_l_gene
+ self.kl_weight * kl_div_back_pro
)
return loss
[docs] def loss_discriminator(
self, z, batch_index, predict_true_class=True, return_details=True
):
n_classes = self.gene_dataset.n_batches
cls_logits = torch.nn.LogSoftmax(dim=1)(self.discriminator(z))
if predict_true_class:
cls_target = one_hot(batch_index, n_classes)
else:
one_hot_batch = one_hot(batch_index, n_classes)
cls_target = torch.zeros_like(one_hot_batch)
# place zeroes where true label is
cls_target.masked_scatter_(
~one_hot_batch.bool(), torch.ones_like(one_hot_batch) / (n_classes - 1)
)
l_soft = cls_logits * cls_target
loss = -l_soft.sum(dim=1).mean()
return loss
def _get_z(self, tensors):
(
sample_batch_X,
local_l_mean,
local_l_var,
batch_index,
label,
sample_batch_Y,
) = tensors
z = self.model.sample_from_posterior_z(
sample_batch_X, sample_batch_Y, batch_index, give_mean=False
)
return z
[docs] def train(self, n_epochs=500, lr=4e-3, eps=0.01, params=None):
super().train(n_epochs=n_epochs, lr=lr, eps=eps, params=params)
[docs] def on_training_loop(self, tensors_list):
if self.use_adversarial_loss:
if self.kappa is None:
kappa = 1 - self.kl_weight
else:
kappa = self.kappa
batch_index = tensors_list[0][3]
if kappa > 0:
z = self._get_z(*tensors_list)
# Train discriminator
d_loss = self.loss_discriminator(z.detach(), batch_index, True)
d_loss *= kappa
self.d_optimizer.zero_grad()
d_loss.backward()
self.d_optimizer.step()
# Train generative model to fool discriminator
fool_loss = self.loss_discriminator(z, batch_index, False)
fool_loss *= kappa
# Train generative model
self.optimizer.zero_grad()
self.current_loss = loss = self.loss(*tensors_list)
if kappa > 0:
(loss + fool_loss).backward()
else:
loss.backward()
self.optimizer.step()
else:
self.current_loss = loss = self.loss(*tensors_list)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()