scvi.nn.Decoder#
- class scvi.nn.Decoder(n_input, n_output, n_cat_list=None, n_layers=1, n_hidden=128, **kwargs)[source]#
Bases:
torch.nn.modules.module.Module
Decodes data from latent space to data space.
n_input
dimensions ton_output
dimensions using a fully-connected neural network ofn_hidden
layers. Output is the mean and variance of a multivariate Gaussian- Parameters
- n_input :
int
The dimensionality of the input (latent space)
- n_output :
int
The dimensionality of the output (data space)
- n_cat_list :
Iterable
[int
] |None
Optional
[Iterable
[int
]] (default:None
) A list containing the number of categories for each category of interest. Each category will be included using a one-hot encoding
- n_layers :
int
(default:1
) The number of fully-connected hidden layers
- n_hidden :
int
(default:128
) The number of nodes per hidden layer
- dropout_rate
Dropout rate to apply to each of the hidden layers
- kwargs
Keyword args for
FCLayers
- n_input :
Attributes table#
Methods table#
|
The forward computation for a single sample. |
Attributes#
T_destination#
- Decoder.T_destination#
alias of TypeVar(‘T_destination’, bound=
Mapping
[str
,torch.Tensor
])
alias of TypeVar(‘T_destination’, bound=Mapping
[str
, torch.Tensor
])
.. autoattribute:: Decoder.T_destination
dump_patches
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Decoder.dump_patches: bool = False#
This allows better BC support for
load_state_dict()
. Instate_dict()
, the version number will be saved as in the attribute _metadata of the returned state dict, and thus pickled. _metadata is a dictionary with keys that follow the naming convention of state dict. See_load_from_state_dict
on how to use this information in loading.If new parameters/buffers are added/removed from a module, this number shall be bumped, and the module’s _load_from_state_dict method can compare the version number and do appropriate changes if the state dict is from before the change.
training#
Methods#
forward#
- Decoder.forward(x, *cat_list)[source]#
The forward computation for a single sample.
Decodes the data from the latent space using the decoder network
Returns tensors for the mean and variance of a multivariate distribution
- Parameters
- Returns
2-tuple of
torch.Tensor
Mean and variance tensors of shape(n_output,)