import collections
from typing import Iterable, List
import torch
from torch import nn as nn
from torch.distributions import Normal
from torch.nn import ModuleList
from scvi.models.utils import one_hot
[docs]def reparameterize_gaussian(mu, var):
return Normal(mu, var.sqrt()).rsample()
[docs]def identity(x):
return x
[docs]class FCLayers(nn.Module):
"""A helper class to build fully-connected layers for a neural network.
Parameters
----------
n_in
The dimensionality of the input
n_out
The dimensionality of the output
n_cat_list
A list containing, for each category of interest,
the number of categories. Each category will be
included using a one-hot encoding.
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
use_batch_norm
Whether to have `BatchNorm` layers or not
use_relu
Whether to have `ReLU` layers or not
bias
Whether to learn bias in linear layers or not
"""
def __init__(
self,
n_in: int,
n_out: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
use_batch_norm: bool = True,
use_relu: bool = True,
bias: bool = True,
):
super().__init__()
layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]
if n_cat_list is not None:
# n_cat = 1 will be ignored
self.n_cat_list = [n_cat if n_cat > 1 else 0 for n_cat in n_cat_list]
else:
self.n_cat_list = []
self.fc_layers = nn.Sequential(
collections.OrderedDict(
[
(
"Layer {}".format(i),
nn.Sequential(
nn.Linear(n_in + sum(self.n_cat_list), n_out, bias=bias),
# Below, 0.01 and 0.001 are the default values for `momentum` and `eps` from
# the tensorflow implementation of batch norm; we're using those settings
# here too so that the results match our old tensorflow code. The default
# setting from pytorch would probably be fine too but we haven't tested that.
nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001)
if use_batch_norm
else None,
nn.ReLU() if use_relu else None,
nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None,
),
)
for i, (n_in, n_out) in enumerate(
zip(layers_dim[:-1], layers_dim[1:])
)
]
)
)
[docs] def forward(self, x: torch.Tensor, *cat_list: int, instance_id: int = 0):
"""Forward computation on ``x``.
Parameters
----------
x
tensor of values with shape ``(n_in,)``
cat_list
list of category membership(s) for this sample
instance_id
Use a specific conditional instance normalization (batchnorm)
x: torch.Tensor
Returns
-------
py:class:`torch.Tensor`
tensor of shape ``(n_out,)``
"""
one_hot_cat_list = [] # for generality in this list many indices useless.
assert len(self.n_cat_list) <= len(
cat_list
), "nb. categorical args provided doesn't match init. params."
for n_cat, cat in zip(self.n_cat_list, cat_list):
assert not (
n_cat and cat is None
), "cat not provided while n_cat != 0 in init. params."
if n_cat > 1: # n_cat = 1 will be ignored - no additional information
if cat.size(1) != n_cat:
one_hot_cat = one_hot(cat, n_cat)
else:
one_hot_cat = cat # cat has already been one_hot encoded
one_hot_cat_list += [one_hot_cat]
for layers in self.fc_layers:
for layer in layers:
if layer is not None:
if isinstance(layer, nn.BatchNorm1d):
if x.dim() == 3:
x = torch.cat(
[(layer(slice_x)).unsqueeze(0) for slice_x in x], dim=0
)
else:
x = layer(x)
else:
if isinstance(layer, nn.Linear):
if x.dim() == 3:
one_hot_cat_list_layer = [
o.unsqueeze(0).expand(
(x.size(0), o.size(0), o.size(1))
)
for o in one_hot_cat_list
]
else:
one_hot_cat_list_layer = one_hot_cat_list
x = torch.cat((x, *one_hot_cat_list_layer), dim=-1)
x = layer(x)
return x
# Encoder
[docs]class Encoder(nn.Module):
"""Encodes data of ``n_input`` dimensions into a latent space of ``n_output``
dimensions using a fully-connected neural network of ``n_hidden`` layers.
Parameters
----------
n_input
The dimensionality of the input (data space)
n_output
The dimensionality of the output (latent space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
:dropout_rate: Dropout rate to apply to each of the hidden layers
distribution
Distribution of z
Returns
-------
"""
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
dropout_rate: float = 0.1,
distribution: str = "normal",
):
super().__init__()
self.distribution = distribution
self.encoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.mean_encoder = nn.Linear(n_hidden, n_output)
self.var_encoder = nn.Linear(n_hidden, n_output)
if distribution == "ln":
self.z_transformation = nn.Softmax(dim=-1)
else:
self.z_transformation = identity
[docs] def forward(self, x: torch.Tensor, *cat_list: int):
"""The forward computation for a single sample.
#. Encodes the data into latent space using the encoder network
#. Generates a mean \\( q_m \\) and variance \\( q_v \\)
#. Samples a new value from an i.i.d. multivariate normal \\( \\sim Ne(q_m, \\mathbf{I}q_v) \\)
Parameters
----------
x
tensor with shape (n_input,)
cat_list
list of category membership(s) for this sample
Returns
-------
3-tuple of :py:class:`torch.Tensor`
tensors of shape ``(n_latent,)`` for mean and var, and sample
"""
# Parameters for latent distribution
q = self.encoder(x, *cat_list)
q_m = self.mean_encoder(q)
q_v = torch.exp(self.var_encoder(q)) + 1e-4
latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
return q_m, q_v, latent
# Decoder
[docs]class DecoderSCVI(nn.Module):
"""Decodes data from latent space of ``n_input`` dimensions ``n_output``
dimensions using a fully-connected neural network of ``n_hidden`` layers.
Parameters
----------
n_input
The dimensionality of the input (latent space)
n_output
The dimensionality of the output (data space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
Returns
-------
"""
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
):
super().__init__()
self.px_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
)
# mean gamma
self.px_scale_decoder = nn.Sequential(
nn.Linear(n_hidden, n_output), nn.Softmax(dim=-1)
)
# dispersion: here we only deal with gene-cell dispersion case
self.px_r_decoder = nn.Linear(n_hidden, n_output)
# dropout
self.px_dropout_decoder = nn.Linear(n_hidden, n_output)
[docs] def forward(
self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int
):
"""The forward computation for a single sample.
#. Decodes the data from the latent space using the decoder network
#. Returns parameters for the ZINB distribution of expression
#. If ``dispersion != 'gene-cell'`` then value for that param will be ``None``
Parameters
----------
dispersion
One of the following
* ``'gene'`` - dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - dispersion can differ between different batches
* ``'gene-label'`` - dispersion can differ between different labels
* ``'gene-cell'`` - dispersion can differ for every gene in every cell
z :
tensor with shape ``(n_input,)``
library
library size
cat_list
list of category membership(s) for this sample
Returns
-------
4-tuple of :py:class:`torch.Tensor`
parameters for the ZINB distribution of expression
"""
# The decoder returns values for the parameters of the ZINB distribution
px = self.px_decoder(z, *cat_list)
px_scale = self.px_scale_decoder(px)
px_dropout = self.px_dropout_decoder(px)
# Clamp to high value: exp(12) ~ 160000 to avoid nans (computational stability)
px_rate = torch.exp(library) * px_scale # torch.clamp( , max=12)
px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None
return px_scale, px_r, px_rate, px_dropout
[docs]class LinearDecoderSCVI(nn.Module):
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
use_batch_norm: bool = True,
bias: bool = False,
):
super(LinearDecoderSCVI, self).__init__()
# mean gamma
self.factor_regressor = FCLayers(
n_in=n_input,
n_out=n_output,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=use_batch_norm,
bias=bias,
dropout_rate=0,
)
# dropout
self.px_dropout_decoder = FCLayers(
n_in=n_input,
n_out=n_output,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=use_batch_norm,
bias=bias,
dropout_rate=0,
)
[docs] def forward(
self, dispersion: str, z: torch.Tensor, library: torch.Tensor, *cat_list: int
):
# The decoder returns values for the parameters of the ZINB distribution
raw_px_scale = self.factor_regressor(z, *cat_list)
px_scale = torch.softmax(raw_px_scale, dim=-1)
px_dropout = self.px_dropout_decoder(z, *cat_list)
px_rate = torch.exp(library) * px_scale
px_r = None
return px_scale, px_r, px_rate, px_dropout
# Decoder
[docs]class Decoder(nn.Module):
"""Decodes data from latent space to data space
``n_input`` dimensions to ``n_output``
dimensions using a fully-connected neural network of ``n_hidden`` layers.
Output is the mean and variance of a multivariate Gaussian
Parameters
----------
n_input
The dimensionality of the input (latent space)
n_output
The dimensionality of the output (data space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
Returns
-------
"""
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 128,
):
super().__init__()
self.decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
)
self.mean_decoder = nn.Linear(n_hidden, n_output)
self.var_decoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, *cat_list: int):
"""The forward computation for a single sample.
#. Decodes the data from the latent space using the decoder network
#. Returns tensors for the mean and variance of a multivariate distribution
Parameters
----------
x
tensor with shape ``(n_input,)``
cat_list
list of category membership(s) for this sample
Returns
-------
2-tuple of :py:class:`torch.Tensor`
Mean and variance tensors of shape ``(n_output,)``
"""
# Parameters for latent distribution
p = self.decoder(x, *cat_list)
p_m = self.mean_decoder(p)
p_v = torch.exp(self.var_decoder(p))
return p_m, p_v
[docs]class MultiEncoder(nn.Module):
def __init__(
self,
n_heads: int,
n_input_list: List[int],
n_output: int,
n_hidden: int = 128,
n_layers_individual: int = 1,
n_layers_shared: int = 2,
n_cat_list: Iterable[int] = None,
dropout_rate: float = 0.1,
):
super().__init__()
self.encoders = ModuleList(
[
FCLayers(
n_in=n_input_list[i],
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers_individual,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=True,
)
for i in range(n_heads)
]
)
self.encoder_shared = FCLayers(
n_in=n_hidden,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers_shared,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.mean_encoder = nn.Linear(n_hidden, n_output)
self.var_encoder = nn.Linear(n_hidden, n_output)
[docs] def forward(self, x: torch.Tensor, head_id: int, *cat_list: int):
q = self.encoders[head_id](x, *cat_list)
q = self.encoder_shared(q, *cat_list)
q_m = self.mean_encoder(q)
q_v = torch.exp(self.var_encoder(q))
latent = reparameterize_gaussian(q_m, q_v)
return q_m, q_v, latent
[docs]class MultiDecoder(nn.Module):
def __init__(
self,
n_input: int,
n_output: int,
n_hidden_conditioned: int = 32,
n_hidden_shared: int = 128,
n_layers_conditioned: int = 1,
n_layers_shared: int = 1,
n_cat_list: Iterable[int] = None,
dropout_rate: float = 0.2,
):
super().__init__()
n_out = n_hidden_conditioned if n_layers_shared else n_hidden_shared
if n_layers_conditioned:
self.px_decoder_conditioned = FCLayers(
n_in=n_input,
n_out=n_out,
n_cat_list=n_cat_list,
n_layers=n_layers_conditioned,
n_hidden=n_hidden_conditioned,
dropout_rate=dropout_rate,
use_batch_norm=True,
)
n_in = n_out
else:
self.px_decoder_conditioned = None
n_in = n_input
if n_layers_shared:
self.px_decoder_final = FCLayers(
n_in=n_in,
n_out=n_hidden_shared,
n_cat_list=[],
n_layers=n_layers_shared,
n_hidden=n_hidden_shared,
dropout_rate=dropout_rate,
use_batch_norm=True,
)
n_in = n_hidden_shared
else:
self.px_decoder_final = None
self.px_scale_decoder = nn.Sequential(
nn.Linear(n_in, n_output), nn.Softmax(dim=-1)
)
self.px_r_decoder = nn.Linear(n_in, n_output)
self.px_dropout_decoder = nn.Linear(n_in, n_output)
[docs] def forward(
self,
z: torch.Tensor,
dataset_id: int,
library: torch.Tensor,
dispersion: str,
*cat_list: int
):
px = z
if self.px_decoder_conditioned:
px = self.px_decoder_conditioned(px, *cat_list, instance_id=dataset_id)
if self.px_decoder_final:
px = self.px_decoder_final(px, *cat_list)
px_scale = self.px_scale_decoder(px)
px_dropout = self.px_dropout_decoder(px)
px_rate = torch.exp(library) * px_scale
px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None
return px_scale, px_r, px_rate, px_dropout
[docs]class DecoderTOTALVI(nn.Module):
"""Decodes data from latent space of ``n_input`` dimensions ``n_output``
dimensions using a linear decoder
Parameters
----------
n_input
The dimensionality of the input (latent space)
n_output_genes
The dimensionality of the output (gene space)
n_output_proteins
The dimensionality of the output (protein space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
Returns
-------
"""
def __init__(
self,
n_input: int,
n_output_genes: int,
n_output_proteins: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 1,
n_hidden: int = 256,
dropout_rate: float = 0,
):
super().__init__()
self.n_output_genes = n_output_genes
self.n_output_proteins = n_output_proteins
self.px_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
# mean gamma
self.px_scale_decoder = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_genes,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=False,
dropout_rate=0,
)
# background mean first decoder
self.py_back_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
# background mean parameters second decoder
self.py_back_mean_log_alpha = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=False,
dropout_rate=0,
)
self.py_back_mean_log_beta = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=False,
dropout_rate=0,
)
# foreground increment decoder step 1
self.py_fore_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
# foreground increment decoder step 2
self.py_fore_scale_decoder = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=True,
use_batch_norm=False,
dropout_rate=0,
)
# dropout (mixture component for proteins, ZI probability for genes)
self.sigmoid_decoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.px_dropout_decoder_gene = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_genes,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=False,
dropout_rate=0,
)
self.py_background_decoder = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_layers=1,
use_relu=False,
use_batch_norm=False,
dropout_rate=0,
)
[docs] def forward(self, z: torch.Tensor, library_gene: torch.Tensor, *cat_list: int):
"""The forward computation for a single sample.
#. Decodes the data from the latent space using the decoder network
#. Returns local parameters for the ZINB distribution for genes
#. Returns local parameters for the Mixture NB distribution for proteins
We use the dictionary `px_` to contain the parameters of the ZINB/NB for genes.
The rate refers to the mean of the NB, dropout refers to Bernoulli mixing parameters.
`scale` refers to the quanity upon which differential expression is performed. For genes,
this can be viewed as the mean of the underlying gamma distribution.
We use the dictionary `py_` to contain the parameters of the Mixture NB distribution for proteins.
`rate_fore` refers to foreground mean, while `rate_back` refers to background mean. `scale` refers to
foreground mean adjusted for background probability and scaled to reside in simplex.
`back_alpha` and `back_beta` are the posterior parameters for `rate_back`. `fore_scale` is the scaling
factor that enforces `rate_fore` > `rate_back`.
Parameters
----------
z
tensor with shape ``(n_input,)``
library_gene
library size
cat_list
list of category membership(s) for this sample
Returns
-------
3-tuple (first 2-tuple :py:class:`dict`, last :py:class:`torch.Tensor`)
parameters for the ZINB distribution of expression
"""
px_ = {}
py_ = {}
px = self.px_decoder(z, *cat_list)
px_cat_z = torch.cat([px, z], dim=-1)
px_["scale"] = nn.Softmax(dim=-1)(self.px_scale_decoder(px_cat_z, *cat_list))
px_["rate"] = library_gene * px_["scale"]
py_back = self.py_back_decoder(z, *cat_list)
py_back_cat_z = torch.cat([py_back, z], dim=-1)
py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list)
py_["back_beta"] = torch.exp(
self.py_back_mean_log_beta(py_back_cat_z, *cat_list)
)
log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample()
py_["rate_back"] = torch.exp(log_pro_back_mean)
py_fore = self.py_fore_decoder(z, *cat_list)
py_fore_cat_z = torch.cat([py_fore, z], dim=-1)
py_["fore_scale"] = (
self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8
)
py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"]
p_mixing = self.sigmoid_decoder(z, *cat_list)
p_mixing_cat_z = torch.cat([p_mixing, z], dim=-1)
px_["dropout"] = self.px_dropout_decoder_gene(p_mixing_cat_z, *cat_list)
py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, *cat_list)
protein_mixing = 1 / (1 + torch.exp(-py_["mixing"]))
py_["scale"] = torch.nn.functional.normalize(
(1 - protein_mixing) * py_["rate_fore"], p=1, dim=-1
)
return (px_, py_, log_pro_back_mean)
# Encoder
[docs]class EncoderTOTALVI(nn.Module):
"""Encodes data of ``n_input`` dimensions into a latent space of ``n_output``
dimensions using a fully-connected neural network of ``n_hidden`` layers.
Parameters
----------
n_input
The dimensionality of the input (data space)
n_output
The dimensionality of the output (latent space)
n_cat_list
A list containing the number of categories
for each category of interest. Each category will be
included using a one-hot encoding
n_layers
The number of fully-connected hidden layers
n_hidden
The number of nodes per hidden layer
dropout_rate
Dropout rate to apply to each of the hidden layers
distribution
Distribution of the latent space, one of
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal
Returns
-------
"""
def __init__(
self,
n_input: int,
n_output: int,
n_cat_list: Iterable[int] = None,
n_layers: int = 2,
n_hidden: int = 256,
dropout_rate: float = 0.1,
distribution: str = "ln",
):
super().__init__()
self.encoder = FCLayers(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
self.z_encoder = nn.Sequential(
nn.Linear(n_hidden, n_hidden),
nn.BatchNorm1d(n_hidden),
nn.ReLU(),
nn.Dropout(p=dropout_rate),
)
self.z_mean_encoder = nn.Linear(n_hidden, n_output)
self.z_var_encoder = nn.Linear(n_hidden, n_output)
self.l_gene_encoder = nn.Sequential(
nn.Linear(n_hidden, n_hidden),
nn.BatchNorm1d(n_hidden),
nn.ReLU(),
nn.Dropout(p=dropout_rate),
)
self.l_gene_mean_encoder = nn.Linear(n_hidden, 1)
self.l_gene_var_encoder = nn.Linear(n_hidden, 1)
self.distribution = distribution
if distribution == "ln":
self.z_transformation = nn.Softmax(dim=-1)
else:
self.z_transformation = identity
self.l_transformation = torch.exp
[docs] def forward(self, data: torch.Tensor, *cat_list: int):
"""The forward computation for a single sample.
#. Encodes the data into latent space using the encoder network
#. Generates a mean \\( q_m \\) and variance \\( q_v \\)
#. Samples a new value from an i.i.d. latent distribution
The dictionary ``latent`` contains the samples of the latent variables, while ``untran_latent``
contains the untransformed versions of these latent variables. For example, the library size is log normally distributed,
so ``untran_latent["l"]`` gives the normal sample that was later exponentiated to become ``latent["l"]``.
The logistic normal distribution is equivalent to applying softmax to a normal sample.
Parameters
----------
data
tensor with shape ``(n_input,)``
cat_list
list of category membership(s) for this sample
Returns
-------
6-tuple. First 4 of :py:class:`torch.Tensor`, next 2 are `dict` of :py:class:`torch.Tensor`
tensors of shape ``(n_latent,)`` for mean and var, and sample
"""
# Parameters for latent distribution
q = self.encoder(data, *cat_list)
qz = self.z_encoder(q)
qz_m = self.z_mean_encoder(qz)
qz_v = torch.exp(self.z_var_encoder(qz)) + 1e-4
z, untran_z = self.reparameterize_transformation(qz_m, qz_v)
ql_gene = self.l_gene_encoder(q)
ql_m = self.l_gene_mean_encoder(ql_gene)
ql_v = torch.exp(self.l_gene_var_encoder(ql_gene)) + 1e-4
log_library_gene = torch.clamp(reparameterize_gaussian(ql_m, ql_v), max=15)
library_gene = self.l_transformation(log_library_gene)
latent = {}
untran_latent = {}
latent["z"] = z
latent["l"] = library_gene
untran_latent["z"] = untran_z
untran_latent["l"] = log_library_gene
return qz_m, qz_v, ql_m, ql_v, latent, untran_latent