scvi.nn.EncoderTOTALVI#
- class scvi.nn.EncoderTOTALVI(n_input, n_output, n_cat_list=None, n_layers=2, n_hidden=256, dropout_rate=0.1, distribution='ln', use_batch_norm=True, use_layer_norm=False)[source]#
Bases:
ModuleEncodes data of
n_inputdimensions into a latent space ofn_outputdimensions.Uses a fully-connected neural network of
n_hiddenlayers.- Parameters:
n_input (
int) – The dimensionality of the input (data space)n_output (
int) – The dimensionality of the output (latent space)n_cat_list (
Iterable[int] (default:None)) – A list containing the number of categories for each category of interest. Each category will be included using a one-hot encodingn_layers (
int(default:2)) – The number of fully-connected hidden layersn_hidden (
int(default:256)) – The number of nodes per hidden layerdropout_rate (
float(default:0.1)) – Dropout rate to apply to each of the hidden layersdistribution (
str(default:'ln')) –Distribution of the latent space, one of
'normal'- Normal distribution'ln'- Logistic normal
use_batch_norm (
bool(default:True)) – Whether to use batch norm in layersuse_layer_norm (
bool(default:False)) – Whether to use layer norm
Attributes table#
Methods table#
|
The forward computation for a single sample. |
|
Reparameterization trick to sample from a normal distribution. |
Attributes#
- EncoderTOTALVI.training: bool#
Methods#
- EncoderTOTALVI.forward(data, *cat_list)[source]#
The forward computation for a single sample.
Encodes the data into latent space using the encoder network
Generates a mean \( q_m \) and variance \( q_v \)
Samples a new value from an i.i.d. latent distribution
The dictionary
latentcontains the samples of the latent variables, whileuntran_latentcontains the untransformed versions of these latent variables. For example, the library size is log normally distributed, sountran_latent["l"]gives the normal sample that was later exponentiated to becomelatent["l"]. The logistic normal distribution is equivalent to applying softmax to a normal sample.- Parameters:
- Returns:
6-tuple. First 4 of
torch.Tensor, next 2 are dict oftorch.Tensortensors of shape(n_latent,)for mean and var, and sample