class scvi.nn.Encoder(n_input, n_output, n_cat_list=None, n_layers=1, n_hidden=128, dropout_rate=0.1, distribution='normal', var_eps=0.0001, var_activation=None, **kwargs)[source]#

Bases: torch.nn.modules.module.Module

Encodes data of n_input dimensions into a latent space of n_output dimensions.

Uses a fully-connected neural network of n_hidden layers.

n_input : int

The dimensionality of the input (data space)

n_output : int

The dimensionality of the output (latent space)

n_cat_list : Iterable[int] | NoneOptional[Iterable[int]] (default: None)

A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding

n_layers : int (default: 1)

The number of fully-connected hidden layers

n_hidden : int (default: 128)

The number of nodes per hidden layer

dropout_rate : float (default: 0.1)

Dropout rate to apply to each of the hidden layers

distribution : str (default: 'normal')

Distribution of z

var_eps : float (default: 0.0001)

Minimum value for the variance; used for numerical stability

var_activation : Callable | NoneOptional[Callable] (default: None)

Callable used to ensure positivity of the variance. When None, defaults to torch.exp.


Keyword args for FCLayers

Attributes table#

Methods table#

forward(x, *cat_list)

The forward computation for a single sample.




alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor])

alias of TypeVar(‘T_destination’, bound=Mapping[str, torch.Tensor]) .. autoattribute:: Encoder.T_destination dump_patches ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Encoder.dump_patches: bool = False#

This allows better BC support for load_state_dict(). In state_dict(), the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See _load_from_state_dict on how to use this information in loading.

If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.

training# bool#



Encoder.forward(x, *cat_list)[source]#

The forward computation for a single sample.

  1. Encodes the data into latent space using the encoder network

  2. Generates a mean \( q_m \) and variance \( q_v \)

  3. Samples a new value from an i.i.d. multivariate normal \( \sim Ne(q_m, \mathbf{I}q_v) \)

x : Tensor

tensor with shape (n_input,)

cat_list : int

list of category membership(s) for this sample


3-tuple of torch.Tensor tensors of shape (n_latent,) for mean and var, and sample