scvi.nn.Encoder#

class scvi.nn.Encoder(n_input, n_output, n_cat_list=None, n_layers=1, n_hidden=128, dropout_rate=0.1, distribution='normal', var_eps=0.0001, var_activation=None, return_dist=False, **kwargs)[source]#

Bases: Module

Encode data of n_input dimensions into a latent space of n_output dimensions.

Uses a fully-connected neural network of n_hidden layers.

Parameters:
  • n_input (int) – The dimensionality of the input (data space)

  • n_output (int) – The dimensionality of the output (latent space)

  • n_cat_list (Iterable[int] (default: None)) – A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding

  • n_layers (int (default: 1)) – The number of fully-connected hidden layers

  • n_hidden (int (default: 128)) – The number of nodes per hidden layer

  • dropout_rate (float (default: 0.1)) – Dropout rate to apply to each of the hidden layers

  • distribution (str (default: 'normal')) – Distribution of z

  • var_eps (float (default: 0.0001)) – Minimum value for the variance; used for numerical stability

  • var_activation (Optional[Callable] (default: None)) – Callable used to ensure positivity of the variance. Defaults to torch.exp().

  • return_dist (bool (default: False)) – Return directly the distribution of z instead of its parameters.

  • **kwargs – Keyword args for FCLayers

Attributes table#

training

Methods table#

forward(x, *cat_list)

The forward computation for a single sample.

Attributes#

Encoder.training: bool#

Methods#

Encoder.forward(x, *cat_list)[source]#

The forward computation for a single sample.

  1. Encodes the data into latent space using the encoder network

  2. Generates a mean \( q_m \) and variance \( q_v \)

  3. Samples a new value from an i.i.d. multivariate normal \( \sim Ne(q_m, \mathbf{I}q_v) \)

Parameters:
  • x (Tensor) – tensor with shape (n_input,)

  • cat_list (int) – list of category membership(s) for this sample

Returns:

3-tuple of torch.Tensor tensors of shape (n_latent,) for mean and var, and sample